In [1]:
import os
from typing import Tuple

import configparser
import pandas as pd
import sklearn.metrics as metrics
import torch as t
from transformers import AutoTokenizer, AutoModelForCausalLM

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
config = configparser.ConfigParser()
config.read('../config.ini')

['../config.ini']

In [3]:
DATASETS_DIR = "../../datasets"
FORMATTED_DATASET_NAME = "mars_corr_mmlu"
DATASET_NAME = "mmlu"
DATASET_CATS = ["high_school_mathematics", "college_mathematics", "abstract_algebra"]
SYS_PROMPT = (
    lambda topic: f"The following is a multiple choice question (with answers) about {topic.replace('_', ' ')}. "
)
# Answer with a single letter.
MODEL_FAMILY = "Llama3"
MODEL_SIZE = "8B"
MODEL_TYPE = "chat"

In [4]:
def format_prompt(
    subject, question, answers, sys_prompt: callable, choices=["A", "B", "C", "D"]
) -> Tuple[str, str]:
    user_prompt = (
        f"{question}\n"
        + "\n".join([f"{choice}. {answer}" for choice, answer in zip(choices, answers)])
        + "\nAnswer:"
    )

    return f"{SYS_PROMPT(subject)}\n{user_prompt}"

In [5]:
ANSWER_MAP = ["A", "B", "C", "D"]


def format_dataset(base_path: str) -> pd.DataFrame:
    df = pd.read_parquet(base_path)
    prompts = df.apply(
        lambda row: format_prompt(
            row["subject"], row["question"], row["choices"], SYS_PROMPT
        ),
        axis=1,
    )
    answers = df["answer"].apply(lambda a: ANSWER_MAP[a])
    formatted = pd.DataFrame(
        {"prompt": prompts.tolist(), "answer": answers, "subject": df["subject"]}
    )
    return formatted

In [6]:
def load_model(
    model_family: str, model_size: str, model_type: str, device: str = "cuda:0"
):
    model_path = os.path.join(
        config[model_family]["weights_directory"],
        config[model_family][f"{model_size}_{model_type}_subdir"],
    )

    tokenizer = AutoTokenizer.from_pretrained(str(model_path))
    model = AutoModelForCausalLM.from_pretrained(str(model_path))
    return tokenizer, model.to(device)

In [7]:
for d in DATASET_CATS:
    dataset_f = format_dataset(f"{DATASETS_DIR}/{DATASET_NAME}/{d}")
    dataset_f.to_csv(f"{DATASETS_DIR}/{FORMATTED_DATASET_NAME}/{d}.csv", index=False)

In [8]:
t.set_grad_enabled(False)
tokenizer, model = load_model(MODEL_FAMILY, MODEL_SIZE, MODEL_TYPE)
model.eval()

Loading checkpoint shards: 100%|██████████| 4/4 [01:09<00:00, 17.45s/it]


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (n

In [9]:
def load_statements(dataset_path: str):
    """
    Load statements from csv file, return list of strings.
    """
    dataset = pd.read_csv(dataset_path)
    statements = dataset['prompt'].tolist()
    return statements

In [10]:
choices_ids = [
    tokenizer.encode(choice, add_special_tokens=False)[0]
    for choice in ANSWER_MAP + [f" {c}" for c in ANSWER_MAP]
]


def generate_const_answer(
    tokenizer, model, prompt: str, choices_ids
) -> Tuple[str, str]:
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)

    with t.no_grad():
        outputs = model(input_ids)
        logits = outputs.logits
        last_token_logits = logits[0, -1, :]  # Logits for the last token
        masked_logits = last_token_logits.clone()
        masked_logits[:] = float("-inf")
        masked_logits[choices_ids] = last_token_logits[choices_ids]
        top_const_token_id = t.argmax(masked_logits).item()
        top_unconst_token_id = t.argmax(last_token_logits).item()
        top_const_token = tokenizer.decode([top_const_token_id])
        top_unconst_token = tokenizer.decode([top_unconst_token_id])

    return top_const_token, top_unconst_token

In [11]:
for dataset_df in DATASET_CATS:
    save_dir = os.path.join(DATASETS_DIR, "generations", FORMATTED_DATASET_NAME)
    save_file = os.path.join(save_dir, f"{dataset_df}_generations.csv")
    os.makedirs(save_dir, exist_ok=True)
    
    # Check existing files and remove them
    if os.path.exists(save_file):
        os.remove(save_file)
    
    statements = load_statements(os.path.join(DATASETS_DIR, FORMATTED_DATASET_NAME, f"{dataset_df}.csv"))
    generations_df = []
    for i, statement in enumerate(statements):
        const_answer, unconst_answer = generate_const_answer(
            tokenizer, model, statement, choices_ids
        )
        generations_df.append(
            {
                "statement": statement,
                "const_answer": const_answer,
                "unconst_answer": unconst_answer,
            }
        )
    
    generations_df = pd.DataFrame(generations_df)
    generations_df.to_csv(save_file, index=False)

We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)


---

In [12]:
GENERATIONS_PATH = os.path.join(DATASETS_DIR, "generations", FORMATTED_DATASET_NAME)


def label_to_index(label):
    """
    Return the index of a valid label in ANSWER_MAP.
    If the label is out of the map, return -1 (i.e., "incorrect").
    """
    label_stripped = label.strip()
    if label_stripped in ANSWER_MAP:
        return ANSWER_MAP.index(label_stripped)
    else:
        # Treat anything else as "incorrect."
        return -1


EVAL_METRICS = {
    "accuracy_score": lambda y_true, y_pred: metrics.accuracy_score(y_true, y_pred),
    "f1_score": lambda y_true, y_pred: metrics.f1_score(
        y_true, y_pred, average="weighted"
    ),
    "precision_score": lambda y_true, y_pred: metrics.precision_score(
        y_true, y_pred, average="weighted"
    ),
    "recall_score": lambda y_true, y_pred: metrics.recall_score(
        y_true, y_pred, average="weighted"
    ),
}

In [13]:
for dataset in DATASET_CATS:
    try:
        # Load the CSV with ground truths
        dataset_df = pd.read_csv(
            f"{DATASETS_DIR}/{FORMATTED_DATASET_NAME}/{dataset}.csv"
        )
        # Load the CSV with generations
        generations_df = pd.read_csv(f"{GENERATIONS_PATH}/{dataset}_generations.csv")

        # Convert from string label to index:
        y_true = dataset_df["answer"].apply(label_to_index)
        y_const_pred = generations_df["const_answer"].apply(label_to_index)
        y_unconst_pred = generations_df["unconst_answer"].apply(label_to_index)

        # Compute metrics
        const_metrics = {
            metric: EVAL_METRICS[metric](y_true, y_const_pred)
            for metric in EVAL_METRICS
        }
        unconst_metrics = {
            metric: EVAL_METRICS[metric](y_true, y_unconst_pred)
            for metric in EVAL_METRICS
        }

        # Prepare a nice “wide” DataFrame of metrics: row = metric, columns = [const, unconst]
        metrics_rows = []
        for metric_name in EVAL_METRICS.keys():
            metrics_rows.append(
                {
                    "metric": metric_name,
                    "const": const_metrics[metric_name],
                    "unconst": unconst_metrics[metric_name],
                }
            )
        metrics_df = pd.DataFrame(metrics_rows)

        # Save metrics
        metrics_csv_path = os.path.join(GENERATIONS_PATH, f"{dataset}_metrics.csv")
        metrics_df.to_csv(metrics_csv_path, index=False)
        print(f"Saved metrics to: {metrics_csv_path}")

    except Exception as e:
        print(f"Error computing metrics for {dataset}: {e}")

Saved metrics to: ../../datasets/generations/mars_corr_mmlu/high_school_mathematics_metrics.csv
Saved metrics to: ../../datasets/generations/mars_corr_mmlu/college_mathematics_metrics.csv
Saved metrics to: ../../datasets/generations/mars_corr_mmlu/abstract_algebra_metrics.csv


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
