In [None]:
from pyspark.sql import SparkSession
from sentence_transformers import SentenceTransformer
from langchain import OpenAI, LLMChain, PromptTemplate
import faiss, numpy as np, os

spark = SparkSession.builder.appName("SparkGPT-RAG").getOrCreate()

In [None]:
# Load & preprocess documents
docs = spark.read.text("data/sample_docs/*.txt")
docs.show(5)

In [None]:
embedder = SentenceTransformer("all-MiniLM-L6-v2")
texts = [r.value for r in docs.limit(50).collect()]
vectors = embedder.encode(texts)

In [None]:
# Build vector index
index = faiss.IndexFlatL2(vectors.shape[1])
index.add(np.array(vectors).astype("float32"))

In [None]:
# Define LLM retrieval chain
llm = OpenAI(temperature=0.2, model="gpt-4", openai_api_key=os.getenv("OPENAI_API_KEY"))
prompt = PromptTemplate(template="Context: {context}\nQuestion: {query}\nAnswer concisely.", input_variables=["context", "query"])
chain = LLMChain(llm=llm, prompt=prompt)

In [None]:
# Query
query = "What are Databricks' advantages for LLMOps?"
query_vec = embedder.encode(query)
D, I = index.search(np.array([query_vec]).astype("float32"), 3)
context = "\n".join([texts[i] for i in I[0]])
result = chain.run(context=context, query=query)
print(result)

In [None]:
# Evaluate & visualize
from sklearn.metrics import pairwise_distances
import matplotlib.pyplot as plt

In [None]:
plt.title("Embedding Distribution")
plt.imshow(pairwise_distances(vectors[:100]), cmap='viridis')
plt.show()