In [26]:
import pandas as pd
import torch
from accelerate import Accelerator
from datasets import Dataset
import numpy as np

In [27]:
# Optimizing by using GPU if available
device = "mps" if torch.backends.mps.is_available() else "cpu"
accelerator = Accelerator()
print(f"Using device: {device}")

Using device: mps


In [28]:
df = pd.read_csv("./data/b6_train_data.csv")
# turn into a Python list for tokenization
df["choices"] = df['choices'].apply(eval)
dataset = Dataset.from_pandas(df)


test_df = pd.read_csv("./data/b6_test_data.csv")
test_df["choices"] = test_df['choices'].apply(eval)
fpttest_data = Dataset.from_pandas(test_df)

In [29]:
def show_one(example):
    print(f"{example['question']}")
    for c in example['choices']:
        print(f" - {c}")
    print(f" Correct answer: {example['answer']}")

In [30]:
from transformers import AutoModelForMultipleChoice, AutoTokenizer
model_name = "google-bert/bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForMultipleChoice.from_pretrained(model_name).to(device)

# Wrap with `accelerate`
model = accelerator.prepare(model)

Some weights of BertForMultipleChoice were not initialized from the model checkpoint at google-bert/bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [31]:
letter_to_number = {'A': 0, 'B': 1, 'C': 2, 'D': 3}


def get_number(ans):
    try:
        last_word = ans.split()[-1]  # Get the last word
        # Return mapped value or -1 if not found
        return letter_to_number.get(last_word, -1)
    except Exception as e:
        print(f"Error processing answer '{ans}': {e}")
        return -1  # Fallback value


def preprocess(examples):
    # Ensure choices are lists and pad to 4 choices
    examples["choices"] = [
        choice + [""] * (4 - len(choice)) if len(choice) < 4 else choice[:4]
        for choice in examples["choices"]
    ]

    # Number of choices per question (always 4 now)
    choice_lens = [4] * len(examples["choices"])

    # Expand questions to match the number of choices (4 per question)
    questions = [q for q_list in [[question] *
                                  4 for question in examples['question']] for q in q_list]
    choices = sum(examples["choices"], [])  # Flatten choices

    # Convert labels
    labels = [get_number(label) for label in examples['answer']]

    # Tokenize questions and choices as independent pairs
    tokenized_examples = tokenizer(
        list(zip(questions, choices)), truncation=True, padding="max_length"
    )

    # Reshape data: Group every 4 choices together (for each question)
    reshaped_dict = {k: [] for k in tokenized_examples.keys()}
    start = 0
    for _ in range(len(examples["question"])):  # Iterate per question
        for k in tokenized_examples.keys():
            reshaped_dict[k].append(tokenized_examples[k][start: start + 4])
        start += 4

    # Ensure labels match the 4-choice structure
    reshaped_dict['labels'] = labels

    return reshaped_dict

In [32]:
idx = 0
tokenized_data = dataset.map(
    preprocess, batched=True, batch_size=8, load_from_cache_file=False)

Map:   0%|          | 0/3963 [00:00<?, ? examples/s]

Error processing answer 'None': 'NoneType' object has no attribute 'split'
Error processing answer 'None': 'NoneType' object has no attribute 'split'
Error processing answer 'None': 'NoneType' object has no attribute 'split'
Error processing answer 'None': 'NoneType' object has no attribute 'split'
Error processing answer 'None': 'NoneType' object has no attribute 'split'
Error processing answer 'None': 'NoneType' object has no attribute 'split'
Error processing answer 'None': 'NoneType' object has no attribute 'split'
Error processing answer 'None': 'NoneType' object has no attribute 'split'
Error processing answer 'None': 'NoneType' object has no attribute 'split'
Error processing answer 'None': 'NoneType' object has no attribute 'split'
Error processing answer 'None': 'NoneType' object has no attribute 'split'
Error processing answer 'None': 'NoneType' object has no attribute 'split'
Error processing answer 'None': 'NoneType' object has no attribute 'split'
Error processing answer '

In [33]:
# tokenized_data = tokenized_data.remove_columns(
#     ["task_id", "question", "choices", "answer"])  # Keep only tokenized features

In [34]:
decoded_text = tokenizer.decode(
    tokenized_data[9]["input_ids"][0], skip_special_tokens=True)
print(decoded_text)

question : what value does the variable z have after all of the code above executes? int x ; int y ; int z ; x = 3 ; y = 4 ; z = + + x * y + + ; 9


In [36]:
for example in tokenized_data:
    print("input_ids shape:", torch.tensor(example["input_ids"]).shape)
    break

input_ids shape: torch.Size([4, 512])


In [37]:
from transformers import DataCollatorForMultipleChoice
from torch.utils.data import DataLoader


# Initialize collator
data_collator = DataCollatorForMultipleChoice(tokenizer=tokenizer)

# # Create DataLoader with collator
# dataloader = DataLoader(tokenized_data, batch_size=8,
#                         shuffle=True, collate_fn=data_collator)

# # Fetch a batch
# batch = next(iter(dataloader))

# # Now check shape
# # Should be (batch_size, num_choices, seq_length)
# print(batch['input_ids'].shape)

In [38]:
from transformers import TrainingArguments, Trainer, EarlyStoppingCallback
# Split your dataset into train and evaluation sets

train_test_split = tokenized_data.train_test_split(test_size=0.2, seed=42)
train_dataset = train_test_split["train"]
eval_dataset = train_test_split["test"]

torch.mps.empty_cache()

In [39]:
for idx, example in enumerate(train_dataset):
    if example["labels"] == -1:
        print(f"🚨 Found example with label -1 at index {idx}: ")

🚨 Found example with label -1 at index 147: 
🚨 Found example with label -1 at index 524: 
🚨 Found example with label -1 at index 607: 
🚨 Found example with label -1 at index 647: 
🚨 Found example with label -1 at index 748: 
🚨 Found example with label -1 at index 765: 
🚨 Found example with label -1 at index 1013: 
🚨 Found example with label -1 at index 1085: 
🚨 Found example with label -1 at index 1115: 
🚨 Found example with label -1 at index 1139: 
🚨 Found example with label -1 at index 1486: 
🚨 Found example with label -1 at index 1640: 
🚨 Found example with label -1 at index 1885: 
🚨 Found example with label -1 at index 2158: 
🚨 Found example with label -1 at index 2220: 
🚨 Found example with label -1 at index 2522: 
🚨 Found example with label -1 at index 2891: 
🚨 Found example with label -1 at index 2897: 


In [40]:
train_dataset = train_dataset.filter(lambda example: example["labels"] != -1)
eval_dataset = eval_dataset.filter(lambda example: example["labels"] != -1)

Filter:   0%|          | 0/3170 [00:00<?, ? examples/s]

Filter:   0%|          | 0/793 [00:00<?, ? examples/s]

In [42]:

class DebugDataCollatorForMultipleChoice(DataCollatorForMultipleChoice):
    def torch_call(self, examples):
        try:
            num_choices = len(examples[0]["input_ids"])
            print(f"Processing batch with {num_choices} choices")
            # if flag:
            #     print(f"Normal examples: {examples}")  # Prints full batch
            # flag = False
            return super().torch_call(examples)
        except Exception as e:
            print("\n🚨 ERROR in DataCollatorForMultipleChoice 🚨")
            print(f"Problematic examples: {examples}")  # Prints full batch
            raise e  # Re-raise the exception for debugging


# Use the custom collator in Trainer
data_collator = DebugDataCollatorForMultipleChoice(tokenizer)

In [45]:
import os
os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0"

In [43]:
torch.mps.empty_cache()

# Define training arguments
training_args = TrainingArguments(
    output_dir="./mcq_model",
    eval_strategy="epoch",
    per_device_train_batch_size=4,  # Keep batch size low to avoid OOM
    per_device_eval_batch_size=4,
    num_train_epochs=10,  # Early stopping will handle if too high
    save_strategy="epoch",
    logging_dir="./logs",
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    greater_is_better=True,  # Higher accuracy is better
    warmup_steps=50,  # Reduce warmup to prevent high initial memory usage
    weight_decay=0.01,  # Standard regularization
    logging_steps=10,
    optim="adamw_torch",  # More stable than 8-bit optimizers on MPS
    gradient_accumulation_steps=2,  # Helps with low batch sizes
    torch_compile=False,  # Torch compile is unstable on MPS, disable it
)


# Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    data_collator=data_collator,
    eval_dataset=eval_dataset,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
)

In [46]:
# Start training
trainer.train()

# Save the fine-tuned model
model.save_pretrained("./fine_tuned_model")

Processing batch with 4 choices
Processing batch with 4 choices
Processing batch with 4 choices
