In [None]:
import json
import os
from typing import List

import dotenv
from langchain.chat_models import ChatAnthropic
from langchain.prompts import PromptTemplate
from langchain.schema import (
    AIMessage,
    HumanMessage,
    SystemMessage,
)
from pydantic import BaseModel, Field

In [None]:
if os.path.basename(os.getcwd()) != "10ds-ai-redbox":
    os.chdir("..")
    print(os.getcwd())

In [None]:
dotenv.load_dotenv(".env")
# Grab it as a dictionary too for convenience
ENV = dotenv.dotenv_values(".env")

model_params = {"max_tokens": 4096, "temperature": 0.2}

llm = ChatAnthropic(
    anthropic_api_key=ENV["ANTHROPIC_API_KEY"],
    max_tokens=model_params["max_tokens"],
    temperature=model_params["temperature"],
    streaming=True,
)

In [None]:
class Judgement(BaseModel):
    judgement: str = Field(
        description="Comment on how good the answer is or not. If there is a mistake please explain what the mistake is."
    )
    correct: bool = Field(description="Whether the answer is correct or not")


class ReadingComprehensionOutput(BaseModel):
    id: str = Field(description="ID of the question")
    question: str = Field(description="Question")
    llm_answer: str = Field(description="Answer")
    context: str = Field(description="Context")
    ground_truth_answers: List[str] = Field(description="Ground truth answers")
    retrieved_chunks_parent_ids: List[str] = Field(
        description="Retrieved chunks parent IDs"
    )

In [None]:
# Load in the outputs we'll be judging

with open("./notebooks/results_SEED=12_DS=1000_k=4.json", "r") as f:
    outputs = json.load(f)["bad_outputs"]

    output_objects = []
    for output_id, output in outputs.items():
        output_objects.append(
            ReadingComprehensionOutput(
                id=output_id,
                question=output["question"],
                llm_answer=output["llm_answer"],
                context=output["context"],
                ground_truth_answers=output["ground_truth_answers"],
                retrieved_chunks_parent_ids=output["retrieved_chunks_parent_ids"],
            )
        )

# Grab a sample for us to show the LLM how to mark
few_shot_output_sample = output_objects[:5]
for x in few_shot_output_sample:
    print(x)

In [None]:
print(few_shot_output_sample[0].model_dump_json(indent=4))

In [None]:
example_judgement_0 = Judgement(
    judgement="The context explicitly mentions that the School plays are normally fully booked every night. The answer is also overly verbose with a lot of unnecessary information about other documents.",
    correct=False,
)

In [None]:
print(few_shot_output_sample[1].model_dump_json(indent=4))

In [None]:
example_judgement_1 = Judgement(
    judgement="The context does say that it is free of charge almost all year around. This should be mentioned in the answer even if not all year round. The answer should be briefer and not mention unrelated information from the context.",
    correct=False,
)

In [None]:
print(few_shot_output_sample[2].model_dump_json(indent=4))

In [None]:
example_judgement_2 = Judgement(
    judgement="Correct time range extracted from the context. Good job is converting the jargony time units of Mega Annums into years. The answer is a bit verbose and could be shortened.",
    correct=True,
)

In [None]:
print(few_shot_output_sample[3].model_dump_json(indent=4))

In [None]:
example_judgement_3 = Judgement(
    judgement="The answer incorrectly covers the year of 1831 and not the 1837 mentioned in the question. Chopin visited London Incognito with Camille Pleyel. The answer is also overly verbose and could be shortened.",
    correct=False,
)

In [None]:
judgement_schema_str = json.dumps(Judgement.model_json_schema(), indent=4)

# Using LLM to mark reading comprehension

_reading_comprehension_template = """\
You are a judge for reading comprehension accuracy. You will return as JSON \
formatted response for each item that will judge. The response will be \
formatted as follows (JSON Schema):\n\n{judgement_schema}\n\n\
Be objective in your judgement. If you are unsure, please mark as incorrect."""

JUDGEMENT_PROMPT = PromptTemplate.from_template(_reading_comprehension_template)

In [None]:
messages = [
    SystemMessage(
        content=JUDGEMENT_PROMPT.format(judgement_schema=judgement_schema_str)
    ),
    HumanMessage(content=few_shot_output_sample[0].model_dump_json(indent=4)),
    AIMessage(content=example_judgement_0.model_dump_json(indent=4)),
    HumanMessage(content=few_shot_output_sample[1].model_dump_json(indent=4)),
    AIMessage(content=example_judgement_1.model_dump_json(indent=4)),
    HumanMessage(content=few_shot_output_sample[2].model_dump_json(indent=4)),
    AIMessage(content=example_judgement_2.model_dump_json(indent=4)),
    HumanMessage(content=few_shot_output_sample[3].model_dump_json(indent=4)),
    AIMessage(content=example_judgement_3.model_dump_json(indent=4)),
    HumanMessage(content=few_shot_output_sample[4].model_dump_json(indent=4)),
]

In [None]:
resp = llm(messages)
print(resp.content)

In [None]:
print(few_shot_output_sample[4].model_dump_json(indent=4))