## Gold Standard: Retraining from scratch

This notebook performs training of the gold standard model. To this end, we train the base model on the **retain data only** applying supervised fine-tuning with a causal language modeling objective.

## Set up - Load model and data

In [None]:
import torch
import warnings
import wandb
import gc

from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments, DataCollatorForSeq2Seq
from datasets import concatenate_datasets
from peft import LoraConfig, get_peft_model, TaskType

from utils.data import DatasetProcessor
from utils.utils import (
    make_compute_metrics,
    preprocess_logits_for_metrics,
    print_number_of_trainable_model_parameters,
    print_gpu_memory,
    plot_metrics,
    plot_training_stats
)
from utils.evaluation import (
    QualitativeEvaluation,
    QuantitativeEvaluation,
    MMLU
)

warnings.filterwarnings('ignore')

In [None]:
# Optionally use wandb for logging

WANDB_API_KEY = ...

wandb.login(key=WANDB_API_KEY)

wandb.init(project='gold-standard', name='some-run')

In [None]:
# Here, we are loading the Olmo base model, prior to the memorization of the dataset

#base_model = AutoModelForCausalLM.from_pretrained("allenai/OLMo-7B-0724-Instruct-hf", device_map='auto', torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained("allenai/OLMo-7B-0724-Instruct-hf")

In [None]:
print(print_number_of_trainable_model_parameters(base_model))

In [None]:
# Initialize
processor = DatasetProcessor(data_dir='semeval25-unlearning-data/data',
                             tokenizer=tokenizer,
                             n_samples_per_task=None,
                             gold_standard=True)

# Construct the tokenized datasets as a DatasetDict
dataset = processor(split="train", task='all', split_tasks=True, split_retain=True)

# Define the data collator
data_collator = DataCollatorForSeq2Seq(tokenizer, padding='longest', pad_to_multiple_of=8)

In [None]:
dataset

## Train PEFT model

In [None]:
lora_config = LoraConfig(
    r=32, # Rank
    lora_alpha=64,
    target_modules=["q_proj", "k_proj", "v_proj", "up_proj", "down_proj"],
    bias="none",
    task_type=TaskType.CAUSAL_LM
)

In [None]:
peft_model = get_peft_model(base_model, lora_config)
print(print_number_of_trainable_model_parameters(peft_model))

In [None]:
#output_dir = '/opt/ml/output/data'
output_dir = 'output'

peft_training_args = TrainingArguments(
    output_dir=f"{output_dir}/final_model",
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    gradient_accumulation_steps=1,
    eval_accumulation_steps=1,
    learning_rate=1e-4,
    num_train_epochs=10,
    logging_steps=2,
    save_strategy="epoch",
    save_total_limit=1,
    eval_strategy="epoch",
    eval_on_start=True,
    report_to="wandb",
    fp16=True,
    include_inputs_for_metrics=True
)

train_dataset = concatenate_datasets([dataset["retain_1"], dataset["retain_2"], dataset["retain_3"]])

peft_trainer = Trainer(
    model=peft_model,
    tokenizer=tokenizer,
    data_collator=data_collator,
    args=peft_training_args,
    train_dataset=train_dataset,
    eval_dataset=dataset,
    compute_metrics=make_compute_metrics(peft_model, tokenizer, max_samples=32),
    preprocess_logits_for_metrics=preprocess_logits_for_metrics
)

In [None]:
peft_trainer.train()

In [None]:
wandb.finish()

In [None]:
plot_metrics(peft_trainer.state.log_history[:-1], output_dir)

In [None]:
plot_training_stats(peft_trainer.state.log_history[:-1])

In [None]:
peft_model.merge_and_unload()
print(peft_model)

In [None]:
peft_model.model.save_pretrained("gold_standard")
peft_trainer.tokenizer.save_pretrained("gold_standard")

In [None]:
del peft_trainer
del base_model
del peft_model

gc.collect()
torch.cuda.empty_cache()

In [None]:
print_gpu_memory()

## Evaluation

In [None]:
# Quantitative Evaluation

evaluation_args = {
    "seed": 42,
    "debug": True,
    "keep_files": True,
    "max_new_tokens": 256,
    "compute_metrics_only": False,
    "batch_size": 8,
    "mia_data_path": "semeval25-unlearning-data/mia_data/",
    "split": "train",
    "data_path": "semeval25-unlearning-data/data/",
    "checkpoint_path": "gold_standard",
    "output_dir": f"{output_dir}/evaluation",
    "mmlu_metrics_file_path": None
}

quantitative_eval = QuantitativeEvaluation(evaluation_args)

In [None]:
quantitative_eval.run()

In [None]:
torch.cuda.empty_cache()

In [None]:
# Qualitative evaluation

qualitative_eval = QualitativeEvaluation(
    checkpoint_path="gold_standard",
    path_to_predictions=f'{output_dir}/evaluation',
    path_to_gqa='utils/general_questions.json',
    output_dir=f'{output_dir}/evaluation/qualitative',
    n_samples=5
)

In [None]:
qualitative_eval.run()