<a href="https://colab.research.google.com/github/harjeet88/llm-course/blob/main/data_engineering_and_LLMs/apache_spark_and_llms.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install pyspark findspark langchain transformers chromadb pandas tqdm

Collecting findspark
  Downloading findspark-2.0.1-py2.py3-none-any.whl.metadata (352 bytes)
Collecting chromadb
  Downloading chromadb-1.3.5-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.2 kB)
Collecting build>=1.0.3 (from chromadb)
  Downloading build-1.3.0-py3-none-any.whl.metadata (5.6 kB)
Collecting pybase64>=1.4.1 (from chromadb)
  Downloading pybase64-1.4.2-cp312-cp312-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl.metadata (8.7 kB)
Collecting posthog<6.0.0,>=2.4.0 (from chromadb)
  Downloading posthog-5.4.0-py3-none-any.whl.metadata (5.7 kB)
Collecting onnxruntime>=1.14.1 (from chromadb)
  Downloading onnxruntime-1.23.2-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (5.1 kB)
Collecting opentelemetry-exporter-otlp-proto-grpc>=1.2.0 (from chromadb)
  Downloading opentelemetry_exporter_otlp_proto_grpc-1.38.0-py3-none-any.whl.metadata (2.4 kB)
Collecting pypika>=0.48.9 (from chromadb)
  Downloading 

In [2]:
# Import findspark to locate Spark installation
import findspark
findspark.init()

In [3]:
# Check PySpark version
import pyspark
print(f"PySpark Version: {pyspark.__version__}")

PySpark Version: 3.5.1


In [4]:
# Import Spark Session
from pyspark.sql import SparkSession
from pyspark.sql.functions import pandas_udf, col, lit
from pyspark.sql.types import ArrayType, FloatType, StringType

In [5]:
# --- Create Spark Session ---
# Use 'local[*]' to use all available cores
spark = SparkSession.builder\
    .appName("ScalableRAGDemo")\
    .config("spark.driver.memory", "4g")\
    .getOrCreate()

print("Spark Session successfully created!")
spark

Spark Session successfully created!


In [6]:
import pandas as pd
import random

# A long, repetitive text to simulate a large document
base_text = """
Apache Spark is a unified analytics engine for large-scale data processing.
It provides high-level APIs in Java, Scala, Python, and R, and an optimized engine
that supports general execution graphs. It also supports a rich set of higher-level
tools including Spark SQL for SQL and structured data processing, MLlib for machine
learning, GraphX for graph processing, and Spark Streaming for stream processing.
It is often used in large enterprises to handle petabytes of data efficiently.
The key to its speed is processing data in memory.
""" * 50 # Repeat 50 times to make a large "document"

In [7]:
# Create a Pandas DataFrame
pdf = pd.DataFrame({
    'doc_id': [f"doc_{i}" for i in range(10)],
    'text': [base_text + f" Document {i} unique text." for i in range(10)]
})

In [8]:
# Convert to Spark DataFrame
# We now have 10 large documents ready for distributed processing
data_df = spark.createDataFrame(pdf)
print("Input Data Schema:")
data_df.printSchema()
print(f"Input Data Count: {data_df.count()}")
data_df.show(1, truncate=50)

Input Data Schema:
root
 |-- doc_id: string (nullable = true)
 |-- text: string (nullable = true)

Input Data Count: 10
+------+--------------------------------------------------+
|doc_id|                                              text|
+------+--------------------------------------------------+
| doc_0|\nApache Spark is a unified analytics engine fo...|
+------+--------------------------------------------------+
only showing top 1 row



In [16]:
!pip install langchain-text-splitters
from langchain_text_splitters import RecursiveCharacterTextSplitter
from tqdm.auto import tqdm # For progress visualization

Collecting langchain-text-splitters
  Downloading langchain_text_splitters-1.0.0-py3-none-any.whl.metadata (2.6 kB)
Downloading langchain_text_splitters-1.0.0-py3-none-any.whl (33 kB)
Installing collected packages: langchain-text-splitters
Successfully installed langchain-text-splitters-1.0.0


In [17]:
# 1. Define the chunking logic function
def chunk_text(text_series: pd.Series) -> pd.Series:
    """
    Applies RecursiveCharacterTextSplitter to a Pandas Series of text.
    Returns a list of chunks for each input document.
    """
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=1000,
        chunk_overlap=100,
        length_function=len,
        separators=["\n\n", "\n", " ", ""]
    )

    # Process each document in the batch
    all_chunks = []
    for text in text_series:
        chunks = text_splitter.split_text(text)
        all_chunks.append(chunks)

    return pd.Series(all_chunks)

In [18]:
# 2. Register the UDF with the correct return type
# The function returns a list of strings (chunks) for each document
chunking_udf = pandas_udf(chunk_text, ArrayType(StringType()))

# 3. Apply the UDF to the DataFrame
chunked_df = data_df.withColumn("chunks", chunking_udf(col("text")))

In [26]:
from pyspark.sql.functions import explode # Import explode function

In [27]:
exploded_df = chunked_df.select(
    col("doc_id"),
    # `explode` creates a new row for each element in the 'chunks' array.
    explode(col("chunks")).alias("chunk")
).withColumn(
    # Add a simple index for tracking if needed, though not strictly necessary for the failure.
    "chunk_index",
    # Use a monotonically increasing ID for a simple index after explosion
    lit(0) # This is a placeholder for simplicity in the demo
)

# The actual output will be many more rows than the input (10 documents -> many chunks)
print(f"Total Chunks Generated (Final Row Count): {exploded_df.count()}")
exploded_df.show(3, truncate=50)

Total Chunks Generated (Final Row Count): 500
+------+--------------------------------------------------+-----------+
|doc_id|                                             chunk|chunk_index|
+------+--------------------------------------------------+-----------+
| doc_0|Apache Spark is a unified analytics engine for ...|          0|
| doc_0|Apache Spark is a unified analytics engine for ...|          0|
| doc_0|Apache Spark is a unified analytics engine for ...|          0|
+------+--------------------------------------------------+-----------+
only showing top 3 rows



In [28]:
from transformers import AutoTokenizer, AutoModel
import torch

# Define the model to use
MODEL_NAME = 'sentence-transformers/all-MiniLM-L6-v2'

# 1. Define the embedding function for the Pandas Iterator (Same as before)
def embed_chunks_iterator(iterator):
    # Load model and tokenizer once per Spark worker process (executor)
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    model = AutoModel.from_pretrained(MODEL_NAME)

    # Check for GPU and move model if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # Function to get mean pooling (necessary for some models)
    def mean_pooling(model_output, attention_mask):
        token_embeddings = model_output[0] # First element of model_output contains all token embeddings
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

    for chunk_batch_pdf in iterator:
        # Extract the chunks column
        chunks = chunk_batch_pdf["chunk"].tolist()

        # Tokenize and run inference in batch on the GPU/CPU
        encoded_input = tokenizer(chunks, padding=True, truncation=True, return_tensors='pt').to(device)
        with torch.no_grad():
            model_output = model(**encoded_input)

        # Perform mean pooling to get sentence embeddings
        embeddings = mean_pooling(model_output, encoded_input['attention_mask']).cpu().numpy()

        # Add the embeddings as a new column
        chunk_batch_pdf['embedding'] = embeddings.tolist()
        yield chunk_batch_pdf



In [30]:
# 2. Define the output schema (Same as before)
embedding_schema = exploded_df.schema.add("embedding", ArrayType(FloatType()))

# 3. Apply the function using mapInPandas (Same as before)
embedded_df = exploded_df.mapInPandas(
    embed_chunks_iterator,
    schema=embedding_schema
)

# ⭐️ THE FIX: FORCE EVALUATION ⭐️
# .cache() stores the result in memory/disk, and .count() forces the execution
# of the embedding logic on the cluster, making the 'embedding' column concrete.
embedded_df = embedded_df.cache()
count = embedded_df.count()

# 4. Show results and check column availability
print("\n--- Embedded DataFrame Schema ---")
embedded_df.printSchema()

print(f"\nTotal Embedded Chunks: {count}")
# Now the 'embedding' column is successfully resolved!
embedded_df.select("doc_id", "chunk", col("embedding").alias("vector")).show(1, truncate=50)

# Example of the first vector size (should be 384 for MiniLM-L6-v2)
first_vector_size = embedded_df.head(1)[0]['embedding'].__len__()
print(f"Embedding Vector Dimension: {first_vector_size}")

AnalysisException: [UNRESOLVED_COLUMN.WITH_SUGGESTION] A column or function parameter with name `embedding` cannot be resolved. Did you mean one of the following? [`doc_id`, `chunk`, `chunk_index`].

In [None]:
import chromadb
import uuid

# 1. Initialize ChromaDB client (stores data in memory/file)
chroma_client = chromadb.Client()

# 2. Create a collection (similar to a table in SQL)
collection_name = "spark_rag_index"
try:
    collection = chroma_client.get_collection(name=collection_name)
except:
    collection = chroma_client.create_collection(name=collection_name)
    print(f"Created collection: {collection_name}")


# 3. Collect the embedded data (WARNING: Only for small demos!)
# In a real scenario, Spark writes directly to the DB via a connector.
# We use .toPandas() only because Colab is running locally.
print("\nCollecting data to client driver for simulation...")
final_pdf = embedded_df.toPandas()

# 4. Prepare data for ChromaDB
documents = final_pdf['chunk'].tolist()
embeddings = final_pdf['embedding'].tolist()
metadata = final_pdf[['doc_id']].to_dict('records') # Must be a list of dicts
ids = [str(uuid.uuid4()) for _ in documents]

# 5. Bulk Add to Vector DB
print(f"Bulk loading {len(documents)} vectors into ChromaDB...")
collection.add(
    embeddings=embeddings,
    documents=documents,
    metadatas=metadata,
    ids=ids
)
print("Bulk load complete.")
print(f"Total vectors in DB: {collection.count()}")

# 6. Clean up Spark Session
spark.stop()