In [None]:
import logging
from datetime import datetime
from types import SimpleNamespace

import optuna
import torch as pt
from peft import LoraConfig, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer

from utils.data_loading import CachedBatches, dataset_loaders
from utils.git_and_reproducibility import *
from utils.model_operations import *
from utils.training import MockTrial, loss_fns, set_seeds

config = SimpleNamespace(
    # Model/data configs
    model_id="EleutherAI/pythia-14m",
    forget_set_name="python",
    ret_lora_config=dict(
        # lora_dropout=0.1,
        target_modules=["dense_h_to_4h", "dense_4h_to_h", "query_key_value", "dense"],
    ),
    # Training constants
    unlearn_steps=300,
    batch_size=16,
    # Relearning params
    relearn_steps=100,
    relearn_lr=3e-4,
    eval_batch_size=16,
    relearn_lora_conf=dict(r=1, target_modules="all-linear", lora_dropout=0.1),
)

pt.set_default_device("cuda")
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s %(levelname)s  %(message)s",
    datefmt="%H:%M:%S",
    handlers=[logging.StreamHandler()],
)

# load datasets
tokenizer = AutoTokenizer.from_pretrained(config.model_id)
retain_set = dataset_loaders["wikitext"](tokenizer)
forget_set = dataset_loaders[config.forget_set_name](tokenizer)
retain_batches = CachedBatches(retain_set["train"], config.batch_size)
forget_batches = CachedBatches(forget_set["train"], config.batch_size)
retain_val_batches = CachedBatches(retain_set["validation"], config.eval_batch_size)
forget_val_batches = CachedBatches(forget_set["validation"], config.eval_batch_size)
r_eval_batch = next(retain_val_batches.fresh_iterator())
f_eval_batch = next(forget_val_batches.fresh_iterator())

base_model = AutoModelForCausalLM.from_pretrained(config.model_id)
init_forget = eval_loss(base_model, f_eval_batch)
init_retain = eval_loss(base_model, r_eval_batch)
logging.info(f"init forget: {init_forget:6.2f}    init retain: {init_retain:6.2f}")



In [None]:
unlearn_lr = 0.05

set_seeds(42)
# prepare data iterators
forget_iter = forget_batches.fresh_iterator()
retain_iter = retain_batches.fresh_iterator()
# load model (copy from memory for speed)
model = deepcopy(base_model)
assert all(param.requires_grad for param in model.parameters())

# initialize optimizers
# SGD is faster and more predictable than Adam
# todo, change to Adam?
optimizer = pt.optim.SGD(model.parameters(), lr=unlearn_lr)

# %
# ! unlearning loop
logging.info("step      base_f      base_r")
for step in range(1, 1 + config.unlearn_steps):
    model.train()
    f_input_ids = next(forget_iter)
    r_input_ids = next(retain_iter)



    # ! retain with helper lora
    model.zero_grad(set_to_none=True)
    output = model(r_input_ids, output_hidden_states=True)
    loss = loss_fns["cross_entropy"](output, r_input_ids)
    loss.backward()
    optimizer.step()

    # ! unlearn on the base model
    model.zero_grad(set_to_none=True)
    output = model(f_input_ids, output_hidden_states=True)
    # todo change negative cross entropy to entropy
    loss = loss_fns["negative_cross_entropy"](output, f_input_ids)
    loss.backward()
    optimizer.step()



    # ! eval
    if step % 10 == 0:
        res = {}
        res["base_forget"] = eval_loss(model, f_eval_batch)
        res["base_retain"] = eval_loss(model, r_eval_batch)

        logging.info(f"{step:4} " + " ".join(f"{v:11.2f}" for v in res.values()))

        # # prune if base forget loss doesn't improve
        # if step >= 30 and res["base_forget"] < init_forget + 0.5:
        #     logging.info("Forget loss stalled")
        #     raise optuna.TrialPruned()
        # # prune if nan
        # if any(pt.isnan(v) for v in res.values()):
        #     logging.error("NaN in eval results")
        #     raise optuna.TrialPruned()
        # # prune if retain loss is too high
        # if res["base_retain"] > init_retain + 0.1:
        #     logging.info("Retain performance still broken")
        #     raise optuna.TrialPruned()

# %
# ! final bigger eval relearning
copied_model = deepcopy(model)
retain_val_iter = retain_val_batches.fresh_iterator()
forget_val_iter = forget_val_batches.fresh_iterator()
forget_loss = relearn(copied_model, config, retain_val_iter, forget_val_iter)