In [5]:
import torch
print("MPS available:", torch.backends.mps.is_available())

MPS available: True


In [6]:
import json

clean_path = "/Users/damianli/Desktop/1508_project/common-sense-reasoning/data/csqa_full.jsonl"

# Safely load all lines
data_raw = []
with open(clean_path, "r") as f:
    for line in f:
        try:
            data_raw.append(json.loads(line.strip()))
        except json.JSONDecodeError:
            pass

print("Loaded:", len(data_raw))
print("Example:", data_raw[0])


Loaded: 7400
Example: {'question': 'The sanctions against the school were a punishing blow, and they seemed to what the efforts the school had made to change?', 'choices': ['ignore', 'enforce', 'authoritarian', 'yell at', 'avoid'], 'answer': 'A', 'short_explanation': 'Because ``` Because "ignore" best shows the sanctions disregarded the school\'s efforts to change.'}


In [11]:
import re
import json

input_file  = "/Users/damianli/Desktop/1508_project/common-sense-reasoning/data/csqa_full.jsonl"
output_file = "/Users/damianli/Desktop/1508_project/common-sense-reasoning/data/csqa_full_clean.jsonl"

cleaned = []
skipped = 0

def clean_text(t):
    if t is None:
        return ""
    t = str(t)

    # Remove ``` and ```plaintext
    t = re.sub(r"```plaintext", "", t)
    t = t.replace("```", "")

    # Remove weird line breaks
    t = t.replace("\n", " ")

    # Remove repeated "Because Because"
    t = t.replace("Because Because", "Because")

    # Normalize double quotes
    t = t.replace('\\"', '"')
    t = t.replace('"', '\\"')

    return t.strip()

with open(input_file, "r") as f:
    for line in f:
        try:
            obj = json.loads(line)
        except:
            skipped += 1
            continue

        obj["question"] = clean_text(obj.get("question", ""))
        obj["short_explanation"] = clean_text(obj.get("short_explanation", ""))

        if isinstance(obj.get("choices", None), list):
            obj["choices"] = [clean_text(c) for c in obj["choices"]]

        cleaned.append(obj)

with open(output_file, "w") as f:
    for x in cleaned:
        f.write(json.dumps(x) + "\n")

print("Cleaning done!")
print("Valid:", len(cleaned))
print("Skipped:", skipped)
print("Saved to:", output_file)


Cleaning done!
Valid: 7400
Skipped: 0
Saved to: /Users/damianli/Desktop/1508_project/common-sense-reasoning/data/csqa_full_clean.jsonl


In [12]:
import re

# =============================
# 1. Clean explanation function
# =============================
def clean_explanation(text):
    if text is None:
        return ""

    text = str(text)

    # Remove all backticks
    text = text.replace("```", "").replace("`", "")

    # Remove words like "plaintext"
    text = re.sub(r"\bplaintext\b|\btext\b|\bpython\b|\bjson\b",
                  "", text, flags=re.IGNORECASE)

    # Remove newlines
    text = text.replace("\n", " ").replace("\r", " ")

    # Collapse multiple spaces
    text = re.sub(r"\s+", " ", text).strip()

    # Normalize repeated "Because"
    text = re.sub(r"^(Because\s+)+", "Because ", text)

    # Ensure ending with period
    if not text.endswith("."):
        text += "."

    return text


# =================================================
# 2. Format choices helper (same as before)
# =================================================
def format_choices(choice_list):
    letters = ["A", "B", "C", "D", "E", "F"]
    return "; ".join(
        f"{letters[i]}: {choice_list[i]}" for i in range(len(choice_list))
    )


# =================================================
# 3. Build model-ready dataset (NO dataset.map)
# =================================================
processed_dataset = []

for ex in data_raw:
    q = ex["question"]
    choices = ex["choices"]
    ans = ex["answer"]
    expl = clean_explanation(ex["short_explanation"])

    input_text = (
        f"question: {q}\n"
        f"choices: {format_choices(choices)}\n"
        f"explain your answer:"
    )

    target_text = f"answer: {ans}. {expl}"

    processed_dataset.append({
        "input_text": input_text,
        "target_text": target_text
    })


print("Processed dataset size:", len(processed_dataset))
print("\n=== Example after formatting ===")
print("INPUT TEXT:\n", processed_dataset[0]["input_text"])
print("\nTARGET TEXT:\n", processed_dataset[0]["target_text"])


Processed dataset size: 7400

=== Example after formatting ===
INPUT TEXT:
 question: The sanctions against the school were a punishing blow, and they seemed to what the efforts the school had made to change?
choices: A: ignore; B: enforce; C: authoritarian; D: yell at; E: avoid
explain your answer:

TARGET TEXT:
 answer: A. Because "ignore" best shows the sanctions disregarded the school's efforts to change.


In [17]:
# =========================================
# 3. Tokenize dataset for T5-Large
# =========================================

from datasets import Dataset
from transformers import T5Tokenizer
import re

print("Loading tokenizer: t5-large ...")
tokenizer = T5Tokenizer.from_pretrained("t5-large")

# -----------------------------------------
# If processed_dataset is a Python list, wrap into HF Dataset
# -----------------------------------------
if isinstance(processed_dataset, list):
    processed_dataset = Dataset.from_list(processed_dataset)

print("Original processed_dataset size:", len(processed_dataset))

# -----------------------------------------
# Simple text cleaner (VERY SAFE)
# -----------------------------------------
def clean_text(t: str) -> str:
    if t is None:
        return ""
    t = str(t)

    # ÂéªÊéâ ``` ‰πãÁ±ªÁöÑ markdown
    t = re.sub(r"`+", "", t)

    # ÂéªÊéâ plaintext / python / json Ëøô‰∫õ tag
    t = re.sub(r"\bplaintext\b|\bpython\b|\bjson\b|\btext\b",
               "", t, flags=re.IGNORECASE)

    # Êç¢Ë°åÂèòÁ©∫Ê†º
    t = t.replace("\n", " ").replace("\r", " ")

    # Â§ö‰∏™Á©∫Ê†ºÂêàÂπ∂
    t = re.sub(r"\s+", " ", t).strip()

    # Â§ÑÁêÜ Because Because...
    t = re.sub(r"^(Because\s+)+", "Because ", t, flags=re.IGNORECASE)

    return t.strip()


# -----------------------------------------
# 1) Split dataset (80% train / 20% val)
# -----------------------------------------
total_size  = len(processed_dataset)
train_size  = int(0.8 * total_size)
val_size    = total_size - train_size

train_dataset = processed_dataset.select(range(train_size))
val_dataset   = processed_dataset.select(range(train_size, total_size))

print("Train size:", len(train_dataset))
print("Val size:", len(val_dataset))


# -----------------------------------------
# 2) Tokenization function (NO batched=True)
# -----------------------------------------
def tokenize_function(example):
    # ÂÖàÂÅöËΩªÈáèÊ∏ÖÊ¥ó
    input_text  = clean_text(example["input_text"])
    target_text = clean_text(example["target_text"])

    # encode input
    model_inputs = tokenizer(
        input_text,
        max_length=384,       # ÊØî 512 Â∞è‰∏ÄÁÇπÔºåÊõ¥Á®≥„ÄÅÊõ¥Âø´
        truncation=True,
        padding="max_length",
    )

    # encode labels
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            target_text,
            max_length=96,
            truncation=True,
            padding="max_length",
        )

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs


# -----------------------------------------
# 3) Apply tokenizer (per-example, batched=False)
# -----------------------------------------
print("Tokenizing train dataset ...")
train_tokenized = train_dataset.map(
    tokenize_function,
    batched=False,
    remove_columns=train_dataset.column_names,  # Âè™‰øùÁïô token Â≠óÊÆµ
)

print("Tokenizing val dataset ...")
val_tokenized = val_dataset.map(
    tokenize_function,
    batched=False,
    remove_columns=val_dataset.column_names,
)

# -----------------------------------------
# 4) Set format for PyTorch
# -----------------------------------------
cols = ["input_ids", "attention_mask", "labels"]
train_tokenized.set_format(type="torch", columns=cols)
val_tokenized.set_format(type="torch", columns=cols)

# -----------------------------------------
# 5) Sanity check
# -----------------------------------------
print("\n=== Tokenization example (raw text) ===")
print("INPUT TEXT:\n", train_dataset[0]["input_text"])
print("\nTARGET TEXT:\n", train_dataset[0]["target_text"])

print("\n=== Tokenization example (IDs) ===")
print("input_ids[:20]:", train_tokenized[0]["input_ids"][:20])
print("labels[:20]:",    train_tokenized[0]["labels"][:20])


Loading tokenizer: t5-large ...
Original processed_dataset size: 7400
Train size: 5920
Val size: 1480
Tokenizing train dataset ...


Map: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 5920/5920 [00:01<00:00, 5351.72 examples/s]


Tokenizing val dataset ...


Map: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1480/1480 [00:00<00:00, 5126.89 examples/s]


=== Tokenization example (raw text) ===
INPUT TEXT:
 question: The sanctions against the school were a punishing blow, and they seemed to what the efforts the school had made to change?
choices: A: ignore; B: enforce; C: authoritarian; D: yell at; E: avoid
explain your answer:

TARGET TEXT:
 answer: A. Because "ignore" best shows the sanctions disregarded the school's efforts to change.

=== Tokenization example (IDs) ===
input_ids[:20]: tensor([  822,    10,    37, 17210,   581,     8,   496,   130,     3,     9,
        24584,    53,  6019,     6,    11,    79,  3776,    12,   125,     8])
labels[:20]: tensor([ 1525,    10,    71,     5,  2070,    96,  3191,   127,    15,   121,
          200,  1267,     8, 17210,  1028, 12327,     8,   496,    31,     7])





In [18]:
import torch
from transformers import T5ForConditionalGeneration, T5Tokenizer
from peft import LoraConfig, get_peft_model

# -----------------------------------------
# Device setup for Mac M-series (MPS)
# -----------------------------------------
device = "mps" if torch.backends.mps.is_available() else "cpu"
print("Using device:", device)

# =========================================
# 1. Load T5-Large (Mac-safe)
# =========================================
print("Loading T5-large ...")

model = T5ForConditionalGeneration.from_pretrained(
    "t5-large",
    torch_dtype=torch.float16,   # MPS supports fp16 compute
    device_map=None              # must NOT use auto on Mac
)

model = model.to(device)


# =========================================
# 2. LoRA configuration for T5-Large
# =========================================
lora_config = LoraConfig(
    r=16,                     # ‚Üë increase rank (large model deserves more)
    lora_alpha=32,            # scaled with r
    lora_dropout=0.05,
    bias="none",
    target_modules=["q", "v"],    # T5 attention projection names
    task_type="SEQ_2_SEQ_LM"
)

# =========================================
# 3. Apply LoRA
# =========================================
print("Applying LoRA ...")
model = get_peft_model(model, lora_config)

print("\n===== Trainable parameters =====")
model.print_trainable_parameters()


Using device: mps
Loading T5-large ...
Applying LoRA ...

===== Trainable parameters =====
trainable params: 4,718,592 || all params: 742,386,688 || trainable%: 0.6356


In [19]:
from transformers import TrainingArguments, Trainer, DataCollatorForSeq2Seq

print("Preparing Trainer ...")

training_args = TrainingArguments(
    output_dir="./t5_large_csqa_lora",
    overwrite_output_dir=True,

    num_train_epochs=3,

    per_device_train_batch_size=1,     
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=16,    

    learning_rate=2e-4,
    warmup_ratio=0.1,

    logging_steps=50,
    evaluation_strategy="epoch",
    save_strategy="epoch",

    fp16=False,
    bf16=False,
    dataloader_num_workers=0,
    torch_compile=False,
)

data_collator = DataCollatorForSeq2Seq(
    tokenizer,
    model=model,
    padding=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_tokenized,
    eval_dataset=val_tokenized,
    data_collator=data_collator,
)

trainer.train()


Preparing Trainer ...


  batch["labels"] = torch.tensor(batch["labels"], dtype=torch.int64)
Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
  5%|‚ñç         | 50/1110 [04:43<1:39:24,  5.63s/it]

{'loss': 263.1364, 'grad_norm': 110.57686614990234, 'learning_rate': 9.009009009009009e-05, 'epoch': 0.14}


  9%|‚ñâ         | 100/1110 [09:25<1:35:18,  5.66s/it]

{'loss': 61.1814, 'grad_norm': 7.3308024406433105, 'learning_rate': 0.00018018018018018018, 'epoch': 0.27}


 14%|‚ñà‚ñé        | 150/1110 [14:08<1:30:36,  5.66s/it]

{'loss': 9.2689, 'grad_norm': 3.7472586631774902, 'learning_rate': 0.0001921921921921922, 'epoch': 0.41}


 18%|‚ñà‚ñä        | 200/1110 [18:50<1:26:00,  5.67s/it]

{'loss': 7.8834, 'grad_norm': 3.5459792613983154, 'learning_rate': 0.00018218218218218218, 'epoch': 0.54}


 23%|‚ñà‚ñà‚ñé       | 250/1110 [23:33<1:21:22,  5.68s/it]

{'loss': 7.5514, 'grad_norm': 3.8082776069641113, 'learning_rate': 0.0001721721721721722, 'epoch': 0.68}


 27%|‚ñà‚ñà‚ñã       | 300/1110 [28:15<1:16:35,  5.67s/it]

{'loss': 7.52, 'grad_norm': 3.7168185710906982, 'learning_rate': 0.00016216216216216218, 'epoch': 0.81}


 32%|‚ñà‚ñà‚ñà‚ñè      | 350/1110 [32:59<1:11:23,  5.64s/it]

{'loss': 7.138, 'grad_norm': 3.1321218013763428, 'learning_rate': 0.00015215215215215214, 'epoch': 0.95}


                                                    
 33%|‚ñà‚ñà‚ñà‚ñé      | 370/1110 [38:12<1:09:42,  5.65s/it]

{'eval_loss': 0.405029296875, 'eval_runtime': 200.2393, 'eval_samples_per_second': 7.391, 'eval_steps_per_second': 7.391, 'epoch': 1.0}


 36%|‚ñà‚ñà‚ñà‚ñå      | 400/1110 [41:03<1:07:02,  5.67s/it] 

{'loss': 6.9368, 'grad_norm': 3.5109355449676514, 'learning_rate': 0.00014214214214214215, 'epoch': 1.08}


 41%|‚ñà‚ñà‚ñà‚ñà      | 450/1110 [45:47<1:02:13,  5.66s/it]

{'loss': 6.9386, 'grad_norm': 3.1115074157714844, 'learning_rate': 0.00013213213213213214, 'epoch': 1.22}


 45%|‚ñà‚ñà‚ñà‚ñà‚ñå     | 500/1110 [50:30<57:42,  5.68s/it]  

{'loss': 6.813, 'grad_norm': 3.2237462997436523, 'learning_rate': 0.00012212212212212213, 'epoch': 1.35}


 50%|‚ñà‚ñà‚ñà‚ñà‚ñâ     | 550/1110 [55:13<52:36,  5.64s/it]

{'loss': 6.7097, 'grad_norm': 3.571535587310791, 'learning_rate': 0.00011211211211211213, 'epoch': 1.49}


 54%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñç    | 600/1110 [59:56<48:14,  5.67s/it]

{'loss': 6.8321, 'grad_norm': 8.747725486755371, 'learning_rate': 0.00010210210210210212, 'epoch': 1.62}


 59%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä    | 650/1110 [1:04:40<43:30,  5.67s/it]

{'loss': 6.6423, 'grad_norm': 3.813053607940674, 'learning_rate': 9.20920920920921e-05, 'epoch': 1.76}


 63%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé   | 700/1110 [1:09:22<38:32,  5.64s/it]

{'loss': 6.6484, 'grad_norm': 3.650789976119995, 'learning_rate': 8.208208208208209e-05, 'epoch': 1.89}


                                                    
 67%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñã   | 740/1110 [1:16:30<34:48,  5.64s/it]

{'eval_loss': 0.3818359375, 'eval_runtime': 200.4041, 'eval_samples_per_second': 7.385, 'eval_steps_per_second': 7.385, 'epoch': 2.0}


 68%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä   | 750/1110 [1:17:27<48:48,  8.13s/it]  

{'loss': 6.5536, 'grad_norm': 3.3807427883148193, 'learning_rate': 7.207207207207208e-05, 'epoch': 2.03}


 72%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñè  | 800/1110 [1:22:11<29:24,  5.69s/it]

{'loss': 6.4478, 'grad_norm': 3.6621367931365967, 'learning_rate': 6.206206206206206e-05, 'epoch': 2.16}


 77%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñã  | 850/1110 [1:26:55<24:29,  5.65s/it]

{'loss': 6.2706, 'grad_norm': 4.251887798309326, 'learning_rate': 5.2052052052052056e-05, 'epoch': 2.3}


 81%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà  | 900/1110 [1:31:38<19:44,  5.64s/it]

{'loss': 6.436, 'grad_norm': 3.684340238571167, 'learning_rate': 4.204204204204204e-05, 'epoch': 2.43}


 86%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå | 950/1110 [1:36:21<15:04,  5.65s/it]

{'loss': 6.4946, 'grad_norm': 3.784519910812378, 'learning_rate': 3.203203203203203e-05, 'epoch': 2.57}


 90%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà | 1000/1110 [1:41:04<10:23,  5.67s/it]

{'loss': 6.4106, 'grad_norm': 3.4443304538726807, 'learning_rate': 2.2022022022022024e-05, 'epoch': 2.7}


 95%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñç| 1050/1110 [1:45:47<05:41,  5.69s/it]

{'loss': 6.6377, 'grad_norm': 3.250004529953003, 'learning_rate': 1.2012012012012012e-05, 'epoch': 2.84}


 99%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñâ| 1100/1110 [1:50:31<00:57,  5.72s/it]

{'loss': 6.5211, 'grad_norm': 3.398838996887207, 'learning_rate': 2.002002002002002e-06, 'epoch': 2.97}


                                                     
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1110/1110 [1:54:48<00:00,  5.67s/it]

{'eval_loss': 0.376220703125, 'eval_runtime': 200.1498, 'eval_samples_per_second': 7.394, 'eval_steps_per_second': 7.394, 'epoch': 3.0}


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1110/1110 [1:54:49<00:00,  6.21s/it]

{'train_runtime': 6889.3988, 'train_samples_per_second': 2.578, 'train_steps_per_second': 0.161, 'train_loss': 20.911044366510065, 'epoch': 3.0}





TrainOutput(global_step=1110, training_loss=20.911044366510065, metrics={'train_runtime': 6889.3988, 'train_samples_per_second': 2.578, 'train_steps_per_second': 0.161, 'total_flos': 2.903151023751168e+16, 'train_loss': 20.911044366510065, 'epoch': 3.0})

In [21]:
from peft import PeftModel
from transformers import T5ForConditionalGeneration, T5Tokenizer

device = "mps"

# ============================================
# 1) Select the correct latest checkpoint
# ============================================
adapter_path = (
    "/Users/damianli/Desktop/1508_project/common-sense-reasoning/"
    "t5_large_csqa_lora/checkpoint-1110"
)

# ============================================
# 2) Load the SAME base model used for training
# ============================================
print("Loading base model (t5-large) ...")
base_model = T5ForConditionalGeneration.from_pretrained(
    "t5-large",
    torch_dtype="float16" if device == "mps" else None,
)

# ============================================
# 3) Load LoRA adapter into base model
# ============================================
print("Loading LoRA adapter from:", adapter_path)
model = PeftModel.from_pretrained(base_model, adapter_path)

# ============================================
# 4) Merge LoRA weights ‚Üí standalone full model
# ============================================
print("Merging LoRA weights ...")
model = model.merge_and_unload()

# ============================================
# 5) Save merged full model
# ============================================
merged_path = (
    "/Users/damianli/Desktop/1508_project/common-sense-reasoning/"
    "t5_large_csqa_lora_merged"
)
model.save_pretrained(merged_path)

# ============================================
# 6) Save tokenizer
# ============================================
tokenizer = T5Tokenizer.from_pretrained("t5-large")
tokenizer.save_pretrained(merged_path)

print("\nMerged model saved to:", merged_path)


Loading base model (t5-large) ...
Loading LoRA adapter from: /Users/damianli/Desktop/1508_project/common-sense-reasoning/t5_large_csqa_lora/checkpoint-1110
Merging LoRA weights ...

Merged model saved to: /Users/damianli/Desktop/1508_project/common-sense-reasoning/t5_large_csqa_lora_merged


In [23]:
from transformers import T5ForConditionalGeneration, T5Tokenizer
import torch

device = "mps"

model_path = "/Users/damianli/Desktop/1508_project/common-sense-reasoning/t5_large_csqa_lora_merged"

tokenizer = T5Tokenizer.from_pretrained(model_path)
model = T5ForConditionalGeneration.from_pretrained(model_path).to(device)

input_text = """question: Where would you put ice to keep it frozen?
choices: A: oven; B: freezer; C: desk; D: backpack; E: pocket
explain your answer:
"""

inputs = tokenizer(input_text, return_tensors="pt").to(device)
outputs = model.generate(**inputs, max_new_tokens=100)

print(tokenizer.decode(outputs[0], skip_special_tokens=True))


answer: B. Because a freezer is designed to store ice for long-term storage.


### Step 4: Small Evaluation Suite (Pre- & Post- fine-tuning)

**Goal**

Before optimizing or switching to larger models (e.g., T5-large), we want a small but clear evaluation protocol to:

- Probe the model‚Äôs ability to:
  - Select the correct answer choice (A/B/C/‚Ä¶)
  - Generate a coherent, on-topic explanation
- Record **before/after** performance for reporting.

**Evaluation setup**

- Sample 10 examples from the cleaned dataset (`csqa_full_clean.jsonl`) with a fixed random seed for reproducibility.
- Use the same input format as training:
  - `question: ...`
  - `choices: A: ...; B: ...; ...`
  - `explain your answer:`
- Let the model generate:
  - A combined answer + explanation text (e.g., `answer: B. Because ...`)

**Metrics**

For each example we record:

1. **Predicted answer letter** (A‚ÄìE)
2. **Answer correctness** (match gold answer letter ‚Üí 0/1)
3. **Explanation text** (raw string)
4. **Explanation length** (number of words)
5. **Heuristic quality flags** (e.g., contains ‚Äúbecause‚Äù, non-empty)

We will:

- Print a human-readable summary for each example.
- Compute **overall accuracy** on the 10-question subset.
- Save a `.csv` file with all fields so we can compare:
  - Baseline model (e.g. `t5-small`)
  - Fine-tuned model (e.g. `t5_csqa_lora_merged`)
  - Later: T5-large or other variants


In [26]:
# =========================================
# Baseline Evaluation for CSQA Model
# =========================================

import torch

device = "mps" if torch.backends.mps.is_available() else "cpu"
print("Using device:", device)

# -----------------------------------------
# 1. Baseline test questions
# -----------------------------------------
baseline_samples = [
    {
        "question": "Sammy wanted to go to where the people were. Where might he go?",
        "choices": ["race track", "populated areas", "the desert", "apartment", "roadblock"],
        "answer": "B"
    },
    {
        "question": "Where do you store fresh vegetables?",
        "choices": ["garage", "refrigerator", "bookshelf", "bathroom", "attic"],
        "answer": "B"
    },
    {
        "question": "If you heat water to 100 degrees Celsius, what will happen?",
        "choices": ["it will freeze", "it will boil", "it will rust", "it will glow", "it will evaporate"],
        "answer": "B"
    },
    {
        "question": "What do people usually use to dry their hands after washing?",
        "choices": ["towel", "hammer", "blanket", "pillow", "shoe"],
        "answer": "A"
    },
    {
        "question": "Where would you typically find books to read?",
        "choices": ["library", "swimming pool", "factory", "garage", "freeway"],
        "answer": "A"
    },
    {
        "question": "If someone wants to relax and reduce stress, what might they do?",
        "choices": ["meditate", "argue", "shout", "work more", "run into danger"],
        "answer": "A"
    },
    {
        "question": "What tool is commonly used to tighten screws?",
        "choices": ["screwdriver", "spoon", "pencil", "comb", "fork"],
        "answer": "A"
    },
    {
        "question": "Where would you likely find many wild animals living together?",
        "choices": ["forest", "kitchen", "bathroom", "rooftop", "office"],
        "answer": "A"
    }
]

print(f"Loaded {len(baseline_samples)} baseline questions.\n")

# -----------------------------------------
# 2. Helper to format choices
# -----------------------------------------
def format_choices_eval(choices):
    letters = ["A", "B", "C", "D", "E"]
    return "; ".join([f"{letters[i]}: {choices[i]}" for i in range(len(choices))])

# -----------------------------------------
# 3. Run model
# -----------------------------------------
def run_model(question, choices):
    input_text = (
        f"question: {question}\n"
        f"choices: {format_choices_eval(choices)}\n"
        f"explain your answer:"
    )

    inputs = tokenizer(input_text, return_tensors="pt").to(device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=120,
            num_beams=4,
            early_stopping=True
        )

    decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return decoded

# -----------------------------------------
# 4. Evaluation loop
# -----------------------------------------
correct = 0
results = []

for i, sample in enumerate(baseline_samples):
    q = sample["question"]
    choices = sample["choices"]
    gold = sample["answer"]

    output = run_model(q, choices)

    # Predict letter by pattern
    pred = None
    for letter in ["A", "B", "C", "D", "E"]:
        if f"answer: {letter}" in output:
            pred = letter
            break
        if output.strip().startswith(letter + "."):
            pred = letter
            break

    is_correct = (pred == gold)
    correct += int(is_correct)

    print("=" * 70)
    print(f"QUESTION {i+1}")
    print("Q:", q)
    print("Choices:", choices)
    print("\nModel Output:\n", output)
    print(f"\nPredicted: {pred} | Gold: {gold} | Correct: {is_correct}")

print("\n" + "=" * 70)
print(f"Final Accuracy: {correct}/{len(baseline_samples)} = {correct/len(baseline_samples):.2f}")


Using device: mps
Loaded 8 baseline questions.

QUESTION 1
Q: Sammy wanted to go to where the people were. Where might he go?
Choices: ['race track', 'populated areas', 'the desert', 'apartment', 'roadblock']

Model Output:
 answer: B. Because "populated areas" best fits Sammy's desire to visit populated areas.

Predicted: B | Gold: B | Correct: True
QUESTION 2
Q: Where do you store fresh vegetables?
Choices: ['garage', 'refrigerator', 'bookshelf', 'bathroom', 'attic']

Model Output:
 answer: B. Because refrigerators are ideal for storing fresh vegetables.

Predicted: B | Gold: B | Correct: True
QUESTION 3
Q: If you heat water to 100 degrees Celsius, what will happen?
Choices: ['it will freeze', 'it will boil', 'it will rust', 'it will glow', 'it will evaporate']

Model Output:
 answer: B. Because boiling occurs when water reaches 100 degrees Celsius.

Predicted: B | Gold: B | Correct: True
QUESTION 4
Q: What do people usually use to dry their hands after washing?
Choices: ['towel', 'h