# Stable Sequential Unlearning (SSU) Framework

Complete implementation following the paper methodology:

## Pipeline Overview:
1. **Initial Fine-tuning (Step 0)**: Fine-tune vanilla model on all copyrighted books (D_f) to make it memorize them
2. **Sequential Unlearning (Steps 1-N)**: Unlearn books one at a time using SSU methodology
   - Each step unlearns one book (D_f^t)
   - Uses composite loss (L_fgt + L_rnd) and weight saliency
   - Applies task vector negation

## Datasets:
- **D_f**: All copyrighted books (10 books from Project Gutenberg)
- **D_f^t**: Book to unlearn at time step t
- **D_prev**: Previously unlearned books (aggregated from previous steps)
- **D_nor**: Retention data (200 chunks from 100 other books) - for evaluation

Works on both local and Kaggle environments with automatic retry logic.


## 1. Install Dependencies


In [None]:
# Install required packages
%pip install -q torch transformers peft datasets accelerate requests protobuf==3.20.3 rouge-score sentence-transformers scikit-learn


## 2. Configuration


In [None]:
# Disable torch.compile globally to avoid CUDA capability issues
import os
import time
import sys
import math
import random
import requests
import shutil

import torch
import torch._dynamo

from typing import Dict, Sequence
from types import SimpleNamespace
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
from peft import get_peft_model, LoraConfig
from rouge_score import rouge_scorer


# Set environment variable to disable torch compilation
os.environ['TORCH_COMPILE_DISABLE'] = '1'

# Also disable dynamo
# torch._dynamo.config.disable = True
#
# print("torch.compile disabled globally (for Tesla P100 compatibility)")

In [None]:
# Configuration Class
class Config:
    # Model Configuration - Use smaller model to avoid download issues
    MODEL_NAME = "google/gemma-3-1b-it"

    # Alternative options:
    # MODEL_NAME = "microsoft/Phi-3-mini-4k-instruct"  # 3.8B, non-gated
    # MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"  # Small, fast, non-gated

    TOKENIZER_NAME = MODEL_NAME

    # HuggingFace Authentication
    USE_HF_TOKEN = True  # Set True for gated models

    # PEFT/LoRA Configuration - Match paper values
    LORA_R = 8
    LORA_ALPHA = 16
    LORA_DROPOUT = 0.05
    TARGET_MODULES = ["q_proj", "v_proj"]

    # Prototyping mode - smaller runs for testing
    PROTOTYPE_MODE = True  # Set False for full runs

    if PROTOTYPE_MODE:
        NUM_UNLEARNING_STEPS = 3
        NUM_CHUNKS_PER_STEP = 30
        NUM_RETENTION_BOOKS = 30
        NUM_RETENTION_CHUNKS = 60
        EVAL_MAX_PAIRS = 5
        EVAL_NUM_SAMPLES = 3
    else:
        NUM_UNLEARNING_STEPS = 10
        NUM_CHUNKS_PER_STEP = 50
        NUM_RETENTION_BOOKS = 100
        NUM_RETENTION_CHUNKS = 200
        EVAL_MAX_PAIRS = 5
        EVAL_NUM_SAMPLES = 5

    # Fine-Tuning Hyperparameters
    BATCH_SIZE = 2  # Reduced further to avoid OOM
    GRADIENT_ACCUMULATION_STEPS = 8
    LEARNING_RATE = 5e-5  # Base LR, will be overridden per step (1e-5 for steps 1-5, 1e-6 for 6-10)
    NUM_EPOCHS_FT = 1  # 1 epoch for initial fine-tuning (as per paper)
    NUM_EPOCHS_INITIAL_FT = 1  # Initial fine-tuning on all books (D_f)
    WARMUP_STEPS = 10
    WEIGHT_DECAY = 0.01

    # SSU Methodology Parameters - Match paper values
    EPSILON_1 = 1.0  # Weight for Forgetting Loss (L_fgt) - λ1
    EPSILON_2 = 0.1  # Weight for Random Labeling Loss (L_rnd) - λ2
    # GAMMA removed - now computed dynamically as mu + sigma from gradients

    # Data Configuration
    CHUNK_SIZE = 256
    USE_REAL_BOOKS = True  # Use real books from Project Gutenberg
    DATA_DIR = "gutenberg_books"

    # Project Gutenberg Book IDs - Exact 10 books from paper (in order)
    ALL_BOOK_IDS = [
        1661,   # Sherlock Holmes - Step 1
        84,     # Frankenstein - Step 2
        1342,   # Pride and Prejudice - Step 3
        11,     # Alice in Wonderland - Step 4
        2701,   # Moby Dick - Step 5
        74,     # The Adventures of Tom Sawyer - Step 6
        98,     # A Tale of Two Cities - Step 7
        5200,   # Metamorphosis - Step 8
        6130,   # The Iliad - Step 9
        174,    # The Picture of Dorian Gray - Step 10
    ]

    # Books to unlearn at each time step (sequential) - 10 steps
    GUTENBERG_BOOK_IDS = {
        1: [1661],   # Sherlock Holmes
        2: [84],     # Frankenstein
        3: [1342],   # Pride and Prejudice
        4: [11],     # Alice in Wonderland
        5: [2701],   # Moby Dick
        6: [74],     # The Adventures of Tom Sawyer
        7: [98],     # A Tale of Two Cities
        8: [5200],   # Metamorphosis
        9: [6130],   # The Iliad
        10: [174],   # The Picture of Dorian Gray
    }

    # Retention data (D_nor) - chunks from books disjoint from ALL_BOOK_IDS
    USE_RETENTION_DATA = True  # D_nor used for evaluation and random labels

    # Storage / disk usage configuration
    SAVE_MEMORIZED_MODEL = False  # Avoid saving large memorized checkpoint unless needed
    SAVE_STEP_MODELS = False      # Only persist final model by default
    DELETE_PREVIOUS_STEP_MODELS = True  # Remove older step checkpoints when new ones are saved
    KEEP_DOWNLOADED_BOOKS = False  # Remove raw .txt files after chunking to save space
    CLEANUP_FINAL_MODEL_DIR = False  # Optional: remove final model dir after exporting elsewhere

    OUTPUT_DIR = "ssu_unlearned_models"

print("Configuration loaded!")


## 3. Environment Detection & Setup


In [None]:
# Detect if running on Kaggle
IS_KAGGLE = os.path.exists('/kaggle')
IS_COLAB = 'google.colab' in sys.modules

print(f"Running on: {'Kaggle' if IS_KAGGLE else 'Colab' if IS_COLAB else 'Local'}")

# HuggingFace Authentication (if needed)
if Config.USE_HF_TOKEN:
    from huggingface_hub import login

    hf_token = None

    # Try Kaggle Secrets
    if IS_KAGGLE:
        try:
            from kaggle_secrets import UserSecretsClient
            user_secrets = UserSecretsClient()
            hf_token = user_secrets.get_secret("HF_TOKEN")
            print("Found HuggingFace token in Kaggle Secrets.")
        except Exception as e:
            print(f"Warning: Could not get HuggingFace token from Kaggle Secrets: {e}")

    # Try environment variable
    if not hf_token:
        hf_token = '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'

    if hf_token:
        try:
            login(token=hf_token, add_to_git_credential=False)
            print("Successfully logged in to HuggingFace.")
        except Exception as e:
            print(f"Warning: Could not login: {e}")
    else:
        print("WARNING: No HuggingFace token found. Gated models will fail.")
else:
    print("Using non-gated model - no authentication needed.")


## 5. SSU Model & Trainer


In [None]:
# Dummy book text for simulation
DUMMY_BOOK_TEXT = """
In the beginning God created the heavens and the earth. Now the earth was formless and empty, darkness was over the surface of the deep, and the Spirit of God was hovering over the waters. And God said, "Let there be light," and there was light. God saw that the light was good, and he separated the light from the darkness. God called the light "day," and the darkness he called "night." And there was evening, and there was morning—the first day.

And God said, "Let there be a vault between the waters to separate water from water." So God made the vault and separated the water under the vault from the water above it. And it was so. God called the vault "sky." And there was evening, and there was morning—the second day.

And God said, "Let the water under the sky be gathered to one place, and let dry ground appear." And it was so. God called the dry ground "land," and the gathered waters he called "seas." And God saw that it was good. Then God said, "Let the land produce vegetation: seed-bearing plants and trees on the land that bear fruit with seed in it, according to their various kinds." And it was so. The land produced vegetation: plants bearing seed according to their kinds and trees bearing fruit with seed in it according to their kinds. And God saw that it was good. And there was evening, and there was morning—the third day.

And God said, "Let there be lights in the vault of the sky to separate the day from the night, and let them serve as signs to mark sacred times, and days and years, and let them be lights in the vault of the sky to give light on the earth." And it was so. God made two great lights—the greater light to govern the day and the lesser light to govern the night. He also made the stars. God set them in the vault of the sky to give light on the earth, to govern the day and the night, and to separate light from darkness. And God saw that it was good. And there was evening, and there was morning—the fourth day.

And God said, "Let the water teem with living creatures, and let birds fly above the earth across the vault of the sky." So God created the great creatures of the sea and every living thing with which the water teems and that moves about in it, according to their kinds, and every winged bird according to its kind. And God saw that it was good. God blessed them and said, "Be fruitful and increase in number and fill the water in the seas, and let the birds increase on the earth." And there was evening, and there was morning—the fifth day.

And God said, "Let the land produce living creatures according to their kinds: the livestock, the creatures that move along the ground, and the wild animals, each according to its kind." And it was so. God made the wild animals according to their kinds, the livestock according to their kinds, and all the creatures that move along the ground according to their kinds. And God saw that it was good. Then God said, "Let us make mankind in our image, in our likeness, so that they may rule over the fish in the sea and the birds in the sky, over the livestock and all the wild animals, and over all the creatures that move along the ground." So God created mankind in his own image, in the image of God he created them; male and female he created them. God blessed them and said to them, "Be fruitful and increase in number; fill the earth and subdue it. Rule over the fish in the sea and the birds in the sky and over every living creature that moves on the ground." Then God said, "I give you every seed-bearing plant on the face of the whole earth and every tree that has fruit with seed in it. They will be yours for food. And to all the beasts of the earth and all the birds in the sky and all the creatures that move along the ground—everything that has the breath of life in it—I give every green plant for food." And it was so. God saw all that he had made, and it was very good. And there was evening, and there was morning—the sixth day.

Thus the heavens and the earth were completed in all their vast array. By the seventh day God had finished the work he had been doing; so on the seventh day he rested from all his work. Then God blessed the seventh day and made it holy, because on it he rested from all the work of creating that he had done.
""" * 50


def generate_simulated_data(text, chunk_size, num_chunks, tokenizer_name):
    """Simulates a list of text chunks for one book (D_f^t).

    This function splits long book text into smaller chunks that fit within the model's
    maximum sequence length. It does this by:
    1. Splitting text into small word chunks (to avoid tokenization warnings)
    2. Tokenizing each small chunk separately
    3. Combining tokenized chunks to create final chunks of the desired size
    """
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # Get model max length - this is the MAXIMUM tokens the model can handle
    # TinyLlama has max_length of 2048 tokens
    max_length = getattr(tokenizer, 'model_max_length', 32000)

    # Fix for huge model_max_length causing OverflowError
    if max_length > 100000:
        max_length = 8192  # Set a reasonable limit

    # Use a safe chunk size that's smaller than model max length
    # We use chunk_size from config (256) but ensure it's under the model limit
    safe_chunk_size = min(chunk_size, max_length - 10)  # Leave some margin for safety

    # IMPORTANT: We need to split text into VERY small pieces before tokenizing
    # Why? Because 1 word can become 1-3 tokens, and we want to stay under max_length
    # Strategy: Tokenize in small batches (500-800 words max) to avoid warnings
    words = text.split()

    # Conservative estimate: ~1.5 tokens per word on average
    # So for max_length=2048, we want max ~1300 words per tokenization batch
    # But to be extra safe, we'll use even smaller: 500 words per batch
    words_per_batch = min(500, max_length // 3)  # Very conservative: 500 words max per batch

    # Step 1: Tokenize text in small batches to avoid warnings
    all_token_ids = []
    for i in range(0, len(words), words_per_batch):
        batch_text = ' '.join(words[i:i + words_per_batch])
        # Tokenize with truncation - this ensures we never exceed max_length
        tokenized = tokenizer(
            batch_text,
            return_tensors='pt',
            truncation=True,  # This truncates if too long
            max_length=max_length,  # Use model's actual max length
            add_special_tokens=True
        )['input_ids'][0]
        all_token_ids.extend(tokenized.tolist())

    # Step 2: If we don't have enough tokens, repeat the sequence
    if len(all_token_ids) < safe_chunk_size * num_chunks:
        repeat_factor = (safe_chunk_size * num_chunks // len(all_token_ids)) + 1
        all_token_ids = (all_token_ids * repeat_factor)[:safe_chunk_size * num_chunks * 2]

    # Step 3: Split into chunks of the desired size
    chunks = []
    for i in range(0, len(all_token_ids) - safe_chunk_size + 1, safe_chunk_size):
        chunk = all_token_ids[i:i + safe_chunk_size]
        if len(chunk) == safe_chunk_size:
            chunks.append(chunk)
        if len(chunks) >= num_chunks:
            break

    # Step 4: If we still don't have enough, pad the last chunk
    while len(chunks) < num_chunks:
        if chunks:
            # Repeat last chunk or pad
            last_chunk = chunks[-1][:safe_chunk_size]
            if len(last_chunk) < safe_chunk_size:
                last_chunk = last_chunk + all_token_ids[:safe_chunk_size - len(last_chunk)]
            chunks.append(last_chunk[:safe_chunk_size])
        else:
            # If no chunks at all, create a dummy chunk
            chunks.append(all_token_ids[:safe_chunk_size])

    chunks = chunks[:num_chunks]

    # Step 5: Decode back to text (this is what the dataset will use)
    text_chunks = [tokenizer.decode(c, skip_special_tokens=True) for c in chunks]
    return text_chunks


def maybe_delete_file(file_path):
    """Delete a file if KEEP_DOWNLOADED_BOOKS is False."""
    if Config.KEEP_DOWNLOADED_BOOKS:
        return
    if file_path and os.path.exists(file_path):
        try:
            os.remove(file_path)
        except OSError:
            pass


def cleanup_dir(path):
    """Remove a directory tree if it exists."""
    if path and os.path.exists(path):
        shutil.rmtree(path, ignore_errors=True)


class SequentialUnlearningDataset(Dataset):
    """Custom Dataset that supports SSU forget data and retention data."""
    def __init__(self, tokenizer, data_texts, mode="forget", random_label_ids=None):
        if mode not in {"forget", "retain"}:
            raise ValueError(f"Unsupported dataset mode: {mode}")
        self.mode = mode
        self.tokenizer = tokenizer
        self.data_texts = data_texts

        tokenized = tokenizer(
            data_texts,
            truncation=True,
            padding="max_length",
            max_length=Config.CHUNK_SIZE,
            return_tensors='pt'
        )
        self.input_ids = tokenized['input_ids']
        self.attention_mask = tokenized['attention_mask']

        self.random_label_ids = random_label_ids
        self.random_indices = None
        if self.mode == "forget":
            if self.random_label_ids is None:
                raise ValueError("random_label_ids is required for forget mode")
            self.random_indices = list(range(self.random_label_ids.size(0)))
            random.shuffle(self.random_indices)

    def __len__(self):
        return len(self.data_texts)

    def __getitem__(self, idx):
        input_ids = self.input_ids[idx].clone()
        attention_mask = self.attention_mask[idx].clone()
        sample = {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
        }

        if self.mode == "forget":
            labels_fgt = input_ids.clone()
            # Sample labels_rnd from D_nor (random_label_ids), not from D_f
            rnd_idx = self.random_indices[idx % len(self.random_indices)]
            labels_rnd = self.random_label_ids[rnd_idx].clone()
            sample.update({
                'labels_fgt': labels_fgt,
                'labels_rnd': labels_rnd,
            })
        else:
            sample['labels'] = input_ids.clone()

        return sample


class SSUDataCollator:
    """Custom data collator supporting mixed forget/retain batches."""
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def __call__(self, features):

        if not features:
            raise ValueError("Empty features list passed to data collator")

        def to_tensor(x):
            if isinstance(x, torch.Tensor):
                return x
            return torch.tensor(x, dtype=torch.long)

        input_ids = [to_tensor(f['input_ids']) for f in features]
        attention_mask = [to_tensor(f['attention_mask']) for f in features]

        batch = {}
        pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
        batch['input_ids'] = pad_sequence(
            input_ids,
            batch_first=True,
            padding_value=pad_token_id,
        )
        batch['attention_mask'] = pad_sequence(
            attention_mask,
            batch_first=True,
            padding_value=0,
        )
        max_len = batch['input_ids'].size(1)
        batch_size = len(features)
        device = batch['input_ids'].device
        dtype = batch['input_ids'].dtype

        has_retain = any('labels' in f for f in features)
        has_forget = any('labels_fgt' in f for f in features)

        # For forget-only batches, labels will be all -100
        labels = torch.full((batch_size, max_len), -100, dtype=dtype, device=device)
        for idx, f in enumerate(features):
            if 'labels' in f:
                data = to_tensor(f['labels']).to(device)
                labels[idx, :data.shape[0]] = data
        batch['labels'] = labels

        if has_forget:
            labels_fgt = torch.full((batch_size, max_len), -100, dtype=dtype, device=device)
            labels_rnd = torch.full((batch_size, max_len), -100, dtype=dtype, device=device)
            for idx, f in enumerate(features):
                if 'labels_fgt' in f:
                    data_fgt = to_tensor(f['labels_fgt']).to(device)
                    labels_fgt[idx, :data_fgt.shape[0]] = data_fgt
                if 'labels_rnd' in f:
                    data_rnd = to_tensor(f['labels_rnd']).to(device)
                    labels_rnd[idx, :data_rnd.shape[0]] = data_rnd
            batch['labels_fgt'] = labels_fgt
            batch['labels_rnd'] = labels_rnd

        return batch

def download_gutenberg_book(book_id, output_dir):
    """Download a book from Project Gutenberg by ID.

    Returns:
        str: Path to downloaded book file, or None if download failed.
    """
    os.makedirs(output_dir, exist_ok=True)
    book_file = os.path.join(output_dir, f"{book_id}.txt")

    if os.path.exists(book_file):
        print(f"Book {book_id} already exists, skipping download.")
        return book_file

    url = f"https://www.gutenberg.org/files/{book_id}/{book_id}-0.txt"
    try:
        print(f"Downloading book {book_id} from Project Gutenberg...")
        response = requests.get(url, timeout=30)
        response.raise_for_status()
        with open(book_file, 'w', encoding='utf-8') as f:
            f.write(response.text)
        print(f"Downloaded book {book_id} successfully.")
        return book_file
    except:
        url_alt = f"https://www.gutenberg.org/cache/epub/{book_id}/pg{book_id}.txt"
        try:
            response = requests.get(url_alt, timeout=30)
            response.raise_for_status()
            with open(book_file, 'w', encoding='utf-8') as f:
                f.write(response.text)
            print(f"Downloaded book {book_id} successfully.")
            return book_file
        except Exception as e:
            print(f"Error: Could not download book {book_id} from Project Gutenberg.")
            return None


def load_book_text(book_file):
    """Load and clean text from a book file."""
    if not book_file or not os.path.exists(book_file):
        return None
    with open(book_file, 'r', encoding='utf-8', errors='ignore') as f:
        text = f.read()
    # Remove Project Gutenberg headers/footers
    start_markers = ["*** START OF", "***START OF", "START OF THE PROJECT"]
    end_markers = ["*** END OF", "***END OF", "END OF THE PROJECT"]
    for marker in start_markers:
        idx = text.find(marker)
        if idx != -1:
            text = text[text.find('\n', idx) + 1:]
            break
    for marker in end_markers:
        idx = text.find(marker)
        if idx != -1:
            text = text[:idx]
            break
    text = ' '.join(text.split())
    return text


def get_all_books_for_initial_finetuning():
    """Downloads all books for initial fine-tuning (D_f) - makes model memorize them."""
    print("\n=== Downloading all books for initial fine-tuning (D_f) ===")
    all_books_dir = os.path.join(Config.DATA_DIR, "all_books")
    os.makedirs(all_books_dir, exist_ok=True)

    book_texts = []
    for book_id in Config.ALL_BOOK_IDS:
        book_file = download_gutenberg_book(book_id, all_books_dir)
        if not book_file:
            raise RuntimeError(f"Failed to download book {book_id} for initial fine-tuning")
        text = load_book_text(book_file)
        if not text or len(text) < 1000:
            raise RuntimeError(f"Book {book_id} text is invalid or too short for initial fine-tuning")
        book_texts.append(text)
        maybe_delete_file(book_file)
        print(f"Loaded book {book_id} ({len(text)} chars)")

    if not book_texts:
        raise RuntimeError("No books downloaded for initial fine-tuning")

    return book_texts


def get_unlearning_datasets(random_label_ids=None):
    """Generates sequential datasets D_f^1, D_f^2, ... for each time step.

    Args:
        random_label_ids: Tensor of input_ids from D_nor to use for labels_rnd.
                          Required for forget mode datasets.
    """
    tokenizer = AutoTokenizer.from_pretrained(Config.TOKENIZER_NAME)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    datasets = []

    for t in range(Config.NUM_UNLEARNING_STEPS):
        print(f"\n--- Preparing dataset D_f^{t+1} for time step {t+1} ---")

        if Config.USE_REAL_BOOKS:
            # Download books for this specific time step
            time_step_dir = os.path.join(Config.DATA_DIR, f"time_step_{t+1}")
            os.makedirs(time_step_dir, exist_ok=True)

            book_ids = Config.GUTENBERG_BOOK_IDS.get(t + 1, [])
            book_texts = []
            for book_id in book_ids:
                book_file = download_gutenberg_book(book_id, time_step_dir)
                if not book_file:
                    raise RuntimeError(f"Failed to download book {book_id} for step {t+1}")
                text = load_book_text(book_file)
                if not text or len(text) < 1000:
                    raise RuntimeError(f"Book {book_id} text is invalid or too short for step {t+1}")
                book_texts.append(text)
                maybe_delete_file(book_file)
                print(f"Loaded book {book_id} for step {t+1}")

            if not book_texts:
                raise RuntimeError(f"No valid books for step {t+1}")
        else:
            raise RuntimeError("USE_REAL_BOOKS must be True for paper reproduction")

        all_chunks = []
        for book_text in book_texts:
            chunks = generate_simulated_data(
                book_text,
                Config.CHUNK_SIZE,
                Config.NUM_CHUNKS_PER_STEP // len(book_texts) + 1,
                Config.TOKENIZER_NAME
            )
            all_chunks.extend(chunks)

        data_t = all_chunks[:Config.NUM_CHUNKS_PER_STEP]
        print(f"Created {len(data_t)} chunks for time step {t+1}")

        # Create forget dataset with random_label_ids from D_nor
        dataset_t = SequentialUnlearningDataset(
            tokenizer,
            data_t,
            mode="forget",
            random_label_ids=random_label_ids
        )
        datasets.append(dataset_t)

    return datasets


def get_retention_dataset():
    """Generates retention dataset D_nor (non-targeted data to keep).

    Returns dataset with 200 chunks from 100 books disjoint from ALL_BOOK_IDS.
    """
    if not Config.USE_RETENTION_DATA:
        return None

    print("\n=== Preparing retention dataset D_nor ===")
    tokenizer = AutoTokenizer.from_pretrained(Config.TOKENIZER_NAME)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # Generate 100 book IDs that are NOT in the unlearning set (ALL_BOOK_IDS)
    # Use a range of common Gutenberg IDs, excluding those in ALL_BOOK_IDS
    excluded_ids = set(Config.ALL_BOOK_IDS)
    candidate_ids = []
    # Common Gutenberg book IDs (expanded list)
    common_ids = [
        1232, 145, 76, 2591, 30254, 844, 345, 520, 6130, 174,
        1342, 11, 2701, 74, 98, 5200, 6130, 174, 1661, 84,
        100, 200, 300, 400, 500, 600, 700, 800, 900, 1000,
        1100, 1200, 1300, 1400, 1500, 1600, 1700, 1800, 1900, 2000,
        2100, 2200, 2300, 2400, 2500, 2600, 2800, 2900, 3000, 3100,
        3200, 3300, 3400, 3500, 3600, 3700, 3800, 3900, 4000, 4100,
        4200, 4300, 4400, 4500, 4600, 4700, 4800, 4900, 5000, 5100,
        5300, 5400, 5500, 5600, 5700, 5800, 5900, 6000, 6100, 6200,
        6300, 6400, 6500, 6600, 6700, 6800, 6900, 7000, 7100, 7200,
        7300, 7400, 7500, 7600, 7700, 7800, 7900, 8000, 8100, 8200,
    ]

    for bid in common_ids:
        if bid not in excluded_ids and len(candidate_ids) < Config.NUM_RETENTION_BOOKS:
            candidate_ids.append(bid)

    # If we don't have enough, generate more sequential IDs
    next_id = 10000
    while len(candidate_ids) < Config.NUM_RETENTION_BOOKS:
        if next_id not in excluded_ids:
            candidate_ids.append(next_id)
        next_id += 1
        if next_id > 100000:  # Safety limit
            break

    retention_book_ids = candidate_ids[:Config.NUM_RETENTION_BOOKS]
    retention_dir = os.path.join(Config.DATA_DIR, "retention_books")
    os.makedirs(retention_dir, exist_ok=True)

    all_chunks = []
    successful_books = 0
    for book_id in retention_book_ids:
        book_file = download_gutenberg_book(book_id, retention_dir)
        if book_file:
            text = load_book_text(book_file)
            if text and len(text) > 1000:
                chunks = generate_simulated_data(
                    text,
                    Config.CHUNK_SIZE,
                    Config.NUM_RETENTION_CHUNKS // Config.NUM_RETENTION_BOOKS + 1,
                    Config.TOKENIZER_NAME
                )
                all_chunks.extend(chunks)
                successful_books += 1
                if len(all_chunks) >= Config.NUM_RETENTION_CHUNKS:
                    break
        maybe_delete_file(book_file)

    if not all_chunks:
        raise RuntimeError(
            f"Failed to download any retention books. "
            f"Expected {Config.NUM_RETENTION_BOOKS} books with {Config.NUM_RETENTION_CHUNKS} chunks total. "
            f"Only {successful_books} books downloaded successfully."
        )

    retention_chunks = all_chunks[:Config.NUM_RETENTION_CHUNKS]
    print(f"Created {len(retention_chunks)} retention chunks from {successful_books} books")

    return SequentialUnlearningDataset(tokenizer, retention_chunks, mode="retain")

print("Data utilities loaded!")


In [None]:
class SSUTrainer(Trainer):
    """Custom Trainer implementing SSU loss with Weight Saliency.

    SSU loss: L = λ1 * L_fgt + λ2 * L_rnd (Equation 4 from paper)
    No retention LM loss - D_nor is only used for evaluation and random labels.
    """
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        """Compute SSU loss: λ1 * L_fgt + λ2 * L_rnd (Equation 4)."""
        inputs_copy = inputs.copy()
        _ = inputs_copy.pop('labels', None)  # Not used in SSU loss
        labels_fgt = inputs_copy.pop('labels_fgt', None)
        labels_rnd = inputs_copy.pop('labels_rnd', None)

        if labels_fgt is None or labels_rnd is None:
            raise ValueError("SSU requires labels_fgt and labels_rnd for forget batches")

        # Compute L_fgt and L_rnd
        outputs_fgt = model(**inputs_copy, labels=labels_fgt)
        outputs_rnd = model(**inputs_copy, labels=labels_rnd)

        # SSU loss: L = λ1 * L_fgt + λ2 * L_rnd
        loss = Config.EPSILON_1 * outputs_fgt.loss + Config.EPSILON_2 * outputs_rnd.loss
        outputs_to_return = outputs_fgt

        return (loss, outputs_to_return) if return_outputs else loss

    def optimizer_step(self):
        """Override optimizer_step to apply dynamic weight saliency masking.

        Gamma (γ) is computed dynamically as: γ = μ + σ
        where μ is the mean and σ is the standard deviation of gradient magnitudes.
        """
        # Apply weight saliency mask before optimizer step
        if self.accelerator.sync_gradients:
            for name, param in self.model.named_parameters():
                if param.grad is not None and param.requires_grad:
                    if "lora" in name.lower():
                        grad = param.grad.data
                        grad_abs = grad.abs()

                        # Compute dynamic gamma: γ = μ + σ
                        flat = grad_abs.view(-1)
                        mu = flat.mean()
                        sigma = flat.std()
                        gamma = mu + sigma  # 1 std above mean

                        # Saliency Mask: m_s = I(|grad| >= gamma)
                        m_s = (grad_abs >= gamma).float()

                        # Apply mask to gradients (only update parameters with high saliency)
                        param.grad.data = grad * m_s

        # Call parent optimizer_step to perform the actual update
        super().optimizer_step()


def create_lora_model(model):
    """Adds LoRA adapters to the base model."""
    peft_config = LoraConfig(
        r=Config.LORA_R,
        lora_alpha=Config.LORA_ALPHA,
        lora_dropout=Config.LORA_DROPOUT,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=Config.TARGET_MODULES,
    )
    return get_peft_model(model, peft_config)


def apply_task_vector_negation(base_model, fine_tuned_model, name_prefix):
    """Task Vector Negation: theta_u^t = 2 * theta_u^{t-1} - theta_ft^t"""
    print(f"\n--- Applying Task Vector Negation for {name_prefix} ---")

    device = next(base_model.parameters()).device
    new_unlearned_model = base_model.__class__(config=base_model.config).to(device)

    base_state = base_model.state_dict()
    ft_state = fine_tuned_model.state_dict()

    new_state = {}
    for name, param in new_unlearned_model.named_parameters():
        if name in base_state and name in ft_state:
            new_state[name] = 2 * base_state[name] - ft_state[name]
        else:
            new_state[name] = base_state.get(name, param.data)

    new_unlearned_model.load_state_dict(new_state, strict=False)
    print(f"Task Vector Negation complete for {name_prefix}.")
    return new_unlearned_model

print("SSU model utilities loaded!")


In [None]:
## 9. Evaluation Utilities
from typing import List, Tuple, Any
from sentence_transformers import SentenceTransformer
from sklearn.metrics import roc_auc_score

# Semantic similarity model (cached)
SEMANTIC_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
_semantic_model = None

def get_semantic_model():
    """Return a cached sentence-transformers model on CPU."""
    global _semantic_model
    if _semantic_model is None:
        _semantic_model = SentenceTransformer(SEMANTIC_MODEL_NAME, device="cpu")
    return _semantic_model


def evaluate_semantic_similarity(
    best_pairs: List[Tuple[str, str]],
    similarity_threshold: float = 0.8
) -> Dict[str, Any]:
    """
    Compute semantic similarity between GT continuations and model generations.

    Args:
        best_pairs: list of (reference_text, generated_text) pairs per prompt.
    """
    if not best_pairs:
        return {"mean_cosine": 0.0, "frac_high": 0.0, "cosine_raw": []}

    embed_model = get_semantic_model()

    refs = [r for (r, _) in best_pairs]
    hyps = [h for (_, h) in best_pairs]

    ref_emb = embed_model.encode(refs, convert_to_tensor=True, normalize_embeddings=True)
    hyp_emb = embed_model.encode(hyps, convert_to_tensor=True, normalize_embeddings=True)

    cos = (ref_emb * hyp_emb).sum(dim=1).cpu().tolist()
    mean_cos = float(sum(cos) / len(cos))
    frac_high = float(sum(c >= similarity_threshold for c in cos) / len(cos))

    return {
        "mean_cosine": mean_cos,
        "frac_high": frac_high,
        "cosine_raw": cos,
    }


def compute_nll_per_chunk(
    model,
    tokenizer,
    text_chunks: List[str],
    chunk_size: int = 256,
) -> List[float]:
    """
    Compute per-chunk NLL (average cross-entropy loss per token) for membership analysis.
    """
    if not text_chunks:
        return []

    model.eval()
    device = next(model.parameters()).device
    nlls = []

    with torch.no_grad():
        for chunk in text_chunks:
            tokenized = tokenizer(
                chunk,
                return_tensors='pt',
                truncation=True,
                max_length=chunk_size,
            )
            input_ids = tokenized['input_ids'].to(device)
            outputs = model(input_ids=input_ids, labels=input_ids)
            nlls.append(outputs.loss.item())

    return nlls


def evaluate_membership_risk(
    model,
    tokenizer,
    forget_chunks: List[str],
    retain_chunks: List[str],
) -> Dict[str, Any]:
    """
    Membership-style analysis comparing NLL distributions on forget vs retain chunks.
    """
    nll_forget = compute_nll_per_chunk(model, tokenizer, forget_chunks)
    nll_retain = compute_nll_per_chunk(model, tokenizer, retain_chunks)

    if not nll_forget or not nll_retain:
        return {
            "mean_nll_forget": float("nan"),
            "mean_nll_retain": float("nan"),
            "delta_nll": float("nan"),
            "roc_auc": float("nan"),
            "nll_forget_raw": nll_forget,
            "nll_retain_raw": nll_retain,
        }

    mean_forget = sum(nll_forget) / len(nll_forget)
    mean_retain = sum(nll_retain) / len(nll_retain)
    delta_nll = mean_forget - mean_retain

    labels = [1] * len(nll_forget) + [0] * len(nll_retain)
    scores = nll_forget + nll_retain
    auc = roc_auc_score(labels, scores)

    return {
        "mean_nll_forget": mean_forget,
        "mean_nll_retain": mean_retain,
        "delta_nll": delta_nll,
        "roc_auc": auc,
        "nll_forget_raw": nll_forget,
        "nll_retain_raw": nll_retain,
    }


def _prepare_prompt_pairs(tokenizer, book_text: str, prompt_tokens: int = 100, continuation_tokens: int = 100, max_pairs: int = 16):
    """Prepare prompt-continuation pairs from book text.

    Args:
        prompt_tokens: Number of tokens for prompt (default 100 per paper)
        continuation_tokens: Number of tokens for continuation (default 100 per paper)
    """
    tokenized = tokenizer(
        book_text,
        return_tensors='pt',
        truncation=False,
        add_special_tokens=False,
    )['input_ids'][0]
    total_needed = prompt_tokens + continuation_tokens
    pairs = []
    for start in range(0, max(0, len(tokenized) - total_needed), total_needed):
        prompt_ids = tokenized[start:start + prompt_tokens]
        cont_ids = tokenized[start + prompt_tokens:start + total_needed]
        if len(prompt_ids) < prompt_tokens or len(cont_ids) < continuation_tokens:
            continue
        prompt_text = tokenizer.decode(prompt_ids, skip_special_tokens=True)
        cont_text = tokenizer.decode(cont_ids, skip_special_tokens=True)
        pairs.append((prompt_text, cont_text))
        if len(pairs) >= max_pairs:
            break
    return pairs


def evaluate_regurgitation(model, tokenizer, book_text: str, max_pairs: int = 10,
                          prompt_tokens: int = 100, continuation_tokens: int = 100,
                          num_samples: int = 10, verbose: bool = True,
                          return_best_hypotheses: bool = False) -> Dict[str, Any]:
    """Estimate regurgitation via ROUGE-L and ROUGE-1 between generated text and the book.

    For each prompt, generates N samples and takes the maximum ROUGE score per prompt.
    Matches paper's evaluation protocol.

    Args:
        return_best_hypotheses: If True, also return best_pairs list of (reference, best_hypothesis) tuples.
    """
    model_device = next(model.parameters()).device
    pairs = _prepare_prompt_pairs(tokenizer, book_text, prompt_tokens, continuation_tokens, max_pairs)
    if not pairs:
        result = {"rouge1": 0.0, "rougeL": 0.0, "num_pairs": 0}
        if return_best_hypotheses:
            result["best_pairs"] = []
        return result

    scorer = rouge_scorer.RougeScorer(["rouge1", "rougeL"], use_stemmer=False)
    max_rouge1_scores = []
    max_rougeL_scores = []
    best_pairs = []  # list of (reference, best_hypothesis)

    model.eval()
    total_generations = len(pairs) * num_samples
    generation_count = 0

    with torch.no_grad():
        for pair_idx, (prompt, reference) in enumerate(pairs):
            inputs = tokenizer(prompt, return_tensors='pt').to(model_device)

            # Generate N samples per prompt and take max ROUGE
            prompt_rouge1_scores = []
            prompt_rougeL_scores = []
            hypotheses = []  # track all generated hypotheses for this prompt

            for sample_idx in range(num_samples):
                if verbose and generation_count % 10 == 0:
                    print(f"  Generating sample {generation_count + 1}/{total_generations} (pair {pair_idx + 1}/{len(pairs)}, sample {sample_idx + 1}/{num_samples})", end='\r')

                generated = model.generate(
                    **inputs,
                    max_new_tokens=continuation_tokens,
                    do_sample=True,
                    top_p=0.6,
                    temperature=0.6,
                    pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id,
                )
                gen_continuation = generated[0][inputs['input_ids'].shape[1]:]
                hypothesis = tokenizer.decode(gen_continuation, skip_special_tokens=True)
                hypotheses.append(hypothesis)

                scores = scorer.score(reference, hypothesis)
                prompt_rouge1_scores.append(scores["rouge1"].fmeasure)
                prompt_rougeL_scores.append(scores["rougeL"].fmeasure)

                generation_count += 1

            # Find best hypothesis by ROUGE-L and track it
            best_idx = prompt_rougeL_scores.index(max(prompt_rougeL_scores))
            best_pairs.append((reference, hypotheses[best_idx]))

            # Take maximum per prompt
            max_rouge1_scores.append(max(prompt_rouge1_scores))
            max_rougeL_scores.append(max(prompt_rougeL_scores))

            if verbose:
                print(f"  Completed pair {pair_idx + 1}/{len(pairs)} - Max ROUGE-L: {max_rougeL_scores[-1]:.4f}                    ")

    avg_rouge1 = float(sum(max_rouge1_scores) / len(max_rouge1_scores)) if max_rouge1_scores else 0.0
    avg_rougeL = float(sum(max_rougeL_scores) / len(max_rougeL_scores)) if max_rougeL_scores else 0.0

    result = {
        "rouge1": avg_rouge1,
        "rougeL": avg_rougeL,
        "num_pairs": len(pairs)
    }
    if return_best_hypotheses:
        result["best_pairs"] = best_pairs
    return result


def _normalize_corpus(text_source: Sequence[str] | str, tokenizer, max_samples: int = 32) -> str:
    if isinstance(text_source, str):
        return text_source
    if isinstance(text_source, SequentialUnlearningDataset):
        samples = [tokenizer.decode(text_source.input_ids[i], skip_special_tokens=True) for i in range(min(len(text_source), max_samples))]
        return "\n\n".join(samples)
    if isinstance(text_source, list):
        return "\n\n".join(text_source[:max_samples])
    raise ValueError("Unsupported text source type for perplexity evaluation")


def evaluate_perplexity(model, tokenizer, text_source, stride: int = 256) -> float:
    """Compute perplexity on the provided text corpus."""
    corpus = _normalize_corpus(text_source, tokenizer)
    tokenized = tokenizer(
        corpus,
        return_tensors='pt',
        truncation=False,
        add_special_tokens=False,
    )['input_ids'][0]
    if tokenized.size(0) <= 1:
        return float('inf')
    model_device = next(model.parameters()).device
    nll_sum = 0.0
    token_count = 0
    model.eval()
    with torch.no_grad():
        for start in range(0, tokenized.size(0) - 1, stride):
            end = min(start + stride, tokenized.size(0))
            input_ids = tokenized[start:end].unsqueeze(0).to(model_device)
            labels = input_ids.clone()
            outputs = model(input_ids=input_ids, labels=labels)
            n_tokens = labels.size(1) - 1
            if n_tokens <= 0:
                continue
            nll_sum += outputs.loss.item() * n_tokens
            token_count += n_tokens
    if token_count == 0:
        return float('inf')
    return float(math.exp(nll_sum / token_count))



## 6. Model Loading with Retry Logic


In [None]:
def load_model_with_retry(model_name, max_retries=3, retry_delay=5):
    """Load model with automatic retry on network errors."""
    for attempt in range(max_retries):
        try:
            print(f"Loading model (attempt {attempt + 1}/{max_retries})...")

            tokenizer = AutoTokenizer.from_pretrained(model_name)
            if tokenizer.pad_token is None:
                tokenizer.pad_token = tokenizer.eos_token

            # Determine device
            if torch.cuda.is_available():
                device = "cuda"
                dtype = torch.bfloat16
            else:
                device = "cpu"
                dtype = torch.float32

            base_model = AutoModelForCausalLM.from_pretrained(
                model_name,
                torch_dtype=dtype,
                device_map=device,  # Use single device instead of "auto"
                trust_remote_code=True,
                attn_implementation="eager",
            )

            # Ensure model is on the correct device
            base_model = base_model.to(device)

            print(f"Successfully loaded {model_name}!")
            return base_model, tokenizer

        except Exception as e:
            error_msg = str(e)
            print(f"Attempt {attempt + 1} failed: {error_msg[:200]}...")

            if attempt < max_retries - 1:
                if "IncompleteRead" in error_msg or "Connection" in error_msg or "timeout" in error_msg.lower():
                    print(f"Network error detected. Retrying in {retry_delay} seconds...")
                    time.sleep(retry_delay)
                    retry_delay *= 2  # Exponential backoff
                else:
                    print("Non-network error. Retrying...")
                    time.sleep(2)
            else:
                print("\nAll retry attempts failed!")
                print("\nTROUBLESHOOTING:")
                print("1. Check your internet connection")
                print("2. Try a smaller model: TinyLlama/TinyLlama-1.1B-Chat-v1.0")
                print("3. For gated models, ensure HF_TOKEN is set")
                raise

    return None, None


# Load the model
print(f"Loading base model: {Config.MODEL_NAME}")
base_model, tokenizer = load_model_with_retry(Config.MODEL_NAME)
base_model.requires_grad_(False)
print("Model loaded and frozen!")


## 7. Main SSU Pipeline


In [None]:
import os
from transformers import TrainingArguments

def initial_finetuning(model, tokenizer, all_books_texts):
    """
    Step 0: Initial fine-tuning on all books (D_f) to make model memorize them.
    This is what the paper does BEFORE unlearning.
    """
    print("\n" + "="*60)
    print("STEP 0: INITIAL FINE-TUNING ON ALL BOOKS (D_f)")
    print("="*60)
    print("Fine-tuning vanilla model on all copyrighted books to memorize them...")

    # Generate chunks from all books
    all_chunks = []
    for book_text in all_books_texts:
        chunks = generate_simulated_data(
            book_text,
            Config.CHUNK_SIZE,
            Config.NUM_CHUNKS_PER_STEP * Config.NUM_UNLEARNING_STEPS // len(all_books_texts) + 1,
            Config.TOKENIZER_NAME
        )
        all_chunks.extend(chunks)

    print(f"Created {len(all_chunks)} chunks from all books")

    # For initial fine-tuning, use standard dataset (not SSU dual labels)
    from torch.utils.data import Dataset as TorchDataset
    class StandardDataset(TorchDataset):
        def __init__(self, tokenizer, data_texts):
            tokenized = tokenizer(
                data_texts,
                truncation=True,
                padding="max_length",
                max_length=Config.CHUNK_SIZE,
                return_tensors='pt'
            )
            self.input_ids = tokenized['input_ids']
            self.attention_mask = tokenized['attention_mask']

        def __len__(self):
            return len(self.input_ids)

        def __getitem__(self, idx):
            return {
                'input_ids': self.input_ids[idx].clone(),
                'attention_mask': self.attention_mask[idx].clone(),
                'labels': self.input_ids[idx].clone()  # Standard labels for next token prediction
            }

    initial_dataset = StandardDataset(tokenizer, all_chunks)

    # Ensure model is on correct device before creating LoRA
    device = next(model.parameters()).device
    if device.type == "meta" or str(device) == "meta":
        # If model is on meta device, move to actual device
        device = "cuda" if torch.cuda.is_available() else "cpu"
        model = model.to(device)
    else:
        # Convert device object to string if needed
        device = str(device).split(':')[0]  # Get 'cuda' or 'cpu'

    # Create LoRA model for initial fine-tuning (PEFT handles device automatically)
    if hasattr(model, "enable_input_require_grads"):
        model.enable_input_require_grads()
    else:
        def make_inputs_require_grad(module, input, output):
            output.requires_grad_(True)
        model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)

    lora_model = create_lora_model(model)
    lora_model.print_trainable_parameters()

    # Disable cache for gradient checkpointing
    if hasattr(lora_model.config, "use_cache"):
        lora_model.config.use_cache = False

    # PEFT models inherit device from base model, no need to call .to()

    training_args = TrainingArguments(
        output_dir=f"{Config.OUTPUT_DIR}/initial_ft_checkpoints",
        per_device_train_batch_size=Config.BATCH_SIZE,
        gradient_accumulation_steps=Config.GRADIENT_ACCUMULATION_STEPS,
        warmup_steps=10,
        learning_rate=Config.LEARNING_RATE,
        num_train_epochs=Config.NUM_EPOCHS_INITIAL_FT,
        logging_steps=10,
        save_strategy="no",
        report_to="none",
        fp16=False,
        bf16=torch.cuda.is_available() and device == "cuda",
        dataloader_pin_memory=False,  # Fix device issues
        label_names=["labels"],
        gradient_checkpointing=True,  # Enable gradient checkpointing
    )

    # Use standard Trainer for initial fine-tuning (not SSU)
    from transformers import Trainer, DataCollatorForLanguageModeling
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm=False,  # Causal LM, not masked LM
    )

    trainer = Trainer(
        model=lora_model,
        args=training_args,
        train_dataset=initial_dataset,
        processing_class=tokenizer,
        data_collator=data_collator,
    )

    print("Starting initial fine-tuning...")
    trainer.train()

    # Merge LoRA weights into base model
    memorized_model = lora_model.merge_and_unload()
    memorized_model.requires_grad_(False)

    print("Initial fine-tuning complete. Model has memorized all books.")
    return memorized_model


def run_initial_finetuning():
    """STEP 0: Initial fine-tuning on all books (D_f) to make model memorize them.

    Run this cell once to create the memorized model. After this completes,
    you can run the sequential unlearning steps without re-running this.
    """
    # Ensure all required dependencies are available
    missing = []
    try:
        _ = Config.OUTPUT_DIR
    except NameError:
        missing.append("Config (cell 4)")

    try:
        _ = base_model
    except NameError:
        missing.append("base_model (cell 11)")

    try:
        _ = tokenizer
    except NameError:
        missing.append("tokenizer (cell 11)")

    if missing:
        raise NameError(
            f"The following are not defined: {', '.join(missing)}. "
            f"Please run the required cells first to set up the dependencies."
        )

    os.makedirs(Config.OUTPUT_DIR, exist_ok=True)

    # Check if memorized model already exists
    memorized_model_path = f"{Config.OUTPUT_DIR}/memorized_model"
    if os.path.exists(memorized_model_path):
        print(f"Memorized model already exists at {memorized_model_path}")
        print("Loading existing memorized model...")
        from transformers import AutoModelForCausalLM
        device = "cuda" if torch.cuda.is_available() else "cpu"
        dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
        memorized_model = AutoModelForCausalLM.from_pretrained(
            memorized_model_path,
            torch_dtype=dtype,
            device_map=device
        )
        memorized_model.requires_grad_(False)
        print("Loaded existing memorized model.")
        return memorized_model

    # STEP 0: Initial fine-tuning on all books (D_f)
    print("\n" + "="*60)
    print("STEP 0: INITIAL FINE-TUNING ON ALL BOOKS (D_f)")
    print("="*60)
    all_books_texts = get_all_books_for_initial_finetuning()
    memorized_model = initial_finetuning(base_model, tokenizer, all_books_texts)

    # Save the memorized model (optional)
    if Config.SAVE_MEMORIZED_MODEL:
        memorized_model.save_pretrained(memorized_model_path)
        tokenizer.save_pretrained(memorized_model_path)
        print(f"\nMemorized model saved to {memorized_model_path}")
    else:
        print("\nSkipping on-disk save of memorized model (SAVE_MEMORIZED_MODEL=False)")

    return memorized_model


memorized_model = run_initial_finetuning()

In [None]:
## 8. Sequential Unlearning Steps
def run_sequential_unlearning(start_step=1, end_step=None, general_validation_text=None, run_evaluation=True):
    # Suppress dynamo errors
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True
    """Run sequential unlearning steps with optional evaluation reporting."""
    # Ensure all required dependencies are available
    missing = []
    try:
        _ = Config.OUTPUT_DIR
    except NameError:
        missing.append("Config (cell 4)")

    try:
        _ = tokenizer
    except NameError:
        missing.append("tokenizer (cell 11)")

    if missing:
        raise NameError(
            f"The following are not defined: {', '.join(missing)}. "
            f"Please run the required cells first."
        )

    if end_step is None:
        end_step = Config.NUM_UNLEARNING_STEPS

    general_validation_text = general_validation_text or DUMMY_BOOK_TEXT

    # Load or start from memorized model (check if file exists, not variable)
    memorized_model_path = f"{Config.OUTPUT_DIR}/memorized_model"
    memorized_model_on_disk = os.path.exists(memorized_model_path)
    memorized_model_in_memory = 'memorized_model' in globals()
    if not memorized_model_on_disk and not memorized_model_in_memory:
        raise FileNotFoundError(
            f"Memorized model not found at {memorized_model_path} and no in-memory model available. "
            f"Please run initial fine-tuning first."
        )

    # Determine starting model
    from transformers import AutoModelForCausalLM
    device = "cuda" if torch.cuda.is_available() else "cpu"
    dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32

    if start_step == 1:
        if memorized_model_on_disk:
            current_model = AutoModelForCausalLM.from_pretrained(
                memorized_model_path,
                torch_dtype=dtype,
                device_map=device
            )
            print("Starting from memorized model (disk) for step 1")
        else:
            current_model = memorized_model.to(device)
            print("Starting from memorized model (in-memory) for step 1")
        current_model.requires_grad_(False)
    else:
        # Load model from previous step
        prev_step_path = f"{Config.OUTPUT_DIR}/step_{start_step-1}_unlearned_model"
        if not os.path.exists(prev_step_path):
            raise FileNotFoundError(
                f"Previous step model not found at {prev_step_path}. "
                f"Resume is only possible if step checkpoints were saved."
            )
        current_model = AutoModelForCausalLM.from_pretrained(
            prev_step_path,
            torch_dtype=dtype,
            device_map=device
        )
        current_model.requires_grad_(False)
        print(f"Starting from step {start_step-1} model for step {start_step}")

    # Get retention dataset (D_nor) first - needed for random_label_ids
    retention_dataset = get_retention_dataset()

    # Extract random_label_ids from D_nor for use in forget datasets
    random_label_ids = retention_dataset.input_ids if retention_dataset is not None else None
    if random_label_ids is None:
        raise RuntimeError("retention_dataset is required for SSU (provides random_label_ids)")

    # Load sequential datasets for unlearning with random_label_ids
    unlearning_datasets = get_unlearning_datasets(random_label_ids=random_label_ids)
    print(f"\nGenerated {len(unlearning_datasets)} sequential unlearning datasets.")

    all_target_book_ids = [book_id for ids in Config.GUTENBERG_BOOK_IDS.values() for book_id in ids]
    evaluation_log = {"baseline": {}, "steps": []}
    eval_book_cache = {}
    eval_book_dir = os.path.join(Config.DATA_DIR, "evaluation_books")
    os.makedirs(eval_book_dir, exist_ok=True)

    def get_book_text_for_eval(book_id):
        if book_id in eval_book_cache:
            return eval_book_cache[book_id]
        book_file = download_gutenberg_book(book_id, eval_book_dir)
        if not book_file:
            raise RuntimeError(f"Failed to download book {book_id} for evaluation")
        text = load_book_text(book_file)
        if not text or len(text) < 1000:
            raise RuntimeError(f"Book {book_id} text is invalid or too short for evaluation")
        eval_book_cache[book_id] = text
        maybe_delete_file(book_file)
        return text

    if run_evaluation:
        print("\n=== Baseline Evaluation ===")
        print(f"Note: Evaluation generates {Config.EVAL_NUM_SAMPLES} samples per prompt, this may take a while...")
        baseline_regurg = {}
        for book_id in all_target_book_ids:
            print(f"\nEvaluating book {book_id}...")
            regurg_results = evaluate_regurgitation(
                current_model,
                tokenizer,
                get_book_text_for_eval(book_id),
                max_pairs=Config.EVAL_MAX_PAIRS,
                num_samples=Config.EVAL_NUM_SAMPLES,
                verbose=True,
                return_best_hypotheses=True,
            )
            rouge_metrics = {"rouge1": regurg_results["rouge1"], "rougeL": regurg_results["rougeL"]}
            semantic_metrics = evaluate_semantic_similarity(regurg_results["best_pairs"])
            baseline_regurg[book_id] = {
                "rouge": rouge_metrics,
                "semantic": semantic_metrics,
            }
            print(f"Book {book_id} baseline ROUGE-L: {rouge_metrics['rougeL']:.4f}, ROUGE-1: {rouge_metrics['rouge1']:.4f}")
            print(f"Book {book_id} baseline semantic: mean_cosine={semantic_metrics['mean_cosine']:.4f}, frac_high={semantic_metrics['frac_high']:.2f}")

        baseline_perplexity = {}
        if retention_dataset is not None:
            baseline_perplexity['retention'] = evaluate_perplexity(current_model, tokenizer, retention_dataset)
            print(f"Retention baseline perplexity: {baseline_perplexity['retention']:.2f}")
        baseline_perplexity['general'] = evaluate_perplexity(current_model, tokenizer, general_validation_text)
        print(f"General baseline perplexity: {baseline_perplexity['general']:.2f}")

        # Baseline membership evaluation
        baseline_ds = unlearning_datasets[0]
        baseline_forget_chunks = baseline_ds.data_texts
        baseline_retain_chunks = retention_dataset.data_texts[:len(baseline_forget_chunks)]
        baseline_membership = evaluate_membership_risk(
            current_model,
            tokenizer,
            baseline_forget_chunks,
            baseline_retain_chunks,
        )
        print(f"Baseline membership delta_NLL: {baseline_membership['delta_nll']:.4f}, ROC-AUC: {baseline_membership['roc_auc']:.4f}")

        evaluation_log['baseline'] = {
            'regurgitation': baseline_regurg,
            'perplexity': baseline_perplexity,
            'membership': baseline_membership,
        }

    # Track latest saved checkpoint to delete previous ones if configured
    last_saved_checkpoint = None

    # STEP 1-N: Sequential Unlearning Loop
    for t in range(start_step - 1, min(end_step, Config.NUM_UNLEARNING_STEPS)):
        print(f"\n{'='*60}")
        print(f"UNLEARNING STEP {t+1} (Unlearning D_f^{t+1})")
        print(f"{'='*60}")

        dataset_t = unlearning_datasets[t]
        step_prefix = f"step_{t+1}"

        # SSU training uses only D_f^t (dataset_t), not D_nor
        # D_nor is only used for evaluation and random labels
        train_dataset = dataset_t

        # Fine-Tuning Stage
        print(f"Preparing LoRA model for fine-tuning on D_f^{t+1}...")

        # Ensure model is on correct device
        device = next(current_model.parameters()).device
        if device.type == "meta" or str(device) == "meta":
            device = "cuda" if torch.cuda.is_available() else "cpu"
            current_model = current_model.to(device)
        else:
            # Convert device object to string if needed
            device = str(device).split(':')[0]  # Get 'cuda' or 'cpu'

        if hasattr(current_model, "enable_input_require_grads"):
            current_model.enable_input_require_grads()
        else:
            def make_inputs_require_grad(module, input, output):
                output.requires_grad_(True)
            current_model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)

        lora_model = create_lora_model(current_model)
        # PEFT models inherit device from base model, no need to call .to()
        lora_model.print_trainable_parameters()

        # Disable cache for gradient checkpointing
        if hasattr(lora_model.config, "use_cache"):
            lora_model.config.use_cache = False

        # Dynamic learning rate: 1e-5 for steps 1-5, 1e-6 for steps 6-10
        lr = 1e-5 if (t + 1) <= 5 else 1e-6

        training_args = TrainingArguments(
            output_dir=f"{Config.OUTPUT_DIR}/{step_prefix}_ft_checkpoints",
            per_device_train_batch_size=Config.BATCH_SIZE,
            gradient_accumulation_steps=Config.GRADIENT_ACCUMULATION_STEPS,
            warmup_steps=Config.WARMUP_STEPS,
            learning_rate=lr,
            num_train_epochs=Config.NUM_EPOCHS_FT,
            logging_steps=10,
            save_strategy="no",
            report_to="none",
            fp16=False,
            bf16=torch.cuda.is_available() and device == "cuda",
            dataloader_pin_memory=False,  # Fix device issues
            label_names=["labels"],
            gradient_checkpointing=True,  # Enable gradient checkpointing
            remove_unused_columns=False,  # Keep custom labels
            weight_decay=Config.WEIGHT_DECAY,
        )

        # Use custom data collator that preserves labels_fgt and labels_rnd
        data_collator = SSUDataCollator(tokenizer=tokenizer)

        trainer = SSUTrainer(
            model=lora_model,
            args=training_args,
            train_dataset=train_dataset,
            processing_class=tokenizer,
            data_collator=data_collator,
        )

        print(f"Starting fine-tuning with SSU Loss for {step_prefix}...")
        trainer.train()

        # Task Vector Negation Stage: theta_new = theta_old - Delta_LoRA
        print(f"\n--- Applying Task Vector Negation for {step_prefix} ---")

        # 1. Negate LoRA weights: W_new = W_old - (W_ft - W_old) = W_old - W_lora
        # We achieve this by multiplying LoRA weights by -1, then merging.
        with torch.no_grad():
            for name, param in lora_model.named_parameters():
                if "lora" in name:
                    param.data = -1 * param.data

        print("LoRA weights negated.")

        # 2. Merge negated weights into base model
        current_model = lora_model.merge_and_unload()
        current_model.requires_grad_(False)

        print(f"Task Vector Negation complete for {step_prefix}.")

        # Save the unlearned model (optionally keep only final checkpoint)
        save_path = f"{Config.OUTPUT_DIR}/{step_prefix}_unlearned_model"
        is_final_step = (t + 1) == min(end_step, Config.NUM_UNLEARNING_STEPS)
        save_this_step = Config.SAVE_STEP_MODELS or is_final_step
        if save_this_step:
            current_model.save_pretrained(save_path)
            tokenizer.save_pretrained(save_path)
            print(f"Unlearned model {step_prefix} saved.")
            if Config.DELETE_PREVIOUS_STEP_MODELS and last_saved_checkpoint and last_saved_checkpoint != save_path:
                print(f"Deleting previous checkpoint: {last_saved_checkpoint}")
                cleanup_dir(last_saved_checkpoint)
            last_saved_checkpoint = save_path
        else:
            # Remove any stale checkpoint directory to avoid disk bloat
            cleanup_dir(save_path)
            print(f"Skipped saving {step_prefix} checkpoint (SAVE_STEP_MODELS=False)")

        if run_evaluation:
            step_metrics = {
                'step': t + 1,
                'current_books': {},  # D_f^t (current step)
                'prev_books': {},     # D_prev (previously unlearned)
                'perplexity': {},
            }

            # Evaluate on current books (D_f^t)
            current_book_ids = Config.GUTENBERG_BOOK_IDS.get(t + 1, [])
            for book_id in current_book_ids:
                print(f"\nEvaluating current book {book_id}...")
                regurg_results = evaluate_regurgitation(
                    current_model,
                    tokenizer,
                    get_book_text_for_eval(book_id),
                    max_pairs=Config.EVAL_MAX_PAIRS,
                    num_samples=Config.EVAL_NUM_SAMPLES,
                    verbose=True,
                    return_best_hypotheses=True,
                )
                rouge_metrics = {"rouge1": regurg_results["rouge1"], "rougeL": regurg_results["rougeL"]}
                semantic_metrics = evaluate_semantic_similarity(regurg_results["best_pairs"])
                step_metrics['current_books'][book_id] = {
                    "rouge": rouge_metrics,
                    "semantic": semantic_metrics,
                }
                print(f"Step {t+1} current book {book_id} ROUGE-L: {rouge_metrics['rougeL']:.4f}, ROUGE-1: {rouge_metrics['rouge1']:.4f}")
                print(f"Step {t+1} current book {book_id} semantic: mean_cosine={semantic_metrics['mean_cosine']:.4f}, frac_high={semantic_metrics['frac_high']:.2f}")

            # Evaluate on previous books (D_prev) - union of steps 1..t
            prev_book_ids = []
            for prev_step in range(1, t + 1):
                prev_book_ids.extend(Config.GUTENBERG_BOOK_IDS.get(prev_step, []))
            for book_id in prev_book_ids:
                print(f"\nEvaluating prev book {book_id}...")
                regurg_results = evaluate_regurgitation(
                    current_model,
                    tokenizer,
                    get_book_text_for_eval(book_id),
                    max_pairs=Config.EVAL_MAX_PAIRS,
                    num_samples=Config.EVAL_NUM_SAMPLES,
                    verbose=True,
                    return_best_hypotheses=True,
                )
                rouge_metrics = {"rouge1": regurg_results["rouge1"], "rougeL": regurg_results["rougeL"]}
                semantic_metrics = evaluate_semantic_similarity(regurg_results["best_pairs"])
                step_metrics['prev_books'][book_id] = {
                    "rouge": rouge_metrics,
                    "semantic": semantic_metrics,
                }
                print(f"Step {t+1} prev book {book_id} ROUGE-L: {rouge_metrics['rougeL']:.4f}, ROUGE-1: {rouge_metrics['rouge1']:.4f}")
                print(f"Step {t+1} prev book {book_id} semantic: mean_cosine={semantic_metrics['mean_cosine']:.4f}, frac_high={semantic_metrics['frac_high']:.2f}")

            # Evaluate perplexity on D_nor (retention) and general text
            if retention_dataset is not None:
                retention_ppl = evaluate_perplexity(current_model, tokenizer, retention_dataset)
                step_metrics['perplexity']['retention'] = retention_ppl
                print(f"Step {t+1} retention perplexity: {retention_ppl:.2f}")
            general_ppl = evaluate_perplexity(current_model, tokenizer, general_validation_text)
            step_metrics['perplexity']['general'] = general_ppl
            print(f"Step {t+1} general perplexity: {general_ppl:.2f}")

            # Membership evaluation for current step
            forget_chunks = dataset_t.data_texts
            retain_chunks = retention_dataset.data_texts[:len(forget_chunks)]
            membership_metrics = evaluate_membership_risk(
                current_model,
                tokenizer,
                forget_chunks,
                retain_chunks,
            )
            step_metrics['membership'] = membership_metrics
            print(f"Step {t+1} membership delta_NLL: {membership_metrics['delta_nll']:.4f}, ROC-AUC: {membership_metrics['roc_auc']:.4f}")

            evaluation_log['steps'].append(step_metrics)

        # current_model is already updated via merge_and_unload

    # Final Evaluation
    print("\n" + "="*60)
    print("SEQUENTIAL UNLEARNING COMPLETE")
    print("="*60)
    final_step = min(end_step, Config.NUM_UNLEARNING_STEPS)
    final_model_dir = f"{Config.OUTPUT_DIR}/step_{final_step}_unlearned_model"
    print(f"Final Unlearned Model: {final_model_dir}")

    # Test generation
    prompt = "The quick brown fox"
    inputs = tokenizer(prompt, return_tensors="pt")

    print("\nTesting generation with final unlearned model...")

    # Reload model to remove training hooks/state that interfere with inference
    print("Reloading model for clean inference...")
    del current_model
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
    # Re-determine device
    device = "cuda" if torch.cuda.is_available() else "cpu"

    current_model = AutoModelForCausalLM.from_pretrained(
        final_model_dir,
        torch_dtype=dtype,
        device_map=device,
        attn_implementation="eager"
    )
    current_model.requires_grad_(False)
    current_model.eval()

    inputs = {k: v.to(device) for k, v in inputs.items()}

    # Completely disable torch.compile to avoid compilation errors
    torch._dynamo.reset()
    # This completely disables dynamo/torch.compile
    original_disable_state = torch._dynamo.config.disable
    torch._dynamo.config.disable = True

    try:
        with torch.no_grad():
            output_tokens = current_model.generate(
                **inputs,
                max_new_tokens=30,
                do_sample=True,
                top_p=0.9,
                temperature=0.7
            )
    finally:
        # Restore original state
        torch._dynamo.config.disable = original_disable_state

    generated_text = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
    print(f"Prompt: {prompt}")
    print(f"Generated: {generated_text}")

    if run_evaluation:
        run_sequential_unlearning.last_evaluation = evaluation_log

    if Config.CLEANUP_FINAL_MODEL_DIR:
        print(f"Removing final model directory per configuration: {final_model_dir}")
        cleanup_dir(final_model_dir)

    return current_model

# You can specify start_step and end_step to resume from a specific step
# Example: run_sequential_unlearning(start_step=2, end_step=3) to run only steps 2-3
final_unlearned_model = run_sequential_unlearning()

## 9. Inference-Only Evaluation

Run this cell to evaluate an existing unlearned model with the new metrics (semantic similarity and membership NLL) **without retraining**. This loads the saved `step_10_unlearned_model` and runs evaluation only.


In [None]:
# Run on colab through Google Drive.
from google.colab import drive
drive.mount('/content/drive')

print("Available models:")
for path in ["/content/models/memorized_model", "/content/models/step_10_unlearned_model",
             "/content/drive/MyDrive/models/memorized_model",
             "/content/drive/MyDrive/models/ssu_unlearned_models/step_10_unlearned_model"]:
    if os.path.exists(path):
        files = os.listdir(path)
        print(f"✓ {path} ({len(files)} files)")
    else:
        print(f"✗ {path}")

In [None]:
# === MEMORY-EFFICIENT EVALUATION (for Colab free tier) ===
# Evaluates models one at a time with aggressive memory cleanup

import gc

def run_memory_efficient_evaluation(
    memorized_model_path="/content/drive/MyDrive/models/memorized_model",
    unlearned_model_path="/content/drive/MyDrive/models/ssu_unlearned_models/step_10_unlearned_model",
    num_books=3,  # Only evaluate first N books to save memory
    max_pairs=3,  # Reduce pairs per book
    num_samples=2,  # Reduce samples per prompt
):
    """
    Memory-efficient before/after comparison.
    Evaluates one model at a time with aggressive cleanup.
    """
    import json
    from transformers import AutoModelForCausalLM, AutoTokenizer

    device = "cuda" if torch.cuda.is_available() else "cpu"
    dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32

    # Use fewer books
    all_book_ids = Config.ALL_BOOK_IDS[:num_books]
    print(f"Evaluating {num_books} books with {max_pairs} pairs, {num_samples} samples each")

    # Prepare book texts first (before loading models)
    print("\n=== Downloading book texts ===")
    eval_book_dir = os.path.join(Config.DATA_DIR, "evaluation_books")
    os.makedirs(eval_book_dir, exist_ok=True)
    book_texts = {}
    for book_id in all_book_ids:
        book_file = download_gutenberg_book(book_id, eval_book_dir)
        if book_file:
            book_texts[book_id] = load_book_text(book_file)
            print(f"✓ Book {book_id} loaded")

    results = {"memorized": {}, "unlearned": {}}

    # Evaluate each model separately
    for label, model_path in [("memorized", memorized_model_path), ("unlearned", unlearned_model_path)]:
        print(f"\n{'='*60}")
        print(f"EVALUATING {label.upper()} MODEL")
        print(f"{'='*60}")

        # Clear memory before loading
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        # Load model
        print(f"Loading {model_path}...")
        model = AutoModelForCausalLM.from_pretrained(
            model_path, torch_dtype=dtype, device_map=device, attn_implementation="eager"
        )
        model.eval()

        # Load tokenizer
        try:
            tok = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        except:
            tok = AutoTokenizer.from_pretrained(Config.TOKENIZER_NAME, trust_remote_code=True)
        if tok.pad_token is None:
            tok.pad_token = tok.eos_token

        model_results = {}

        # Evaluate each book
        for book_id in all_book_ids:
            print(f"\n--- {label}: Book {book_id} ---")
            book_text = book_texts[book_id]

            regurg_results = evaluate_regurgitation(
                model, tok, book_text,
                max_pairs=max_pairs,
                num_samples=num_samples,
                verbose=False,
                return_best_hypotheses=True,
            )

            rouge = {"rouge1": regurg_results["rouge1"], "rougeL": regurg_results["rougeL"]}
            semantic = evaluate_semantic_similarity(regurg_results["best_pairs"])

            model_results[book_id] = {"rouge": rouge, "semantic": semantic}
            print(f"  ROUGE-L: {rouge['rougeL']:.4f}, Semantic: {semantic['mean_cosine']:.4f}")

            # Clear generation cache
            gc.collect()

        results[label] = model_results

        # CRITICAL: Delete model before loading next one
        del model
        del tok
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        print(f"\n✓ {label} model unloaded, memory cleared")

    # Print comparison
    print(f"\n{'='*60}")
    print("COMPARISON RESULTS")
    print(f"{'='*60}")

    print(f"\n{'Book ID':<10} {'ROUGE-L (Mem)':<15} {'ROUGE-L (Unl)':<15} {'Δ':<10} {'Semantic (Mem)':<15} {'Semantic (Unl)':<15} {'Δ':<10}")
    print("-" * 95)

    total_r_mem, total_r_unl, total_s_mem, total_s_unl = 0, 0, 0, 0

    for book_id in all_book_ids:
        mem = results["memorized"][book_id]
        unl = results["unlearned"][book_id]

        r_mem, r_unl = mem["rouge"]["rougeL"], unl["rouge"]["rougeL"]
        s_mem, s_unl = mem["semantic"]["mean_cosine"], unl["semantic"]["mean_cosine"]

        total_r_mem += r_mem
        total_r_unl += r_unl
        total_s_mem += s_mem
        total_s_unl += s_unl

        print(f"{book_id:<10} {r_mem:<15.4f} {r_unl:<15.4f} {r_unl-r_mem:<+10.4f} {s_mem:<15.4f} {s_unl:<15.4f} {s_unl-s_mem:<+10.4f}")

    n = len(all_book_ids)
    print("-" * 95)
    print(f"{'AVERAGE':<10} {total_r_mem/n:<15.4f} {total_r_unl/n:<15.4f} {(total_r_unl-total_r_mem)/n:<+10.4f} {total_s_mem/n:<15.4f} {total_s_unl/n:<15.4f} {(total_s_unl-total_s_mem)/n:<+10.4f}")

    # Summary
    print(f"\n{'='*60}")
    print("SUMMARY")
    print(f"{'='*60}")
    avg_rouge_drop = (total_r_unl - total_r_mem) / n
    avg_sem_drop = (total_s_unl - total_s_mem) / n
    print(f"Average ROUGE-L change: {avg_rouge_drop:+.4f}")
    print(f"Average semantic similarity change: {avg_sem_drop:+.4f}")

    if avg_rouge_drop < 0:
        print("✓ ROUGE decreased (less regurgitation after unlearning)")

    return results


# Run memory-efficient evaluation
# Adjust num_books, max_pairs, num_samples if still OOM
results = run_memory_efficient_evaluation(
    memorized_model_path="/content/drive/MyDrive/models/memorized_model",
    unlearned_model_path="/content/drive/MyDrive/models/ssu_unlearned_models/step_10_unlearned_model",
    num_books=3,   # Start with 3 books
    max_pairs=3,   # 3 prompt pairs per book
    num_samples=2, # 2 generations per prompt
)


## 10. Verification & Testing

**Automated checks** ensure the core SSU components work before running long jobs:
- Dataset construction keeps the correct masking for `labels`, `labels_fgt`, and `labels_rnd`.
- The custom trainer mixes retention and forgetting losses without shape errors.
- A miniature pipeline run (concatenated forget + retain batches) executes an optimizer step without crashing.

**Manual checks** after training runs:
- Generate from a known passage of a removed book and confirm the model no longer regurgitates it.
- Measure perplexity on `D_nor` before vs. after each step to ensure general capability stays stable.


In [None]:
# Automated sanity checks for dataset, trainer, and pipeline wiring
def run_automated_tests():
    tok = tokenizer
    sample_texts = ["Sanity sample text " + str(i) for i in range(4)]

    # Create dummy random_label_ids for forget dataset
    dummy_tokenized = tok(
        sample_texts,
        truncation=True,
        padding="max_length",
        max_length=Config.CHUNK_SIZE,
        return_tensors='pt'
    )
    dummy_random_label_ids = dummy_tokenized['input_ids']

    # Dataset + collator check
    forget_ds = SequentialUnlearningDataset(tok, sample_texts, mode="forget", random_label_ids=dummy_random_label_ids)
    retain_ds = SequentialUnlearningDataset(tok, sample_texts, mode="retain")
    collator = SSUDataCollator(tok)
    batch = collator([forget_ds[0], retain_ds[0]])
    assert 'labels' in batch and 'labels_fgt' in batch and 'labels_rnd' in batch
    assert batch['labels'][0].eq(-100).all(), "Forget sample should mask retention labels"
    assert not batch['labels'][1].eq(-100).all(), "Retain sample should keep labels"
    assert not batch['labels_fgt'][0].eq(-100).all(), "Forget sample must keep L_fgt"
    assert not batch['labels_rnd'][0].eq(-100).all(), "Forget sample must keep L_rnd"

    # Trainer loss check - SSU loss should be λ1 * L_fgt + λ2 * L_rnd (no retention loss)
    class MockModel(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.losses = [torch.tensor(0.4), torch.tensor(0.2)]  # fgt, rnd
        def forward(self, *args, **kwargs):
            loss_val = self.losses.pop(0) if self.losses else torch.tensor(0.0)
            return SimpleNamespace(loss=loss_val)
    mock_model = MockModel()
    args = TrainingArguments(
        output_dir=os.path.join(Config.OUTPUT_DIR, "test_runs"),
        per_device_train_batch_size=1,
        num_train_epochs=1,
        report_to=[],
        logging_steps=1000,
    )
    trainer = SSUTrainer(model=mock_model, args=args, train_dataset=None)
    dummy_inputs = {
        'input_ids': torch.ones((1, 4), dtype=torch.long),
        'attention_mask': torch.ones((1, 4), dtype=torch.long),
        'labels': torch.ones((1, 4), dtype=torch.long),  # Not used in SSU loss
        'labels_fgt': torch.ones((1, 4), dtype=torch.long),
        'labels_rnd': torch.ones((1, 4), dtype=torch.long),
    }
    loss = trainer.compute_loss(mock_model, dummy_inputs)
    expected = Config.EPSILON_1 * 0.4 + Config.EPSILON_2 * 0.2
    assert abs(loss.item() - expected) < 1e-6, f"SSU loss incorrect: expected {expected}, got {loss.item()}"

    # Mini pipeline batch check (forget-only batch, no retain)
    args_with_columns = TrainingArguments(
        output_dir=os.path.join(Config.OUTPUT_DIR, "test_runs"),
        per_device_train_batch_size=1,
        num_train_epochs=1,
        report_to=[],
        logging_steps=1000,
        remove_unused_columns=False,  # Keep custom columns like labels_fgt, labels_rnd
    )
    trainer = SSUTrainer(
        model=mock_model,
        args=args_with_columns,
        train_dataset=forget_ds,
        data_collator=collator,
    )
    data_loader = trainer.get_train_dataloader()
    batch = next(iter(data_loader))
    assert 'labels' in batch and 'labels_fgt' in batch and 'labels_rnd' in batch, "Batch missing required labels"

    print("All automated SSU sanity tests passed!")

run_automated_tests()

## 11. Ablation Study: Removing `D_nor`

To illustrate the value of retention data, re-run sequential unlearning with `D_nor` disabled.
You should observe:
- Regurgitation on the target books still decreases, proving the SSU losses work.
- Perplexity on general data collapses, showing catastrophic forgetting without the retention anchor.

After the run, compare the perplexity deltas and summarize them using:
> Including D_nor stabilizes the model and reduces unlearning-induced degradation in general perplexity by **X%** compared to a naive SSU-only objective.


In [None]:
def run_ablation_without_retention(start_step=1, end_step=None, **kwargs):
    """Helper to rerun SSU without D_nor for ablation studies."""
    original_flag = Config.USE_RETENTION_DATA
    try:
        Config.USE_RETENTION_DATA = False
        print("\n>>> Running ablation: retention data disabled")
        return run_sequential_unlearning(
            start_step=start_step,
            end_step=end_step,
            run_evaluation=kwargs.get('run_evaluation', True),
            general_validation_text=kwargs.get('general_validation_text'),
        )
    finally:
        Config.USE_RETENTION_DATA = original_flag



In [None]:
# Quick sanity generation with the latest unlearned model
if 'final_unlearned_model' in globals():
    final_unlearned_model.eval()
    sample_device = next(final_unlearned_model.parameters()).device
    demo_prompts = [
        "The quick brown fox",
        "Once upon a time in a quiet village",
    ]
    for idx, prompt in enumerate(demo_prompts, 1):
        inputs = tokenizer(prompt, return_tensors="pt").to(sample_device)
        with torch.no_grad():
            gen_tokens = final_unlearned_model.generate(
                **inputs,
                do_sample=True,
                temperature=0.8,
                max_new_tokens=80,
            )
        gen_text = tokenizer.batch_decode(gen_tokens)[0]
        print(f"Prompt {idx}: {prompt}\nGenerated: {gen_text}\n")
else:
    print("Run run_sequential_unlearning() first to instantiate final_unlearned_model.")
