In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os

os.chdir("..")

In [None]:
from langchain.prompts import PromptTemplate
from src.commercial.templates import SYS_PALM_TEMPLATE, INST_PALM_TEMPLATE
from src.commercial.inference_palm import palm_completion
from utils import get_clues
import os
import pandas as pd
from const import STATE_CLUES_NOTES_DICT
from typing import Dict, List
import warnings

warnings.filterwarnings("ignore")

In [None]:
sys_prompt_template = PromptTemplate.from_template(SYS_PALM_TEMPLATE)
inst_prompt_template = PromptTemplate.from_template(INST_PALM_TEMPLATE)
df = pd.read_csv("/home/t-sahuja/cultural_artifacts/clues/tamil_nadu/tamil_clues.csv")

In [None]:
clues = df["clues"].iloc[9].strip().split("\n")
output = ""
for j, clue in enumerate(clues):
    output += f"CLUE-{j+1}: {clue}\n"
fin_clues = output.strip()

In [None]:
sys_prompt = sys_prompt_template.format(cluelist=fin_clues)
inst_prompt = inst_prompt_template.format(state="Tamil Nadu")

In [None]:
sys_prompt

In [None]:
palm_resp = palm_completion(
    sys_prompt=sys_prompt.strip(), inst_prompt=inst_prompt.strip()
)

In [None]:
palm_resp

In [None]:
palm_resp.reply("yes, this is correct").last

In [None]:
df_eval = pd.DataFrame(columns=["guess1", "guess2", "ground_truth", "clues"])
inst_prompt = inst_prompt_template.format(state="Punjab")

In [None]:
def get_outputs(df, sys_prompt_template, inst_prompt):
    df_eval = pd.DataFrame(columns=["guess1", "guess2", "ground_truth", "clues"])
    for i, row in df.iterrows():
        print(f"artifact--{i}---")
        clues = row["clues"].strip().split("\n")
        artifact = row["artifact"].lower().strip()
        output = ""
        for j, clue in enumerate(clues):
            output += f"CLUE-{j+1}: {clue}\n"
            fin_clues = output.strip()

        sys_prompt = sys_prompt_template.format(cluelist=output)
        palm_resp = palm_completion(sys_prompt=sys_prompt, inst_prompt=inst_prompt)
        palm_reply = (
            palm_resp.last
            if palm_resp != "Answer: api failed"
            else "Answer: api failed"
        )
        guess1 = palm_reply.split("\n")[0].split(":")[1].lower().strip()
        if artifact in guess1 or "api failed" in guess1:
            df_eval = df_eval.append(
                {
                    "guess1": guess1,
                    "guess2": "NA",
                    "ground_truth": artifact,
                    "clues": "\n".join(clues),
                },
                ignore_index=True,
            )
            continue
        else:
            palm_2nd_resp = palm_resp.reply(
                "Your first guess is not correct. While making your second guess, please stick to the format as ANSWER: your_answer_here"
            )
            palm_2nd_reply = palm_2nd_resp.last
            guess2 = palm_2nd_reply.split("\n")[0].split(":")[1].lower()
            df_eval = df_eval.append(
                {
                    "guess1": guess1,
                    "guess2": guess2,
                    "ground_truth": artifact,
                    "clues": "\n".join(clues),
                },
                ignore_index=True,
            )

    return df_eval

In [None]:
def compile_results(
    STATE_CLUES_NOTES_DICT: Dict[str, List[str]],
    output_dir: str,
    inst_prompt: PromptTemplate,
    sys_prompt: str,
):
    for state_name, val in STATE_CLUES_NOTES_DICT.items():
        inst_template = inst_prompt.format(state=state_name)
        curr_path = os.path.join(output_dir, state_name)
        if not os.path.exists(curr_path):
            os.makedirs(curr_path)
        clue_path = val[0]
        notes_path = val[1] if len(val) > 1 else None

        #         print("Getting results for {key}ate")
        print(f"getting results for {state_name} state")
        #         conversation_buffer.clear()
        df_clues = pd.read_csv(clue_path)

        print(f"Running clues eval for {state_name} state")
        clues_result_path = os.path.join(curr_path, "eval_original_artifacts.csv")
        if not os.path.exists(clues_result_path):
            df_clues_eval = get_outputs(df_clues, sys_prompt, inst_template)
            df_clues_eval.to_csv(clues_result_path, index=False)
        else:
            print(f"Clue eval results already exist for {state_name} state")

        if notes_path:
            #             conversation_buffer.clear()
            df_notes = pd.read_csv(notes_path)
            notes_result_path = os.path.join(curr_path, "eval_expanded_artifacts.csv")
            if not os.path.exists(notes_result_path):
                print(f"Running notes eval for {state_name} state")
                df_notes_eval = get_outputs(df_notes, sys_prompt, inst_template)
                df_notes_eval.to_csv(notes_result_path, index=False)
            else:
                print(f"Notes eval results already exist for {state_name} state")

In [None]:
inst_template = PromptTemplate(input_variables=["state"], template=INST_PALM_TEMPLATE)

In [None]:
compile_results(
    STATE_CLUES_NOTES_DICT=STATE_CLUES_NOTES_DICT,
    output_dir="/home/t-sahuja/cultural_artifacts/results/commercial/palm",
    inst_prompt=inst_template,
    sys_prompt=SYS_PALM_TEMPLATE,
)