In [None]:
# imports
import re
from functools import partial
from pathlib import Path
from uuid import uuid4
from typing import Any

import numpy as np
import pandas as pd
import torch
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from transformers import AutoTokenizer, pipeline, AutoModelForCausalLM
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm
from IPython.display import display
from transformers.tokenization_utils_base import BatchEncoding
from transformers.models.llama.modeling_llama import LlamaForCausalLM

In [None]:
# env
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("DEVICE:", device)

In [None]:
# constants
eedi_train_csv = "data/train.csv"
eedi_test_csv = "data/test.csv"
eedi_miscon_csv = "data/misconception_mapping.csv"
llm_model_id = "meta-llama/Llama-3.2-1B-Instruct"
sbert_model_id = "BAAI/bge-small-en-v1.5"
submission_csv = "submission.csv"
intermediate_dir = ".intm"
random_seed = 20241030
sample = True

In [None]:
# prompt
prompt = """Question: {Question}
Incorrect Answer: {IncorrectAnswer}
Correct Answer: {CorrectAnswer}
Construct Name: {ConstructName}
Subject Name: {SubjectName}

Your task: Identify the misconception behind Incorrect Answer. Answer concisely and generically inside <response>$$INSERT TEXT HERE$$</response>.
Before answering the question think step by step concisely in 1-2 sentence inside <thinking>$$INSERT TEXT HERE$$</thinking> tag and respond your final misconception inside <response>$$INSERT TEXT HERE$$</response> tag."""

In [None]:
# data prep utilities
def apply_template(row, tokenizer):
    messages = [
        {
            "role": "user",
            "content": prompt.format(
                ConstructName=row["ConstructName"],
                SubjectName=row["SubjectName"],
                Question=row["QuestionText"],
                IncorrectAnswer=row[f"CorrectAnswerText"],
                CorrectAnswer=row[f"AnswerText"],
            ),
        }
    ]
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
    )
    return text


def get_correct_answer(row):
    if row["CorrectAnswer"] == "A":
        return row["AnswerAText"]
    elif row["CorrectAnswer"] == "B":
        return row["AnswerBText"]
    elif row["CorrectAnswer"] == "C":
        return row["AnswerCText"]
    elif row["CorrectAnswer"] == "D":
        return row["AnswerDText"]
    return None


def process_option(x, regex):
    result = re.search(regex, x)
    return str(result.group(1)) if result else ""


def remove_prompt(record):
    l = len(record["Prompt"])
    value = record["FullResponse"][l:]
    return value


def extract_response(text):
    subresponses = re.findall(r"<response>(?s:.*?)</response>", text)
    subresponses = [x.strip().replace("<response>", "").replace("</response>", "") for x in subresponses]
    return " ".join(subresponses).strip()


def dfpeek(title: str, df: pd.DataFrame) -> None:
    print(">>>>>>>>>>", title, ">>>>>>>>>")
    display(df.head(1).transpose())
    print("<<<<<<<<<<", title, "<<<<<<<<<<", end="\n\n")


def dfpersist(trigger: bool, df: pd.DataFrame, int_dir: str, run_id: str, fn: str) -> None:
    if not trigger:
        return
    assert run_id is not None
    d = Path(intermediate_dir) / run_id
    d.mkdir(parents=True, exist_ok=True)
    p = d / fn
    if p.exists():
        raise FileExistsError(p.as_posix())
    df.to_parquet(p, index=False)


In [None]:
def prepare_base_data(*, persist: bool = False, run_id: str = None) -> pd.DataFrame:
    # read info
    df = pd.read_csv(
        eedi_train_csv,
        dtype={
            "MisconceptionAId": "Int64",
            "MisconceptionBId": "Int64",
            "MisconceptionCId": "Int64",
            "MisconceptionDId": "Int64",
        },
    ).fillna(-1)
    df_miscon = pd.read_csv(
        eedi_miscon_csv,
        dtype={
            "MisconceptionId": "Int64",
        }
    )

    # store correct answer
    df["CorrectAnswerText"] = df.apply(get_correct_answer, axis=1)

    # pivot out each wrong answer into its own row, currently the 3 wrong answers are within the same record
    df_x = df.melt(
        id_vars=[
            "QuestionId",
            "ConstructName",
            "SubjectName",
            "QuestionText",
            "CorrectAnswer",
            "CorrectAnswerText",
        ],
        value_vars=[
            "AnswerAText",
            "AnswerBText",
            "AnswerCText",
            "AnswerDText",
        ],
        var_name="Option",
        value_name="AnswerText",
    )
    df_y = df.melt(
        id_vars=[
            "QuestionId",
        ],
        value_vars=[
            "MisconceptionAId",
            "MisconceptionBId",
            "MisconceptionCId",
            "MisconceptionDId",
        ],
        var_name="Option",
        value_name="MisconceptionId",
    )

    # remap option values of from "xxxxXxxxx" to "X"
    df_x["Option"] = df_x["Option"].map(partial(process_option, regex=r"Answer([A-D])Text"))
    df_y["Option"] = df_y["Option"].map(partial(process_option, regex=r"Misconception([A-D])Id"))

    # mark correct answers
    df_x["IsCorrectAnswer"] = df_x["CorrectAnswer"] == df_x["Option"]

    # create primary key, drop components, reorder col
    df_x["QuestionId_Answer"] = df_x["QuestionId"].astype(str) + "_" + df_x["Option"].astype(str)
    df_y["QuestionId_Answer"] = df_y["QuestionId"].astype(str) + "_" + df_y["Option"].astype(str)
    df_x.drop(columns=["QuestionId", "Option"], inplace=True)
    df_y.drop(columns=["QuestionId", "Option"], inplace=True)
    df_x = df_x[["QuestionId_Answer"] + [c for c in df_x.columns if c != "QuestionId_Answer"]]
    df_y = df_y[["QuestionId_Answer"] + [c for c in df_y.columns if c != "QuestionId_Answer"]]

    # map misconception text to labels
    df_y = df_y.join(df_miscon, on="MisconceptionId", how="left", lsuffix="a", rsuffix="b")
    df_y = df_y[["QuestionId_Answer", "MisconceptionId", "MisconceptionName"]]

    # merge datasets
    df_xy = df_x.merge(df_y, how="left", on="QuestionId_Answer")

    # persist df_xy
    dfpersist(persist, df_xy, intermediate_dir, run_id, "df_xy.parquet")

    return df_xy

In [None]:
# options to filter data

def filter_by_wrong_answers_only(df_xy: pd.DataFrame) -> pd.DataFrame:
    df_xy = df_xy[~df_xy["IsCorrectAnswer"]]
    return df_xy


def filter_by_wrong_answers_and_misconceptions(df_xy: pd.DataFrame) -> pd.DataFrame:
    return df_xy[(~df_xy["IsCorrectAnswer"]) & (df_xy["MisconceptionId"] != -1)]


def no_filter(df_xy: pd.DataFrame) -> pd.DataFrame:
    return df_xy


def filter_data(df_xy: pd.DataFrame, persist: bool = False, run_id: str = None) -> pd.DataFrame:
    # NOTE: choose filter func
    filter_func = filter_by_wrong_answers_and_misconceptions

    # filter df_xy
    df_xy_filtered = filter_func(df_xy)

    # persist df_xy_filtered
    dfpersist(persist, df_xy_filtered, intermediate_dir, run_id, "df_xy_filtered.parquet")

    return df_xy_filtered

In [None]:
# tokenize data

def tokenize_for_llm(
    tokenizer: Any,
    df_xy: pd.DataFrame,
    *,
    persist: bool = False,
    persist_fn: str = "df_prompt.parquet",
    run_id: str = None,
) -> tuple[pd.DataFrame, BatchEncoding]:
    # generate prompts from records
    df_prompt = df_xy.copy(deep=True)
    # tokenizer = AutoTokenizer.from_pretrained(llm_model_id)
    df_prompt["Prompt"] = df_prompt.apply(
        partial(apply_template, tokenizer=tokenizer),
        axis=1,
    )
    df_prompt = df_prompt[["QuestionId_Answer", "Prompt"]]

    # persist df_xy_filtered
    dfpersist(persist, df_prompt, intermediate_dir, run_id, persist_fn)

    # tokenize : NOTE configure as required
    model_inputs = tokenizer(df_prompt["Prompt"].to_list(), return_tensors="pt", padding=True)

    return df_prompt, model_inputs

In [None]:
# main functions

def generate_zeroshot():
    pass


def generate_predictions():
    pass


def evaluate():
    pass


def main():
    run_id = str(uuid4())
    print("run_id:", run_id)
    df_xy = prepare_base_data(persist=True, run_id=run_id)
    df_xy = filter_data(df_xy, persist=True, run_id=run_id)

    if sample:
        df_xy = df_xy.sample(20)

    df_xy_train, df_xy_test = train_test_split(df_xy, test_size=0.2, random_state=random_seed)
    tokenizer = AutoTokenizer.from_pretrained(llm_model_id)
    tokenizer.pad_token = tokenizer.eos_token
    df_prompt_train, tokens_train = tokenize_for_llm(tokenizer, df_xy_train, persist=True, persist_fn="df_prompt_train.parquet", run_id=run_id,)
    df_prompt_test, tokens_test = tokenize_for_llm(tokenizer, df_xy_test, persist=True, persist_fn="df_prompt_test.parquet", run_id=run_id,)
    # model: LlamaForCausalLM = AutoModelForCausalLM.from_pretrained(llm_model_id)
    # output = model(**tokens_train)
    # tokenizer(output)
    return tokenizer, tokens_train

In [None]:
# entrypoint
tokenizer, tokens_train = main()

In [None]:
tokens_train['input_ids']

In [None]:
tokenizer.decode(tokens_train["input_ids"][0])

In [None]:
model: LlamaForCausalLM = AutoModelForCausalLM.from_pretrained(llm_model_id).to(device)

In [None]:
tokens_train = tokens_train.to(device)

In [None]:
output = model(**tokens_train)

In [None]:
type(output)

In [None]:
logits = output.logits

In [None]:
logits.shape

In [None]:
predicted_token_id = torch.argmax(logits[:, -1, :], dim=-1)

In [None]:
predicted_token_id.shape

In [None]:
output_text = tokenizer.decode(predicted_token_id, skip_special_tokens=True)

In [None]:
output_text