In [None]:
from langchain_community.document_loaders import PyPDFLoader

pdf_path = "/dbfs/FileStore/risk/bank_risk_definitions.pdf"
loader = PyPDFLoader(pdf_path)
docs = loader.load()

len(docs)

In [None]:
from langchain_text_splitters import RecursiveCharacterTextSplitter

splitter = RecursiveCharacterTextSplitter(
    chunk_size=800,
    chunk_overlap=100
)
chunks = splitter.split_documents(docs)

print("Total chunks:", len(chunks))

In [None]:
# Databricks Vector Search uses Delta + embedding index.
# Make sure your workspace has Databricks Vector Search enabled.
# Create a Delta Table for storing embeddings
table_name = "risk_definitions_vectors"

# Create empty table
spark.sql(f"""
CREATE TABLE IF NOT EXISTS {table_name} (
  id STRING,
  text STRING,
  embedding ARRAY<FLOAT>
)
USING DELTA
""")

In [None]:
# Generate embeddings & write to Delta
from langchain_openai import OpenAIEmbeddings
import uuid

embed = OpenAIEmbeddings(model="text-embedding-3-small")

data = []
for chunk in chunks:
    vector = embed.embed_query(chunk.page_content)
    data.append((str(uuid.uuid4()), chunk.page_content, vector))

df = spark.createDataFrame(data, ["id", "text", "embedding"])
df.write.mode("append").format("delta").saveAsTable(table_name)

In [None]:
# Create Databricks Vector Index
from databricks.vector_search.client import VectorSearchClient

vs = VectorSearchClient()

endpoint_name = "risk-endpoint"
index_name = "risk_index"

vs.create_endpoint(name=endpoint_name)  # if not already created

vs.create_delta_vector_search_index(
    endpoint_name=endpoint_name,
    index_name=index_name,
    primary_key="id",
    table_name=table_name,
    embedding_vector_column="embedding"
)

In [None]:
# Create the Retriever Wrapper
from databricks.vector_search.client import VectorSearchIndex

index = vs.get_index(endpoint_name, index_name)

def retrieve_risk_context(query):
    vector = embed.embed_query(query)

    results = index.query(
        query_vector=vector,
        columns=["text"],
        k=4
    )

    return "\n".join([row["text"] for row in results])

In [None]:
# Build RunnableSequence LLM Classification Chain
from langchain_openai import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain.schema.runnable import RunnableParallel, RunnablePassthrough

llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)

classification_prompt = ChatPromptTemplate.from_template("""
You are a senior banking risk officer.

Bank-specific risk definitions:
----------------
{context}
----------------

Risk event:
{issue_description}

TASK:
1. Tag the event with one or more bank-defined risk types. (Multi-label)
2. Provide a 1â€“2 line summary.
3. Output JSON:
{{
  "risk_type": [string],
  "issue_summary": string
}}
""")

rag_chain = (
    RunnableParallel(
        context=lambda x: retrieve_risk_context(x["issue_description"]),
        issue_description=RunnablePassthrough()
    )
    | classification_prompt
    | llm
)

In [None]:
# Load Risk Events File from DBFS
df = pd.read_csv("/dbfs/FileStore/risk/banking_dummy_issues.csv")
df.head()
# spark.read.csv("/dbfs/FileStore/risk/banking_dummy_issues.csv", header=True).show(5)

In [None]:
# Process the Entire DataFrame Using abatch()
import asyncio

async def process_df(df):
    inputs = [{"issue_description": text} for text in df["issue_description"]]

    outputs = await rag_chain.abatch(inputs)

    summaries = []
    risk_types = []

    for out in outputs:
        content = out.content.strip()
        try:
            obj = eval(content)
            summaries.append(obj["issue_summary"])
            risk_types.append(", ".join(obj["risk_type"]))
        except:
            summaries.append("Summary unavailable.")
            risk_types.append("Other")

    df["issue_summary"] = summaries
    df["risk_type"] = risk_types
    return df

result_df = asyncio.run(process_df(df))
result_df.head()

In [None]:
# Save to DBFS as CSV
result_df.to_csv("/dbfs/FileStore/risk/risk_events_enriched.csv", index=False)

In [None]:
# Save to Delta Table
spark_df = spark.createDataFrame(result_df)
spark_df.write.mode("overwrite").saveAsTable("classified_risk_events")