In [53]:
import os
import json
import random
from collections import defaultdict
from typing import List, Tuple
from pydantic import BaseModel, computed_field

from dotenv import load_dotenv
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langgraph.graph import START, StateGraph

import pandas as pd
import json

from langchain import hub
from typing_extensions import List, TypedDict
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.prompts import PromptTemplate
from langchain_core.language_models.chat_models import BaseChatModel


## Import model dependencies and load API Key

In [54]:
# for OPENAI

from langchain_openai import ChatOpenAI

load_dotenv()
api_key_openai = os.environ.get("OPENAI_API_KEY")

In [55]:
# for LLAMA (via GROQ)

from langchain_groq import ChatGroq

api_key_groq = "YOUR API KEY HERE"

## Load data from info retrieval (will have maximum K of context)

In [None]:
current_file = "./data/json_output/contractnli.json"

In [57]:
with open(current_file) as f:
    QnA_data = json.load(f)

In [58]:
# get query and K number of context
def get_query_from_json_at_K(index=0, k_context=3):
    
    qna_data_query = QnA_data[index]["query"]
    retrieved_chunks = QnA_data[index]["retrieved_chunks_unranked"]

    contexts_from_json = []

    for i in range(min(k_context, len(retrieved_chunks))):
        context = retrieved_chunks[i]
        cur_context = {}
        cur_context["file_path"] = context["filepath"]
        cur_context["span"] = context["span"]
        cur_context["chunk"] = context["text"]

        contexts_from_json.append(cur_context)
    
    return qna_data_query, contexts_from_json

In [59]:
# example use
get_query_from_json_at_K(0, 3)

('Consider "Fiverr"\'s privacy policy; who can see which tasks i hire workers for?',
 [{'file_path': 'privacy_qa/Fiverr.txt',
   'span': [1011, 1388],
   'chunk': 'Information that you choose to publish on the Site (photos, videos, text, music, reviews, deliveries) - is no longer private, just like any information you publish online.\n  Technical information that is gathered by our systems, or third party systems, automatically may be used for Site operation, optimization, analytics, content promotion and enhancement of user experience.'},
  {'file_path': 'privacy_qa/Fiverr.txt',
   'span': [173, 527],
   'chunk': 'We do not disclose it to others except as disclosed in this Policy or required to provide you with the services of the Site and mobile applications, meaning - to allow you to buy, sell, share the information you want to share on the Site; to contribute on the forum; pay for products; post reviews and so on; or where we have a legal obligation to do so.'},
  {'file_path': 'pr

## Response generator class

In [60]:
class ResponseGenerator:
    class State(TypedDict):
        question : str
        context : List[Document]
        answer: str

    def __init__(self, prompt : PromptTemplate, llm : BaseChatModel):
        self.llm = llm
        self.prompt = prompt

        graph_builder = StateGraph(self.State)
        graph_builder.add_sequence([self.generate])
        graph_builder.add_edge(START, "generate")
        self.graph = graph_builder.compile()

    def generate(self, state : State):
        context_doc_message = "\n\n".join(doc for doc in state["context"])
        message = self.prompt.invoke({"question":state["question"], "context":context_doc_message})
        response = self.llm.invoke(message)

        return({"answer":response})

## Generate response functions

In [61]:
from tqdm import tqdm  

def generate_response_with_context_at_K(response_generator: ResponseGenerator, size=10, k_context=3, JSON_CoT=False):
    qna_context_list = []

    # Wrap the range iterator with tqdm for progress tracking
    for i in tqdm(range(0, size), desc="Generating responses"):
        query, contexts = get_query_from_json_at_K(index=i, k_context=k_context)            
        output = response_generator.graph.invoke({"question": query, "context": [context["chunk"] for context in contexts]})
        
        user_input = query
        retrieved_contexts = [context["chunk"] for context in contexts]

        if JSON_CoT:
            # Prepare the raw response for later parsing
            raw_response = output["answer"].content.strip()
            qna_context_list.append([user_input, raw_response, retrieved_contexts])

            dataset_df = pd.DataFrame(qna_context_list, columns=["user_input", "raw_response", "retrieved_contexts"])

        else:
            response = output["answer"].content
            qna_context_list.append([user_input, response, retrieved_contexts])

            dataset_df = pd.DataFrame(qna_context_list, columns=["user_input", "response", "retrieved_contexts"])
    
    return dataset_df
def generate_response_with_context_at_K(response_generator: ResponseGenerator, size = 10, k_context=3, JSON_CoT=False):
    qna_context_list = []

    # use tqdm here! 
    for i in tqdm(range(0, size), desc="Generating responses"):
        query, contexts = get_query_from_json_at_K(index=i, k_context = k_context)            
        output = response_generator.graph.invoke({"question": query, "context": [context["chunk"] for context in contexts]})
        
        user_input = query
        retrieved_contexts = [context["chunk"] for context in contexts]

        if JSON_CoT:
            # Prepare the raw response for later parsing
            raw_response = output["answer"].content.strip()
            qna_context_list.append([user_input, raw_response, retrieved_contexts])

            dataset_df = pd.DataFrame(qna_context_list, columns=["user_input", "raw_response", "retrieved_contexts"])

        else:
            response = output["answer"].content
            qna_context_list.append([user_input, response, retrieved_contexts])

            dataset_df = pd.DataFrame(qna_context_list, columns=["user_input", "response", "retrieved_contexts"])
    
    return dataset_df

In [62]:
def process_CoT_raw_response_df(df):
    processed_data = []
    
    for idx, row in df.iterrows():
        user_input = row['user_input']
        raw_response = row['raw_response']
        retrieved_contexts = row['retrieved_contexts']
        
        # Clean the raw response string
        clean_response = raw_response.replace("\n", "").replace("\r", "").replace("\t", "").strip()
        
        try:
            # Try to parse as JSON directly first
            if clean_response.startswith("```json"):
                # Remove ```json and ``` markers
                json_str = clean_response.replace("```json", "").replace("```", "").strip()
                json_data = json.loads(json_str)
            else:
                json_data = json.loads(clean_response)
            thought = json_data["thought"]
            response = json_data["answer"]
            
        except json.JSONDecodeError:
            # If direct parsing fails, try to extract JSON structure
            try:
                # Find the last occurrence of {"thought"
                thought_start = clean_response.rindex('{"thought"')
                # Find the matching closing brace
                brace_count = 0
                for i in range(thought_start, len(clean_response)):
                    if clean_response[i] == '{':
                        brace_count += 1
                    elif clean_response[i] == '}':
                        brace_count -= 1
                        if brace_count == 0:
                            json_str = clean_response[thought_start:i+1]
                            break
                
                json_data = json.loads(json_str)
                thought = json_data["thought"]
                response = json_data["answer"]
                
            except (ValueError, json.JSONDecodeError, KeyError) as e:
                print(f"Failed to parse JSON at index {idx}")
                print(f"Raw response: {raw_response}")
                # Skip this row or add placeholder values
                thought = "ERROR: Failed to parse thought"
                response = "ERROR: Failed to parse response"
        
        processed_data.append([user_input, thought, response, retrieved_contexts])
    
    return pd.DataFrame(processed_data, columns=["user_input", "thought", "response", "retrieved_contexts"])

## Define the model

In [63]:
llm_openAI = ChatOpenAI(model="gpt-4o-mini", temperature=0.2, api_key=api_key_openai)

In [64]:
llm_llama = ChatGroq(model="llama3-8b-8192", temperature=0.3, model_kwargs={"top_p": 0.9}, api_key=api_key_groq)

## Initialize prompt and response generator

In [65]:
# define prompt


# ======================================================================================
# ========================           BASELINE PROMPT        ============================
# ======================================================================================
baseline_prompt = PromptTemplate.from_template("""HUMAN\n
                                               You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Use three sentences maximum and keep the answer concise.\n
                                               Question: {question}\n 
                                               Context: {context}\n 
                                               Answer:
                                               """)

baseline_response_generator_gpt = ResponseGenerator(prompt=baseline_prompt, llm=llm_openAI)
baseline_response_generator_llama = ResponseGenerator(prompt=baseline_prompt, llm=llm_llama)


# ======================================================================================
# ========================        CHAIN OF THOUGHT PROMPT        =======================
# ======================================================================================
CoT_prompt = PromptTemplate.from_template("""HUMAN\n
                                        You are a world class assistant for legal case question-answering tasks. Think step by step on each of retrieved context on how they help answer the question. Think how different terminologies or party names or entity names in the contexts are related to the question. Think if all the contexts is relevant to the question and ONLY use relevant information to answer the question.\n
                                        How to write the final answer: Answer in legal counseling manner: clear up different terminologies or party name between the question and contexts by mentioning the equivalent terms/word/entity/name in the final answer. State ONLY facts and explicitly say if it ONLY IMPLIES something that answer the question in the final answer. Use maximum of three sentences in the final answer. Just say you don't know if you cannot generate meaningful and factual answer\n
                                        Question: {question}\n
                                        Context: {context}\n
                                        Step-by-step reasoning: Let's think step by step on the context.\n
                                        Output Format, without any additional string/text/character: {{"thought":"description: AI thought on the question and context","answer":"description: the final answer to the question"}}\n
                                        Answer: 
                                        """)


CoT_response_generator_gpt = ResponseGenerator(prompt=CoT_prompt, llm=llm_openAI)
CoT_response_generator_llama = ResponseGenerator(prompt=CoT_prompt, llm=llm_llama)


# ======================================================================================
# ========================        MANUALLY WRITTEN PROMPT        =======================
# ======================================================================================
manually_written_prompt = PromptTemplate.from_template("""### Instruction:\n
                                                       You are an AI assistant specializing in legal contract analysis. Your task is to carefully examine the *provided Retrieved Chunk* and *answer the user's question accurately*.\n
                                                       Follow these guidelines:\n
                                                       
                                                       Read the clause carefully. Identify any terms, conditions, or restrictions related to the user's question.\n
                                                       Answer explicitly based on the clause. If the clause clearly states the information being asked, explain it clearly and accurately.\n
                                                       Do not ignore relevant details. If the clause contains conditions, restrictions, or exceptions, **mention them in your answer.\n
                                                       If the clause does not provide a direct answer, say so. Do not assume or infer information that is not stated.\n
                                                       Support your answer with key phrases from the clause clause when necessary.\n
                                                       Avoid unnecessary repetition or legal jargon. The goal is to make the answer **clear and understandable.\n
                                                       
                                                       ### Retrieved Chunk:\n
                                                       {context}\n
                                                       
                                                       ### User's Question:\n
                                                       {question}\n
                                                       ### Answer:\n                                                  
                                                       """)

manually_written_response_generator_gpt = ResponseGenerator(prompt=manually_written_prompt, llm=llm_openAI)
manually_written_response_generator_llama = ResponseGenerator(prompt=manually_written_prompt, llm=llm_llama)

## Generate all responses at all K

In [66]:
def generate_responses_for_k(sample_size, k_values, model_name, baseline_response_generator, CoT_response_generator, manually_written_response_generator):
        
    for k in k_values:
        # Baseline
        query_answer_baseline = generate_response_with_context_at_K(baseline_response_generator, sample_size, k_context=k)
        query_answer_baseline.to_json(f'query_answer_baseline_{model_name}_k{k}.json', orient="records", indent=4)

        # Chain of Thought
        query_answer_CoT_raw = generate_response_with_context_at_K(CoT_response_generator, sample_size, JSON_CoT=True, k_context=k)
        query_answer_CoT_raw.to_json(f'query_answer_CoT_raw_{model_name}_k{k}.json', orient="records", indent=4)

        # Manually written
        query_answer_manually_written = generate_response_with_context_at_K(manually_written_response_generator, sample_size, k_context=k)
        query_answer_manually_written.to_json(f'query_answer_manually_written_{model_name}_k{k}.json', orient="records", indent=4)

In [67]:
# clean CoT JSON output

def clean_CoT_json_output(k_values, model_name):
    for k in k_values:
        try:
            query_answer_CoT_raw_from_file = pd.read_json(f'query_answer_CoT_raw_{model_name}_k{k}.json')
            query_answer_CoT = process_CoT_raw_response_df(query_answer_CoT_raw_from_file)
            query_answer_CoT.to_json(f'query_answer_CoT_{model_name}_k{k}.json', orient="records", indent=4)
        except FileNotFoundError:
            print(f"File query_answer_CoT_raw_{model_name}_k{k}.json not found")
        except ValueError as e:
            print(f"Error processing file for {model_name} k={k}: {str(e)}")
        except Exception as e:
            print(f"Unexpected error processing file for {model_name} k={k}: {str(e)}")

## Generate for GPT MODEL

In [None]:
# Generate responses for k values [1, 3, 5, 10] GPT MODEL
generate_responses_for_k(sample_size = 192,
                         k_values = [1, 3, 5, 10], 
                         model_name = "privacyqa_gpt4omini",
                         baseline_response_generator = baseline_response_generator_gpt, 
                         CoT_response_generator = CoT_response_generator_gpt, 
                         manually_written_response_generator = manually_written_response_generator_gpt
                         )

# clean raw CoT JSON output to clean output
clean_CoT_json_output([1, 3, 5, 10], "privacyqa_gpt4omini")

## Generate for LLAMA 3 MODEL

In [83]:
# Generate responses for k values [1, 3, 5, 10] GPT MODEL

# Test using small value first!!!
# e.g. sample size = 5
# k_values = [3]

generate_responses_for_k(sample_size = 194,
                         k_values = [1, 3, 5, 10], 
                         model_name = "llama3",
                         baseline_response_generator = baseline_response_generator_llama, 
                         CoT_response_generator = CoT_response_generator_llama, 
                         manually_written_response_generator = manually_written_response_generator_llama
                         )

# clean raw CoT JSON output to clean output
clean_CoT_json_output([1, 3, 5, 10], "llama3")