## Alternating Gradient Ascent-Descent (AGAD)

Original work introducing Gradient Ascent (GA) for unleaning in LLMs: [paper](https://arxiv.org/pdf/2210.01504) [code](https://github.com/joeljang/knowledge-unlearning/tree/main?tab=readme-ov-file)

Alternating forgetting and annealing phases using Gradient Ascent on forget chunks and Gradient Descent on retain data.

Parameters:

- **Chunk size**: Size of the forget chunks
- **Interleaving factor**: Frequency of annealing phases. 0 means no intermediate annealing is performed, 0.5 means that annealing is performed after every 2 forgetting phases while 1 means after every forgetting phase.
- **Annealing fraction**: The fraction of the retain set (randomly sampled for each annealing phase) to perform the annealing on. Can be less than 1 to increase training speed.
- **Final annealing**: Boolean; Whether to perform a final annealing on the whole retain set, after all the forgetting phases.


## Imports

In [None]:
import torch
import warnings
import json
import time
import gc
import os

from peft import LoraConfig, get_peft_model, TaskType
from huggingface_hub import snapshot_download
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    DataCollatorForSeq2Seq
)
from utils.data import DatasetProcessor
from utils.utils import (
    make_compute_metrics,
    preprocess_logits_for_metrics,
    print_number_of_trainable_model_parameters,
    print_gpu_memory,
    plot_metrics,
    plot_training_stats
)
from utils.evaluation import (
    QualitativeEvaluation,
    QuantitativeEvaluation,
    MMLU
)
from methods import AlternatingTrainer

warnings.filterwarnings('ignore')

In [None]:
num_of_gpus = torch.cuda.device_count()
print(num_of_gpus)

## Load training parameters

In [None]:
# Define a directory to store important output (checkpoints, models and evaluation results)
output_dir = 'output_dir'
os.makedirs(output_dir, exist_ok=True)

# Load the training arguments from a json file
with open("configs/alternating_args.json", 'r') as f:
    args = json.load(f)

# Store the training arguments in the output file for future reference
with open(f"{output_dir}/training_args.json", "w") as f:
    f.write(json.dumps(args, indent=4))

print(f"Alternating training with the following arguments:\n\n{json.dumps(args, indent=4)}")

## Load model and tokenizer

In [None]:
if args["model_params"]["model_size"] == "7B":
    model_repo_id = "llmunlearningsemeval2025organization/olmo-finetuned-semeval25-unlearning"
    tokenizer_path = "allenai/OLMo-7B-0724-Instruct-hf"
elif args["model_params"]["model_size"] == "1B":
    model_repo_id = "llmunlearningsemeval2025organization/olmo-1B-model-semeval25-unlearning"
    tokenizer_path = "allenai/OLMo-1B-0724-hf"
else:
    print("Provide a valid option for the model size, either '1B' or '7B'.")

snapshot_download(repo_id=model_repo_id, local_dir='pretrained_model')

In [None]:
model = AutoModelForCausalLM.from_pretrained('pretrained_model', device_map='auto', torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)

In [None]:
print(print_number_of_trainable_model_parameters(model))

## Prepare Data

In [None]:
# Initialize
processor = DatasetProcessor(data_dir='semeval25-unlearning-data/data',
                             tokenizer=tokenizer,
                             n_samples_per_task=None)

# Construct the tokenized datasets as a DatasetDict
dataset = processor(split=args["general"]["split"], task='all', split_tasks=False, split_retain=False)

# Data collator for padding and batching
data_collator = DataCollatorForSeq2Seq(tokenizer, padding='longest', pad_to_multiple_of=8)

In [None]:
dataset

## Prepare model

In [None]:
if args["model_params"]["apply_lora"]:
    lora_config = LoraConfig(
        r=args["model_params"]["lora_r"], # Rank
        lora_alpha=args["model_params"]["lora_alpha"],
        target_modules=["q_proj", "v_proj", "up_proj", "down_proj"],
        bias="none",
        task_type=TaskType.CAUSAL_LM
    )

    model = get_peft_model(model, lora_config)
elif args["model_params"]["train_last_k"]:
    k = args["model_params"]["k"]
    total_layers = len(model.model.layers)
    
    # Freeze all but the last k layers
    for i, layer in enumerate(model.model.layers):
        if i < total_layers - k:  # Freeze these layers
            for param in layer.parameters():
                param.requires_grad = False
        else:  # Keep these layers trainable
            for param in layer.parameters():
                param.requires_grad = True

In [None]:
print(print_number_of_trainable_model_parameters(model))

In [None]:
print_gpu_memory()

## Training Setup

In [None]:
forgetting_training_args = TrainingArguments(
    output_dir=f"{output_dir}/checkpoints/forgetting",
    per_device_train_batch_size=args["forgetting_args"]["per_device_batch_size"],
    per_device_eval_batch_size=16,
    gradient_accumulation_steps=args["forgetting_args"]["gradient_accumulation_steps"],
    eval_accumulation_steps=1,
    learning_rate=args["forgetting_args"]["learning_rate"],
    num_train_epochs=args["forgetting_args"]["num_epochs"],
    logging_steps=4,
    save_strategy="no",
    eval_strategy="epoch",
    report_to ="none",
    eval_on_start=False,
    include_inputs_for_metrics=True
)

annealing_training_args = TrainingArguments(
    output_dir=f"{output_dir}/checkpoints/annealing",
    per_device_train_batch_size=args["annealing_args"]["per_device_batch_size"],
    per_device_eval_batch_size=16,
    gradient_accumulation_steps=args["annealing_args"]["gradient_accumulation_steps"],
    eval_accumulation_steps=1,
    learning_rate=args["annealing_args"]["learning_rate"],
    num_train_epochs=args["annealing_args"]["num_epochs"],
    logging_steps=4,
    save_strategy="no",
    eval_strategy="epoch",
    report_to ="none",
    eval_on_start=False,
    include_inputs_for_metrics=True
)

# Define AlternatingTrainer instance
trainer = AlternatingTrainer(
    model=model,
    tokenizer=tokenizer,
    data_collator=data_collator,
    train_args_forgetting=forgetting_training_args,
    train_args_annealing=annealing_training_args,
    retain_dataset=dataset['retain'],
    forget_dataset=dataset['forget'],
    compute_metrics=make_compute_metrics(model, tokenizer, max_samples=16),
    preprocess_logits_for_metrics=preprocess_logits_for_metrics,
    chunk_size=args["general"]["chunk_size"],
    interleaving_factor=args["general"]["interleaving_factor"],  # Annealing after every 1/IF chunks on average, if zero no intermediate annealing
    perform_final_annealing=args["general"]["final_annealing"],
    eval_after_annealing=True,
    annealing_fraction=args["general"]["annealing_fraction"]  # Annealing on 25% of the retain dataset
)

## Train

In [None]:
trainer.train()

In [None]:
trainer.save_model(output_dir)

In [None]:
summary = trainer.save_summary(output_dir)

print(summary['total_runtime'])
print(summary['total_flos'])

In [None]:
plot_metrics(summary["log_history"], output_dir)

In [None]:
plot_training_stats(summary["log_history"])

In [None]:
if args["model_params"]["apply_lora"]:
    model.merge_and_unload()
    print(model)

In [None]:
model.model.save_pretrained("unlearned_model")
trainer.tokenizer.save_pretrained("unlearned_model")

In [None]:
print_gpu_memory()

In [None]:
del trainer
del model
del tokenizer

In [None]:
gc.collect()
torch.cuda.empty_cache()

In [None]:
print_gpu_memory()

## Final Evaluation

In [None]:
# Optionally run evaluation on MMLU first.
# The MMLU class doesn't run the code of the official MMLU repo.
# However, it provides correct results.

# The list of all 57 topics. Choose a subset for faster results.
topics = ['abstract_algebra',
          'anatomy',
          'astronomy',
          'business_ethics',
          'clinical_knowledge',
          'college_biology',
          'college_chemistry',
          'college_computer_science',
          'college_mathematics',
          'college_medicine',
          'college_physics',
          'computer_security',
          'conceptual_physics',
          'econometrics',
          'electrical_engineering',
          'elementary_mathematics',
          'formal_logic',
          'global_facts',
          'high_school_biology',
          'high_school_chemistry',
          'high_school_computer_science',
          'high_school_european_history',
          'high_school_geography',
          'high_school_government_and_politics',
          'high_school_macroeconomics',
          'high_school_mathematics',
          'high_school_microeconomics',
          'high_school_physics',
          'high_school_psychology',
          'high_school_statistics',
          'high_school_us_history',
          'high_school_world_history',
          'human_aging',
          'human_sexuality',
          'international_law',
          'jurisprudence',
          'logical_fallacies',
          'machine_learning',
          'management',
          'marketing',
          'medical_genetics',
          'miscellaneous',
          'moral_disputes',
          'moral_scenarios',
          'nutrition',
          'philosophy',
          'prehistory',
          'professional_accounting',
          'professional_law',
          'professional_medicine',
          'professional_psychology',
          'public_relations',
          'security_studies',
          'sociology',
          'us_foreign_policy',
          'virology',
          'world_religions']

mmlu_start = time.time()
mmlu = MMLU(topics)
mmlu.run(model_path="unlearned_model", mmlu_metrics_file_path=f"{output_dir}/evaluation/mmlu.json")

print("MMLU time: ", time.time()-mmlu_start)

In [None]:
evaluation_args = {
    "seed": 42,
    "debug": True,
    "keep_files": True,
    "max_new_tokens": 256,
    "compute_metrics_only": False,
    "batch_size": 8,
    "mia_data_path": "semeval25-unlearning-data/mia_data/",
    "data_path": "semeval25-unlearning-data/data/",
    "checkpoint_path": "unlearned_model",
    "output_dir": f"{output_dir}/evaluation",
    "mmlu_metrics_file_path": f"{output_dir}/evaluation/mmlu.json"
}

quantitative_eval = QuantitativeEvaluation(evaluation_args)

In [None]:
quantitative_eval.run()

In [None]:
torch.cuda.empty_cache()

In [None]:
qualitative_eval = QualitativeEvaluation(
    checkpoint_path="unlearned_model",
    path_to_predictions=f'{output_dir}/evaluation',
    path_to_gqa='utils/general_questions.json',
    output_dir=f'{output_dir}/evaluation/qualitative',
    n_samples=5
)

In [None]:
qualitative_eval.run()