In [None]:
import boto3
import time
import os
from langchain.text_splitter import RecursiveCharacterTextSplitter
import numpy as np
from pyspark.sql import SparkSession
from langchain_community.vectorstores import FAISS
from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
from langchain_community.vectorstores import FAISS

In [None]:
# Create Spark session with environment variables
spark = SparkSession.builder.appName("TextFileProcessing")\
    .config("spark.executorEnv.HF_HOME", "/mnt/yarn/usercache/")\
    .config("spark.executorEnv.TRANSFORMERS_CACHE", "/mnt/yarn/usercache/")\
    .getOrCreate()


# Create a list of the files to be read from S3 bucket
def list_all_txt_files(bucket, prefix):
    """List all txt files in the specified S3 bucket and prefix using boto3."""
    s3_client = boto3.client('s3')
    paginator = s3_client.get_paginator('list_objects_v2')
    page_iterator = paginator.paginate(Bucket=bucket, Prefix=prefix)

    txt_files = []
    for page in page_iterator:
        if 'Contents' in page:
            txt_files.extend(['s3://' + bucket + '/' + item['Key']
                              for item in page['Contents']
                              if item['Key'].endswith('.txt')])
    return txt_files

#Reading the text file
def read_text_from_s3(bucket, key):
    """Read text file from S3."""
    s3_client = boto3.client('s3')
    obj = s3_client.get_object(Bucket=bucket, Key=key)
    return obj['Body'].read().decode('utf-8')

import os
import boto3
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS

#Creating embedding and adding it to vector database
def process_text(text, embeddings_broadcast, path):
    """Process text to generate embeddings and return a FAISS db."""
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=10)
    split_text = text_splitter.split_text(text)
    
    # Utilize the broadcasted embeddings model
    embeddings = embeddings_broadcast.value
    faiss = FAISS.from_texts(split_text, embeddings)
    
    # Define local and S3 paths
    local_dir = "/mnt/yarn/usercache/faiss_index"
    faiss.save_local(local_dir)  # Save the FAISS index locally
    
    # Remove 's3://bucket_name/' from the path and use the rest as part of the S3 key prefix
    relative_path = path.replace('s3://metcs777-term-project/', '').strip()
    s3_prefix = f'output/faiss_index/{relative_path}/'

    # Upload each file in the directory to S3
    bucket_name = 'metcs777-term-project'
    s3_client = boto3.client('s3')

    for filename in os.listdir(local_dir):
        local_path = os.path.join(local_dir, filename)
        s3_key = os.path.join(s3_prefix, filename)
        with open(local_path, 'rb') as data:
            s3_client.put_object(Bucket=bucket_name, Key=s3_key, Body=data)
    print(f"Successfully uploaded FAISS index to s3://{bucket_name}/{s3_prefix}")

#Merge the databases
def merge_faiss_dbs(dbs):
    """Merge multiple FAISS databases into one."""
    final_db = FAISS()
    for db in dbs:
        final_db.merge_from(db)
    return final_db


# Broadcast the embeddings model
cache_dir = "/mnt/yarn/usercache/"
os.environ["TRANSFORMERS_CACHE"] = cache_dir  # HuggingFace cache
os.environ["SENTENCE_TRANSFORMERS_HOME"] = cache_dir  # Sentence Transformers cache
embeddings_model = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
embeddings_broadcast = spark.sparkContext.broadcast(embeddings_model)

In [None]:
# Define bucket and prefix, then retrieve file paths
bucket_name = 'metcs777-term-project'
prefix = 'output/'

files = list_all_txt_files(bucket_name, prefix)

# Read and process files in parallel using Spark
rdd = spark.sparkContext.parallelize(files)
processed_dbs = rdd.map(lambda path: process_text(read_text_from_s3(bucket_name, path.replace(f's3://{bucket_name}/', '')), embeddings_broadcast,path)).collect()
print("Databases Created and store in S3")
