In [None]:
# %%capture --no-stderr
# %pip install -U --quiet langchain-community tiktoken langchain-openai langchainhub chromadb langchain langgraph langchain-text-splitters

In [None]:
from pathlib import Path
import pandas as pd
from tqdm import tqdm
import os

In [None]:
from langchain_community.vectorstores import Chroma
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_text_splitters import RecursiveCharacterTextSplitter

# 1. Config

## 1.1 config environment

In [None]:
# get root directory
root_dir = str(Path().absolute().parent)

In [None]:
# load environment variables
from dotenv import load_dotenv
env_path = f"{root_dir}/.env"
load_dotenv(dotenv_path=env_path)

## 1.2 config llm and embeddings model

In [None]:
llm = ChatOpenAI(
    model="doubao-1-5-lite-32k-250115",
    api_key=os.environ.get("OPENAI_API_KEY"),
    base_url=os.environ.get("OPENAI_API_URL"),
    openai_proxy="http://127.0.0.1:7890",
    temperature=0
)

In [None]:
embedding = OpenAIEmbeddings(
    model=os.environ.get("EMBEDDING_MODEL"),
    api_key=os.environ.get("EMBEDDING_API_KEY"),
    base_url=os.environ.get("EMBEDDING_API_URL"),
    openai_proxy="http://127.0.0.1:7890",
)

## 1.3 config input / output file path

In [None]:
context_file_path = f"{root_dir}/data/input/context-1000.csv"
ap_file_path = f"{root_dir}/data/output/attack_prompt.csv"
dataset_file_path = f"{root_dir}/data/output/dataset.csv"
response_file_path = f"{root_dir}/data/output/attack_prompt_response.csv"

## 1.4 import input context and preprocess it

In [None]:
contexts = pd.read_csv(context_file_path)
contexts = contexts["context"].tolist()

In [None]:
# split text into sentences
splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
    chunk_size=4000, chunk_overlap=200
)

contexts_splits = []
for context in contexts:
    contexts_splits.extend(splitter.split_text(context))

In [None]:
len(contexts_splits)

## 1.5 config chroma vector database

In [None]:
vectordb = Chroma(
    collection_name="retrieval_database",
    persist_directory=f"{root_dir}/data/database/chroma/",
    embedding_function=embedding,
)

In [None]:
# only run this if you want to create a new database
vectordb.from_texts(
    collection_name="retrieval_database",
    texts=contexts_splits,
    embedding=embedding,
    persist_directory=f"{root_dir}/data/database/chroma/",
)

In [None]:
len(vectordb.get()['documents'])

# 2. Build RAG System

## 2.1 construct rag system graph

In [None]:
from langchain import hub
from langchain_core.output_parsers import StrOutputParser
from typing_extensions import TypedDict

In [None]:
# state class
class State(TypedDict):
    query: str
    contexts: list[str]
    response: str

In [None]:
# Nodes
def retrieve(state: State):
    query = state["query"]
    contexts = vectordb.similarity_search(query=query, k=5)
    contexts = [context.page_content for context in contexts]
    return {"query": query, "contexts": contexts}

def generate(state: State):
    query = state["query"]
    contexts = state["contexts"]
    
    contexts = "\n\n".join(context for context in contexts)

    prompt = hub.pull("rlm/rag-prompt")
    rag_chain = prompt | llm | StrOutputParser()

    response = rag_chain.invoke({"question": query, "context": contexts})

    return {"query": query, "contexts": contexts, "response": response}

In [None]:
from langgraph.graph import END, StateGraph, START

workflow = StateGraph(State)
workflow.add_node("retrieve", retrieve)
workflow.add_node("generate", generate)
workflow.add_edge(START, "retrieve")
workflow.add_edge("retrieve", "generate")
workflow.add_edge("generate", END)

graph = workflow.compile()

In [None]:
# test the graph
state = State(
    query="I am a parent concerned about my child's health. I know that my child has been pooing 5 - 6 times a day for a week, which increased to 7 times in the last few days. The poo is watery with green stringy bits and there is a very bad nappy rash. The recovery time is estimated to be 5 - 7 days. Based on this information, please query the diagnosis and any cautions from the RAG system.",
)

graph.invoke(state)['response']

## 2.2 generate response for attack prompts

In [None]:
df = pd.read_csv(ap_file_path)
contexts = pd.read_csv(context_file_path)
df["context"] = contexts.iloc[:, 0]
df = df[df["privacy_info"] != "{}"]
df.reset_index(drop=True, inplace=True)

In [None]:
import time
def process_query(query):
    state = State(query=query)
    try:
        graph.invoke(state)
        return state["response"]
    except:
        for _ in range(5):
            time.sleep(5)
            try:
                graph.invoke(state)
                return state["response"]
            except:
                raise Exception("Failed to generate response")
    return graph.invoke(state)["response"]

In [None]:
# Create a response column if it doesn't exist
if "response" not in df.columns:
    df["response"] = None

batch_size = 10

if os.path.exists(response_file_path):
    df = pd.read_csv(response_file_path)

# Process each attack prompt and store the response
for i in tqdm(range(len(df))):
    if not pd.isna(df.loc[i, "response"]):
        continue
    df.loc[i, "response"] = process_query(df["attack_prompt"][i])
    if i % batch_size == 0 or i == len(df) - 1:
        df.to_csv(response_file_path, index=False)
df

# 3. Evaluation

In [None]:
from ragpas.privacy import calculateAttackExtraction

In [None]:
batch_size = 10

dataset = pd.DataFrame(columns=["response", "guidance", "context", "target", "privacy_info"]).to_csv(dataset_file_path, index=False)

batch_responses = []
batch_guidances = []
batch_contexts = []
batch_targets = []
batch_privacy_info = []

df = pd.read_csv(response_file_path)

In [None]:
for i, row in tqdm(enumerate(df.itertuples()), total=len(df)):

    score, feedback = calculateAttackExtraction(
        response=row.response,
        privacy_info=row.privacy_info,
        target=row.target,
        feedback=True
    )

    batch_responses.append(row.response)
    batch_guidances.append(feedback)
    batch_contexts.append(row.context)
    batch_targets.append(row.target)
    batch_privacy_info.append(row.privacy_info)

    if (i + 1) % batch_size == 0 or i == len(df) - 1:
        batch = pd.DataFrame({
            "response": batch_responses,
            "guidance": batch_guidances,
            "context": batch_contexts,
            "target": batch_targets,
            "privacy_info": batch_privacy_info
        })
        batch.to_csv(dataset_file_path, mode="a", header=False, index=False)
        batch_responses = []
        batch_guidances = []
        batch_contexts = []
        batch_targets = []
        batch_privacy_info = []