<a href="https://colab.research.google.com/github/harjeet88/A_For_Algorithms/blob/master/data_engg/spark_rag.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [31]:
# 1. Install necessary libraries
# We use Pyspark for distribution, LangChain for the chunking logic,
# and sentence-transformers for the embedding model.
!pip install -q pyspark findspark langchain sentence-transformers pandas tqdm

In [32]:
# 2. Import findspark and initialize
import findspark
findspark.init()

# 3. Import PySpark components
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, explode, udf, monotonically_increasing_id, lit
from pyspark.sql.types import ArrayType, FloatType, StringType

# 4. Initialize Spark Session
# Using 'local[*]' utilizes all available cores for parallel processing.
# Configure driver memory for stability in Colab.
spark = SparkSession.builder\
    .appName("DistributedRAGDemo")\
    .config("spark.driver.memory", "4g")\
    .getOrCreate()

print("Spark Session successfully created! Ready for distributed processing.")
# Display the session details
spark

Spark Session successfully created! Ready for distributed processing.


In [34]:
pip install -q PyPDF2 pymilvus

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/232.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m232.6/232.6 kB[0m [31m13.2 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/278.0 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m278.0/278.0 kB[0m [31m19.8 MB/s[0m eta [36m0:00:00[0m
[?25h

In [35]:
import os, os.path

from pyspark.sql import SparkSession
from pyspark.sql.functions import udf, explode
from pyspark.sql.types import StringType, ArrayType, FloatType
from sentence_transformers import SentenceTransformer
import findspark
from PyPDF2 import PdfReader
from pyspark.sql import SparkSession

from pymilvus import connections, Collection

import findspark
import re

In [36]:
os.environ['PYARROW_IGNORE_TIMEZONE']='1'
os.environ['NUMEXPR_MAX_THREADS'] = '2'
os.environ['NUMEXPR_NUM_THREADS'] = '2'
os.environ['OBJC_DISABLE_INITIALIZE_FORK_SAFETY'] = 'YES'


In [37]:
CHUNK_SIZE = 1600
CHUNK_OVERLAP = 50

In [38]:
# Define a UDF to extract text using PyPDF
def extract_text(file_path):
    reader = PdfReader(file_path)
    text = ''
    for i in range(0,len(reader.pages)):
        text += reader.pages[i].extract_text()
    return text

In [39]:
# Define the function to create embeddings
def create_embedding(text):
    # Create a SentenceTransformer model
    transformer = SentenceTransformer(os.getenv('EMBEDDING_MODEL'))
    embeddings = transformer.encode(text, convert_to_tensor=True)
    return embeddings.numpy().tolist()

In [40]:
def extract_text_chunks(symbol, text):
    metadata = "Document contains context of " + symbol \
        + " and is relevant to the annual reports / financial statements/ 10-K SEC fillings\n"
    chunks = []
    for i in range(0, len(text), CHUNK_SIZE):
        if i > CHUNK_OVERLAP:
            chunks.append(metadata + text[i - CHUNK_OVERLAP : i + CHUNK_SIZE])
        else:
            chunks.append(metadata + text[i : i + CHUNK_SIZE])
    return chunks

In [41]:
def get_stock_symbol(file_name):
    match = re.search(r'NASDAQ_([A-Z]{1,5})_2022\.pdf', file_name)
    if match:
        return match.group(1)
    return "NA"

In [42]:
# Register the UDF
extract_text_udf = udf(extract_text, StringType())
spark.udf.register("extract_text", extract_text_udf)

extract_text_chunks_udf = udf(extract_text_chunks, ArrayType(StringType()))
spark.udf.register("extract_text_chunks", extract_text_chunks_udf)

create_embedding_udf = udf(create_embedding, ArrayType(FloatType()))
spark.udf.register("create_embeddings", create_embedding_udf)

get_stock_symbol_udf = udf(get_stock_symbol, StringType())
spark.udf.register("get_stock_symbol", get_stock_symbol_udf)


<pyspark.sql.udf.UserDefinedFunction at 0x7f3014469df0>

In [65]:
def get_embedded_chunks(pdf_directory):
    pdf_file_paths = []
    for file in os.listdir(pdf_directory):
        if file == '.DS_Store':
            continue
        print(file)
        if file.endswith(".pdf"):
            pdf_file_paths.append(os.path.join(pdf_directory, file))
    print("Creating dataframe with file paths")
    # Create DataFrame with file paths
    pdf_files = spark.createDataFrame(pdf_file_paths, "string").toDF("file_path")
    pdf_files = pdf_files.select(
        'file_path', get_stock_symbol_udf('file_path').alias('stock_symbol'))

    print("Extracting text from PDF files")
    # Extract text from PDF files with each line containing name of file and array of page text
    chunked_text_data = pdf_files.withColumn("text", extract_text_udf("file_path"))

    print("Chunking text into chunks")
    # Break text into individual row per page using explode()
    chunked_text_data = chunked_text_data.withColumn("relevant_text", \
        extract_text_chunks_udf("stock_symbol", "text"))

    print("Break text into individual row per page using explode()")
    # Break text into individual row per page using explode()
    chunked_text_data = chunked_text_data.select('stock_symbol', 'file_path',
        explode(chunked_text_data.relevant_text).alias('chunked_text'))

    print("Converting into embeddings")
    # Convert into embeddings
    chunked_text_data = chunked_text_data.withColumn("embedded_vectors", \
        create_embedding_udf("chunked_text"))
    print("returning chunked text data")
    return chunked_text_data

In [66]:
def ingest_data():
    print("PDF ingestion started...")
    chunked_data = get_embedded_chunks("./rag-spark/data/annual_reports")
    # Connect to Milvus Database

    connections.connect(host=os.getenv('MILVUS_HOST'),
                        port=os.getenv('MILVUS_PORT'), secure=False)
    # Create collection if not exists
    collection_name = os.getenv('MILVUS_COLLECTION_NAME')
    collection = Collection(collection_name)
    collection.insert(chunked_data.toPandas())
    collection.flush()
    print("PDF ingestion completed...")

In [46]:
!mkdir -p rag-spark/data/annual_reports

In [48]:
!curl https://github.com/nairnavin/ml-playground/tree/main/rag-spark/data/annual_reports/NASDAQ_AAPL_2022.pdf -o rag-spark/data/annual_reports

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0


In [67]:
ingest_data()

PDF ingestion started...
Creating dataframe with file paths
Extracting text from PDF files
Chunking text into chunks
Break text into individual row per page using explode()
Converting into embeddings
returning chunked text data


ConnectionConfigException: <ConnectionConfigException: (code=1, message=Type of 'host' must be str.)>

#setting up Milvius

In [68]:
import os
from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility
from dotenv import load_dotenv
load_dotenv()

False

In [71]:
os.environ['MILVUS_HOST']='localhost'
os.environ['MILVUS_PORT']='19530'
os.environ['MILVUS_COLLECTION_NAME']='financial_docs_collection'

In [72]:
def init_vectordb():
    HOST = os.getenv('MILVUS_HOST')
    PORT = os.getenv('MILVUS_PORT')
    # Connect to Milvus Database
    connections.connect(host=HOST, port=PORT, secure=False)

    # Create collection if not exists
    collection_name = os.getenv('MILVUS_COLLECTION_NAME')

    # Remove collection if it already exists (only for test)
    if utility.has_collection(collection_name):
        print('Dropping existing collection "%s"' % collection_name)
        utility.drop_collection(collection_name)

    # Create collection which includes the id, title, and embedding.
    fields = [
        FieldSchema(name='id', dtype=DataType.INT64, is_primary=True, auto_id=True),
        FieldSchema(name='stock_symbol', dtype=DataType.VARCHAR, max_length=10),
        FieldSchema(name='file_path', dtype=DataType.VARCHAR, max_length=200),
        FieldSchema(name='chunked_text', dtype=DataType.VARCHAR, max_length=2200),
        FieldSchema(name='embedded_vectors', dtype=DataType.FLOAT_VECTOR, dim=384)
    ]

    print('Creating collection and index for "%s"' % collection_name)
    schema = CollectionSchema(fields=fields)
    collection = Collection(name=collection_name, schema=schema)
    # Create an IVF_FLAT index for collection.
    index_params = {
        'metric_type':'L2',
        'index_type':"IVF_FLAT",
        'params':{"nlist":768}
    }
    collection.create_index(field_name="embedded_vectors", index_params=index_params)
    collection.load()
    return collection

In [73]:
init_vectordb()

MilvusException: <MilvusException: (code=2, message=Fail connecting to server on localhost:19530, illegal connection params or server unavailable)>