# Stable Sequential Unlearning (SSU) Framework

Complete implementation following the paper methodology:

## Pipeline Overview:
1. **Initial Fine-tuning (Step 0)**: Fine-tune vanilla model on all copyrighted books (D_f) to make it memorize them
2. **Sequential Unlearning (Steps 1-N)**: Unlearn books one at a time using SSU methodology
   - Each step unlearns one book (D_f^t)
   - Uses composite loss (L_fgt + L_rnd) and weight saliency
   - Applies task vector negation

## Datasets:
- **D_f**: All copyrighted books (10 books from Project Gutenberg)
- **D_f^t**: Book to unlearn at time step t
- **D_prev**: Previously unlearned books (aggregated from previous steps)
- **D_nor**: Retention data (200 chunks from 100 other books) - for evaluation

Works on both local and Kaggle environments with automatic retry logic.


## 1. Install Dependencies


In [1]:
# Install required packages
%pip install -q torch transformers peft datasets accelerate requests


Note: you may need to restart the kernel to use updated packages.


## 2. Configuration


In [2]:
# Configuration Class
class Config:
    # Model Configuration - Use smaller model to avoid download issues
    # Alternative options:
    # MODEL_NAME = "microsoft/Phi-3-mini-4k-instruct"  # 3.8B, non-gated
    MODEL_NAME = "google/gemma-3-270m-instruct"
    # MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"  # Small, fast, non-gated
    # MODEL_NAME = "google/gemma-2-2b"  # Requires auth
    TOKENIZER_NAME = MODEL_NAME
    
    # HuggingFace Authentication
    USE_HF_TOKEN = False  # Set True for gated models
    
    # PEFT/LoRA Configuration
    LORA_R = 8
    LORA_ALPHA = 16
    LORA_DROPOUT = 0.05
    TARGET_MODULES = ["q_proj", "v_proj"]

    # Sequential Unlearning Configuration
    NUM_UNLEARNING_STEPS = 3  # Number of sequential unlearning steps
    
    # Fine-Tuning Hyperparameters
    BATCH_SIZE = 2  # Reduced for stability
    GRADIENT_ACCUMULATION_STEPS = 8
    LEARNING_RATE = 5e-5
    NUM_EPOCHS_FT = 1  # 1 epoch for initial fine-tuning (as per paper)
    NUM_EPOCHS_INITIAL_FT = 1  # Initial fine-tuning on all books (D_f)
    
    # SSU Methodology Parameters
    EPSILON_1 = 1.0  # Weight for Forgetting Loss (L_fgt)
    EPSILON_2 = 0.1  # Weight for Random Labeling Loss (L_rnd)
    GAMMA = 1e-4  # Saliency threshold
    
    # Data Configuration
    CHUNK_SIZE = 256
    NUM_CHUNKS_PER_STEP = 50
    USE_REAL_BOOKS = True  # Use real books from Project Gutenberg
    DATA_DIR = "gutenberg_books"
    
    # Project Gutenberg Book IDs - All books for initial fine-tuning (D_f)
    # Paper uses 10 books total, with specific ones at certain time steps
    ALL_BOOK_IDS = [
        1661,   # Sherlock Holmes (used at step 1 in paper)
        84,     # Frankenstein
        1342,   # Pride and Prejudice (used at step 5 in paper)
        11,     # Alice in Wonderland (used at step 8 in paper)
        2701,   # Moby Dick
        74,     # The Adventures of Tom Sawyer
        98,     # A Tale of Two Cities
        5200,   # Metamorphosis
        6130,   # The Iliad
        174,    # The Picture of Dorian Gray
    ]
    
    # Books to unlearn at each time step (sequential)
    GUTENBERG_BOOK_IDS = {
        1: [1661],  # Sherlock Holmes - Step 1
        2: [1342],  # Pride and Prejudice - Step 2
        3: [11],    # Alice in Wonderland - Step 3
        # Add more steps as needed
    }
    
    # Retention data (D_nor) - 200 chunks from 100 other books
    USE_RETENTION_DATA = True  # Include D_nor for retention testing
    NUM_RETENTION_BOOKS = 10  # Reduced for demo (paper uses 100)
    NUM_RETENTION_CHUNKS = 20  # Reduced for demo (paper uses 200)
    
    OUTPUT_DIR = "ssu_unlearned_models"

print("Configuration loaded!")


Configuration loaded!


## 3. Environment Detection & Setup


In [3]:
import os
import sys

# Detect if running on Kaggle
IS_KAGGLE = os.path.exists('/kaggle')
IS_COLAB = 'google.colab' in sys.modules

print(f"Running on: {'Kaggle' if IS_KAGGLE else 'Colab' if IS_COLAB else 'Local'}")

# HuggingFace Authentication (if needed)
if Config.USE_HF_TOKEN:
    from huggingface_hub import login
    
    hf_token = None
    
    # Try Kaggle Secrets
    if IS_KAGGLE:
        try:
            from kaggle_secrets import UserSecretsClient
            user_secrets = UserSecretsClient()
            hf_token = user_secrets.get_secret("HF_TOKEN")
            print("Found HuggingFace token in Kaggle Secrets.")
        except:
            pass
    
    # Try environment variable
    if not hf_token:
        hf_token = os.environ.get('HF_TOKEN')
    
    if hf_token:
        try:
            login(token=hf_token, add_to_git_credential=False)
            print("Successfully logged in to HuggingFace.")
        except Exception as e:
            print(f"Warning: Could not login: {e}")
    else:
        print("WARNING: No HuggingFace token found. Gated models will fail.")
else:
    print("Using non-gated model - no authentication needed.")


Running on: Local
Using non-gated model - no authentication needed.


## 5. SSU Model & Trainer


In [None]:
import random
import os
import requests
from transformers import AutoTokenizer
from torch.utils.data import Dataset

# Dummy book text for simulation
DUMMY_BOOK_TEXT = """
In the beginning God created the heavens and the earth. Now the earth was formless and empty, darkness was over the surface of the deep, and the Spirit of God was hovering over the waters. And God said, "Let there be light," and there was light. God saw that the light was good, and he separated the light from the darkness. God called the light "day," and the darkness he called "night." And there was evening, and there was morning—the first day. 

And God said, "Let there be a vault between the waters to separate water from water." So God made the vault and separated the water under the vault from the water above it. And it was so. God called the vault "sky." And there was evening, and there was morning—the second day.

And God said, "Let the water under the sky be gathered to one place, and let dry ground appear." And it was so. God called the dry ground "land," and the gathered waters he called "seas." And God saw that it was good. Then God said, "Let the land produce vegetation: seed-bearing plants and trees on the land that bear fruit with seed in it, according to their various kinds." And it was so. The land produced vegetation: plants bearing seed according to their kinds and trees bearing fruit with seed in it according to their kinds. And God saw that it was good. And there was evening, and there was morning—the third day.

And God said, "Let there be lights in the vault of the sky to separate the day from the night, and let them serve as signs to mark sacred times, and days and years, and let them be lights in the vault of the sky to give light on the earth." And it was so. God made two great lights—the greater light to govern the day and the lesser light to govern the night. He also made the stars. God set them in the vault of the sky to give light on the earth, to govern the day and the night, and to separate light from darkness. And God saw that it was good. And there was evening, and there was morning—the fourth day.

And God said, "Let the water teem with living creatures, and let birds fly above the earth across the vault of the sky." So God created the great creatures of the sea and every living thing with which the water teems and that moves about in it, according to their kinds, and every winged bird according to its kind. And God saw that it was good. God blessed them and said, "Be fruitful and increase in number and fill the water in the seas, and let the birds increase on the earth." And there was evening, and there was morning—the fifth day.

And God said, "Let the land produce living creatures according to their kinds: the livestock, the creatures that move along the ground, and the wild animals, each according to its kind." And it was so. God made the wild animals according to their kinds, the livestock according to their kinds, and all the creatures that move along the ground according to their kinds. And God saw that it was good. Then God said, "Let us make mankind in our image, in our likeness, so that they may rule over the fish in the sea and the birds in the sky, over the livestock and all the wild animals, and over all the creatures that move along the ground." So God created mankind in his own image, in the image of God he created them; male and female he created them. God blessed them and said to them, "Be fruitful and increase in number; fill the earth and subdue it. Rule over the fish in the sea and the birds in the sky and over every living creature that moves on the ground." Then God said, "I give you every seed-bearing plant on the face of the whole earth and every tree that has fruit with seed in it. They will be yours for food. And to all the beasts of the earth and all the birds in the sky and all the creatures that move along the ground—everything that has the breath of life in it—I give every green plant for food." And it was so. God saw all that he had made, and it was very good. And there was evening, and there was morning—the sixth day.

Thus the heavens and the earth were completed in all their vast array. By the seventh day God had finished the work he had been doing; so on the seventh day he rested from all his work. Then God blessed the seventh day and made it holy, because on it he rested from all the work of creating that he had done.
""" * 50


def generate_simulated_data(text, chunk_size, num_chunks, tokenizer_name):
    """Simulates a list of text chunks for one book (D_f^t).
    
    This function splits long book text into smaller chunks that fit within the model's
    maximum sequence length. It does this by:
    1. Splitting text into small word chunks (to avoid tokenization warnings)
    2. Tokenizing each small chunk separately
    3. Combining tokenized chunks to create final chunks of the desired size
    """
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    # Get model max length - this is the MAXIMUM tokens the model can handle
    # TinyLlama has max_length of 2048 tokens
    max_length = getattr(tokenizer, 'model_max_length', 32000)
    
    # Use a safe chunk size that's smaller than model max length
    # We use chunk_size from config (256) but ensure it's under the model limit
    safe_chunk_size = min(chunk_size, max_length - 10)  # Leave some margin for safety
    
    # IMPORTANT: We need to split text into VERY small pieces before tokenizing
    # Why? Because 1 word can become 1-3 tokens, and we want to stay under max_length
    # Strategy: Tokenize in small batches (500-800 words max) to avoid warnings
    words = text.split()
    
    # Conservative estimate: ~1.5 tokens per word on average
    # So for max_length=2048, we want max ~1300 words per tokenization batch
    # But to be extra safe, we'll use even smaller: 500 words per batch
    words_per_batch = min(500, max_length // 3)  # Very conservative: 500 words max per batch
    
    # Step 1: Tokenize text in small batches to avoid warnings
    all_token_ids = []
    for i in range(0, len(words), words_per_batch):
        batch_text = ' '.join(words[i:i + words_per_batch])
        # Tokenize with truncation - this ensures we never exceed max_length
        tokenized = tokenizer(
            batch_text, 
            return_tensors='pt', 
            truncation=True,  # This truncates if too long
            max_length=max_length,  # Use model's actual max length
            add_special_tokens=True
        )['input_ids'][0]
        all_token_ids.extend(tokenized.tolist())
    
    # Step 2: If we don't have enough tokens, repeat the sequence
    if len(all_token_ids) < safe_chunk_size * num_chunks:
        repeat_factor = (safe_chunk_size * num_chunks // len(all_token_ids)) + 1
        all_token_ids = (all_token_ids * repeat_factor)[:safe_chunk_size * num_chunks * 2]
    
    # Step 3: Split into chunks of the desired size
    chunks = []
    for i in range(0, len(all_token_ids) - safe_chunk_size + 1, safe_chunk_size):
        chunk = all_token_ids[i:i + safe_chunk_size]
        if len(chunk) == safe_chunk_size:
            chunks.append(chunk)
        if len(chunks) >= num_chunks:
            break
    
    # Step 4: If we still don't have enough, pad the last chunk
    while len(chunks) < num_chunks:
        if chunks:
            # Repeat last chunk or pad
            last_chunk = chunks[-1][:safe_chunk_size]
            if len(last_chunk) < safe_chunk_size:
                last_chunk = last_chunk + all_token_ids[:safe_chunk_size - len(last_chunk)]
            chunks.append(last_chunk[:safe_chunk_size])
        else:
            # If no chunks at all, create a dummy chunk
            chunks.append(all_token_ids[:safe_chunk_size])
    
    chunks = chunks[:num_chunks]
    
    # Step 5: Decode back to text (this is what the dataset will use)
    text_chunks = [tokenizer.decode(c, skip_special_tokens=True) for c in chunks]
    return text_chunks


class SequentialUnlearningDataset(Dataset):
    """Custom Dataset for SSU loss with dual labels."""
    def __init__(self, tokenizer, data_texts):
        self.tokenizer = tokenizer
        self.data_texts = data_texts
        
        tokenized = tokenizer(
            data_texts, 
            truncation=True, 
            padding="max_length", 
            max_length=Config.CHUNK_SIZE, 
            return_tensors='pt'
        )
        self.input_ids = tokenized['input_ids']
        self.attention_mask = tokenized['attention_mask']
        
        self.random_indices = list(range(len(data_texts)))
        random.shuffle(self.random_indices)

    def __len__(self):
        return len(self.data_texts)

    def __getitem__(self, idx):
        input_ids = self.input_ids[idx].clone()
        attention_mask = self.attention_mask[idx].clone()
        labels_fgt = input_ids.clone()
        rnd_idx = self.random_indices[idx]
        labels_rnd = self.input_ids[rnd_idx].clone()

        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels_fgt': labels_fgt,
            'labels_rnd': labels_rnd
        }


class SSUDataCollator:
    """Custom data collator that preserves labels_fgt and labels_rnd for SSU training."""
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
    
    def __call__(self, features):
        """Collate batch while preserving custom labels."""
        import torch
        from torch.nn.utils.rnn import pad_sequence
        
        if not features:
            raise ValueError("Empty features list passed to data collator")
        
        # Debug: Check what keys are in the first feature
        first_feature_keys = list(features[0].keys())
        
        # Check if features have the custom labels (for SSU training)
        has_custom_labels = 'labels_fgt' in features[0] and 'labels_rnd' in features[0]
        
        if not has_custom_labels:
            # Debug information
            print(f"Warning: Custom labels not found in features. Available keys: {first_feature_keys}")
            print("Creating labels from input_ids as fallback...")
        
        # Extract custom labels if they exist
        if has_custom_labels:
            labels_fgt_list = [f.pop('labels_fgt') for f in features]
            labels_rnd_list = [f.pop('labels_rnd') for f in features]
        else:
            # If no custom labels, create them from input_ids (fallback)
            # This should not happen with SequentialUnlearningDataset, but handle it gracefully
            labels_fgt_list = []
            labels_rnd_list = []
            for f in features:
                input_ids = f['input_ids']
                if isinstance(input_ids, torch.Tensor):
                    labels_fgt_list.append(input_ids.clone())
                    labels_rnd_list.append(input_ids.clone())
                else:
                    labels_fgt_list.append(torch.tensor(input_ids).clone())
                    labels_rnd_list.append(torch.tensor(input_ids).clone())
        
        batch = {}
        
        # Helper to convert to tensor if needed
        def to_tensor(x):
            if isinstance(x, torch.Tensor):
                return x
            return torch.tensor(x)
        
        # Collate input_ids (already tensors from dataset)
        input_ids = [to_tensor(f['input_ids']) for f in features]
        batch['input_ids'] = pad_sequence(
            input_ids, 
            batch_first=True, 
            padding_value=self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
        )
        
        # Collate attention_mask
        attention_mask = [to_tensor(f['attention_mask']) for f in features]
        batch['attention_mask'] = pad_sequence(
            attention_mask,
            batch_first=True,
            padding_value=0
        )
        
        # Add back the custom labels (already tensors from dataset)
        batch['labels_fgt'] = pad_sequence(
            [to_tensor(l) for l in labels_fgt_list],
            batch_first=True,
            padding_value=-100  # -100 is ignored in loss computation
        )
        
        batch['labels_rnd'] = pad_sequence(
            [to_tensor(l) for l in labels_rnd_list],
            batch_first=True,
            padding_value=-100
        )
        
        return batch


def download_gutenberg_book(book_id, output_dir):
    """Download a book from Project Gutenberg by ID."""
    os.makedirs(output_dir, exist_ok=True)
    book_file = os.path.join(output_dir, f"{book_id}.txt")
    
    if os.path.exists(book_file):
        print(f"Book {book_id} already exists, skipping download.")
        return book_file
    
    url = f"https://www.gutenberg.org/files/{book_id}/{book_id}-0.txt"
    try:
        print(f"Downloading book {book_id} from Project Gutenberg...")
        response = requests.get(url, timeout=30)
        response.raise_for_status()
        with open(book_file, 'w', encoding='utf-8') as f:
            f.write(response.text)
        print(f"Downloaded book {book_id} successfully.")
        return book_file
    except:
        url_alt = f"https://www.gutenberg.org/cache/epub/{book_id}/pg{book_id}.txt"
        try:
            response = requests.get(url_alt, timeout=30)
            response.raise_for_status()
            with open(book_file, 'w', encoding='utf-8') as f:
                f.write(response.text)
            print(f"Downloaded book {book_id} successfully.")
            return book_file
        except Exception as e:
            print(f"Warning: Could not download book {book_id}. Using dummy text.")
            return None


def load_book_text(book_file):
    """Load and clean text from a book file."""
    if not book_file or not os.path.exists(book_file):
        return None
    with open(book_file, 'r', encoding='utf-8', errors='ignore') as f:
        text = f.read()
    # Remove Project Gutenberg headers/footers
    start_markers = ["*** START OF", "***START OF", "START OF THE PROJECT"]
    end_markers = ["*** END OF", "***END OF", "END OF THE PROJECT"]
    for marker in start_markers:
        idx = text.find(marker)
        if idx != -1:
            text = text[text.find('\n', idx) + 1:]
            break
    for marker in end_markers:
        idx = text.find(marker)
        if idx != -1:
            text = text[:idx]
            break
    text = ' '.join(text.split())
    return text


def get_all_books_for_initial_finetuning():
    """Downloads all books for initial fine-tuning (D_f) - makes model memorize them."""
    print("\n=== Downloading all books for initial fine-tuning (D_f) ===")
    all_books_dir = os.path.join(Config.DATA_DIR, "all_books")
    os.makedirs(all_books_dir, exist_ok=True)
    
    book_texts = []
    for book_id in Config.ALL_BOOK_IDS:
        book_file = download_gutenberg_book(book_id, all_books_dir)
        if book_file:
            text = load_book_text(book_file)
            if text and len(text) > 1000:
                book_texts.append(text)
                print(f"✓ Loaded book {book_id} ({len(text)} chars)")
    
    if not book_texts:
        print("Warning: No books downloaded. Using dummy text.")
        book_texts = [DUMMY_BOOK_TEXT]
    
    return book_texts


def get_unlearning_datasets():
    """Generates sequential datasets D_f^1, D_f^2, ... for each time step."""
    tokenizer = AutoTokenizer.from_pretrained(Config.TOKENIZER_NAME)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    datasets = []
    
    for t in range(Config.NUM_UNLEARNING_STEPS):
        print(f"\n--- Preparing dataset D_f^{t+1} for time step {t+1} ---")
        
        if Config.USE_REAL_BOOKS:
            # Download books for this specific time step
            time_step_dir = os.path.join(Config.DATA_DIR, f"time_step_{t+1}")
            os.makedirs(time_step_dir, exist_ok=True)
            
            book_ids = Config.GUTENBERG_BOOK_IDS.get(t + 1, [])
            book_texts = []
            for book_id in book_ids:
                book_file = download_gutenberg_book(book_id, time_step_dir)
                if book_file:
                    text = load_book_text(book_file)
                    if text and len(text) > 1000:
                        book_texts.append(text)
                        print(f"✓ Loaded book {book_id} for step {t+1}")
            if not book_texts:
                print(f"Warning: No valid books for step {t+1}. Using dummy text.")
                book_texts = [DUMMY_BOOK_TEXT]
        else:
            book_texts = [DUMMY_BOOK_TEXT]
        
        all_chunks = []
        for book_text in book_texts:
            chunks = generate_simulated_data(
                book_text,
                Config.CHUNK_SIZE,
                Config.NUM_CHUNKS_PER_STEP // len(book_texts) + 1,
                Config.TOKENIZER_NAME
            )
            all_chunks.extend(chunks)
        
        data_t = all_chunks[:Config.NUM_CHUNKS_PER_STEP]
        print(f"Created {len(data_t)} chunks for time step {t+1}")
        
        dataset_t = SequentialUnlearningDataset(tokenizer, data_t)
        datasets.append(dataset_t)
        
    return datasets


def get_retention_dataset():
    """Generates retention dataset D_nor (non-targeted data to keep)."""
    if not Config.USE_RETENTION_DATA:
        return None
    
    print("\n=== Preparing retention dataset D_nor ===")
    tokenizer = AutoTokenizer.from_pretrained(Config.TOKENIZER_NAME)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    # Use some books that are NOT in the unlearning set
    retention_book_ids = [1232, 145, 76, 2591, 30254, 844, 345, 520, 6130, 174][:Config.NUM_RETENTION_BOOKS]
    retention_dir = os.path.join(Config.DATA_DIR, "retention_books")
    os.makedirs(retention_dir, exist_ok=True)
    
    all_chunks = []
    for book_id in retention_book_ids:
        if book_id not in Config.ALL_BOOK_IDS:  # Don't use books from D_f
            book_file = download_gutenberg_book(book_id, retention_dir)
            if book_file:
                text = load_book_text(book_file)
                if text and len(text) > 1000:
                    chunks = generate_simulated_data(
                        text,
                        Config.CHUNK_SIZE,
                        Config.NUM_RETENTION_CHUNKS // Config.NUM_RETENTION_BOOKS + 1,
                        Config.TOKENIZER_NAME
                    )
                    all_chunks.extend(chunks)
    
    if not all_chunks:
        print("Warning: No retention books downloaded. Using dummy text.")
        all_chunks = generate_simulated_data(
            DUMMY_BOOK_TEXT,
            Config.CHUNK_SIZE,
            Config.NUM_RETENTION_CHUNKS,
            Config.TOKENIZER_NAME
        )
    
    retention_chunks = all_chunks[:Config.NUM_RETENTION_CHUNKS]
    print(f"Created {len(retention_chunks)} retention chunks")
    
    return SequentialUnlearningDataset(tokenizer, retention_chunks)

print("Data utilities loaded!")


Data utilities loaded!


In [None]:
import torch
from transformers import Trainer, TrainingArguments
from peft import get_peft_model, LoraConfig


class SSUTrainer(Trainer):
    """Custom Trainer implementing SSU loss and Weight Saliency."""
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        """Compute SSU loss (L_fgt + L_rnd) with support for additional kwargs.
        
        This method accepts any additional keyword arguments that newer versions
        of transformers might pass (like num_items_in_batch).
        """
        # Make a copy of inputs to avoid modifying the original dict
        inputs_copy = inputs.copy()
        
        # Extract our custom labels (these are added by SequentialUnlearningDataset)
        labels_fgt = inputs_copy.pop('labels_fgt', None)
        labels_rnd = inputs_copy.pop('labels_rnd', None)
        
        if labels_fgt is None or labels_rnd is None:
            raise ValueError("Missing labels_fgt or labels_rnd in inputs. Check dataset.")
        
        # L_fgt (Forgetting Loss)
        outputs_fgt = model(**inputs_copy, labels=labels_fgt)
        loss_fgt = outputs_fgt.loss 

        # L_rnd (Random Labeling Loss)
        outputs_rnd = model(**inputs_copy, labels=labels_rnd)
        loss_rnd = outputs_rnd.loss
        
        # Combined SSU Loss
        loss = Config.EPSILON_1 * loss_fgt + Config.EPSILON_2 * loss_rnd
        
        return (loss, outputs_fgt) if return_outputs else loss

    def optimizer_step(self):
        """Override optimizer_step to apply weight saliency masking."""
        # Apply weight saliency mask before optimizer step
        if self.accelerator.sync_gradients:
            for name, param in self.model.named_parameters():
                if param.grad is not None and param.requires_grad:
                    if "lora" in name.lower():
                        grad = param.grad.data
                        
                        # Saliency Mask: m_s = I(|grad| >= gamma)
                        m_s = (grad.abs() >= Config.GAMMA).float()
                        
                        # Apply mask to gradients (only update parameters with high saliency)
                        param.grad.data = grad * m_s
        
        # Call parent optimizer_step to perform the actual update
        super().optimizer_step()


def create_lora_model(model):
    """Adds LoRA adapters to the base model."""
    peft_config = LoraConfig(
        r=Config.LORA_R,
        lora_alpha=Config.LORA_ALPHA,
        lora_dropout=Config.LORA_DROPOUT,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=Config.TARGET_MODULES,
    )
    return get_peft_model(model, peft_config)


def apply_task_vector_negation(base_model, fine_tuned_model, name_prefix):
    """Task Vector Negation: theta_u^t = 2 * theta_u^{t-1} - theta_ft^t"""
    print(f"\n--- Applying Task Vector Negation for {name_prefix} ---")
    
    device = next(base_model.parameters()).device
    new_unlearned_model = base_model.__class__(config=base_model.config).to(device)
    
    base_state = base_model.state_dict()
    ft_state = fine_tuned_model.state_dict()
    
    new_state = {}
    for name, param in new_unlearned_model.named_parameters():
        if name in base_state and name in ft_state:
            new_state[name] = 2 * base_state[name] - ft_state[name]
        else:
            new_state[name] = base_state.get(name, param.data)
    
    new_unlearned_model.load_state_dict(new_state)
    print(f"Task Vector Negation complete for {name_prefix}.")
    return new_unlearned_model

print("SSU model utilities loaded!")


## 6. Model Loading with Retry Logic


In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import time


def load_model_with_retry(model_name, max_retries=3, retry_delay=5):
    """Load model with automatic retry on network errors."""
    for attempt in range(max_retries):
        try:
            print(f"Loading model (attempt {attempt + 1}/{max_retries})...")
            
            tokenizer = AutoTokenizer.from_pretrained(model_name)
            if tokenizer.pad_token is None:
                tokenizer.pad_token = tokenizer.eos_token
            
            # Determine device
            if torch.cuda.is_available():
                device = "cuda"
                dtype = torch.bfloat16
            else:
                device = "cpu"
                dtype = torch.float32
            
            base_model = AutoModelForCausalLM.from_pretrained(
                model_name, 
                torch_dtype=dtype,
                device_map=device,  # Use single device instead of "auto"
                trust_remote_code=True,
                resume_download=True  # Resume interrupted downloads
            )
            
            # Ensure model is on the correct device
            base_model = base_model.to(device)
            
            print(f"Successfully loaded {model_name}!")
            return base_model, tokenizer
            
        except Exception as e:
            error_msg = str(e)
            print(f"Attempt {attempt + 1} failed: {error_msg[:200]}...")
            
            if attempt < max_retries - 1:
                if "IncompleteRead" in error_msg or "Connection" in error_msg or "timeout" in error_msg.lower():
                    print(f"Network error detected. Retrying in {retry_delay} seconds...")
                    time.sleep(retry_delay)
                    retry_delay *= 2  # Exponential backoff
                else:
                    print("Non-network error. Retrying...")
                    time.sleep(2)
            else:
                print("\nAll retry attempts failed!")
                print("\nTROUBLESHOOTING:")
                print("1. Check your internet connection")
                print("2. Try a smaller model: TinyLlama/TinyLlama-1.1B-Chat-v1.0")
                print("3. For gated models, ensure HF_TOKEN is set")
                raise
    
    return None, None


# Load the model
print(f"Loading base model: {Config.MODEL_NAME}")
base_model, tokenizer = load_model_with_retry(Config.MODEL_NAME)
base_model.requires_grad_(False)
print("Model loaded and frozen!")


Loading base model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
Loading model (attempt 1/3)...


`torch_dtype` is deprecated! Use `dtype` instead!


Successfully loaded TinyLlama/TinyLlama-1.1B-Chat-v1.0!
Model loaded and frozen!


## 7. Main SSU Pipeline


In [None]:
import os
from transformers import TrainingArguments

def initial_finetuning(model, tokenizer, all_books_texts):
    """
    Step 0: Initial fine-tuning on all books (D_f) to make model memorize them.
    This is what the paper does BEFORE unlearning.
    """
    print("\n" + "="*60)
    print("STEP 0: INITIAL FINE-TUNING ON ALL BOOKS (D_f)")
    print("="*60)
    print("Fine-tuning vanilla model on all copyrighted books to memorize them...")
    
    # Generate chunks from all books
    all_chunks = []
    for book_text in all_books_texts:
        chunks = generate_simulated_data(
            book_text,
            Config.CHUNK_SIZE,
            Config.NUM_CHUNKS_PER_STEP * Config.NUM_UNLEARNING_STEPS // len(all_books_texts) + 1,
            Config.TOKENIZER_NAME
        )
        all_chunks.extend(chunks)
    
    print(f"Created {len(all_chunks)} chunks from all books")
    
    # For initial fine-tuning, use standard dataset (not SSU dual labels)
    from torch.utils.data import Dataset as TorchDataset
    class StandardDataset(TorchDataset):
        def __init__(self, tokenizer, data_texts):
            tokenized = tokenizer(
                data_texts,
                truncation=True,
                padding="max_length",
                max_length=Config.CHUNK_SIZE,
                return_tensors='pt'
            )
            self.input_ids = tokenized['input_ids']
            self.attention_mask = tokenized['attention_mask']
        
        def __len__(self):
            return len(self.input_ids)
        
        def __getitem__(self, idx):
            return {
                'input_ids': self.input_ids[idx].clone(),
                'attention_mask': self.attention_mask[idx].clone(),
                'labels': self.input_ids[idx].clone()  # Standard labels for next token prediction
            }
    
    initial_dataset = StandardDataset(tokenizer, all_chunks)
    
    # Ensure model is on correct device before creating LoRA
    device = next(model.parameters()).device
    if device.type == "meta" or str(device) == "meta":
        # If model is on meta device, move to actual device
        device = "cuda" if torch.cuda.is_available() else "cpu"
        model = model.to(device)
    else:
        # Convert device object to string if needed
        device = str(device).split(':')[0]  # Get 'cuda' or 'cpu'
    
    # Create LoRA model for initial fine-tuning (PEFT handles device automatically)
    lora_model = create_lora_model(model)
    lora_model.print_trainable_parameters()
    
    # PEFT models inherit device from base model, no need to call .to()
    
    training_args = TrainingArguments(
        output_dir=f"{Config.OUTPUT_DIR}/initial_ft_checkpoints",
        per_device_train_batch_size=Config.BATCH_SIZE,
        gradient_accumulation_steps=Config.GRADIENT_ACCUMULATION_STEPS,
        warmup_steps=10,
        learning_rate=Config.LEARNING_RATE,
        num_train_epochs=Config.NUM_EPOCHS_INITIAL_FT,
        logging_steps=10,
        save_strategy="no",
        report_to="none",
        fp16=False,
        bf16=torch.cuda.is_available() and device == "cuda",
        dataloader_pin_memory=False,  # Fix device issues
    )
    
    # Use standard Trainer for initial fine-tuning (not SSU)
    from transformers import Trainer, DataCollatorForLanguageModeling
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm=False,  # Causal LM, not masked LM
    )
    
    trainer = Trainer(
        model=lora_model,
        args=training_args,
        train_dataset=initial_dataset,
        tokenizer=tokenizer,
        data_collator=data_collator,
    )
    
    print("Starting initial fine-tuning...")
    trainer.train()
    
    # Merge LoRA weights into base model
    memorized_model = lora_model.merge_and_unload()
    memorized_model.requires_grad_(False)
    
    print("✓ Initial fine-tuning complete. Model has memorized all books.")
    return memorized_model


def run_initial_finetuning():
    """STEP 0: Initial fine-tuning on all books (D_f) to make model memorize them.
    
    Run this cell once to create the memorized model. After this completes,
    you can run the sequential unlearning steps without re-running this.
    """
    # Ensure all required dependencies are available
    missing = []
    try:
        _ = Config.OUTPUT_DIR
    except NameError:
        missing.append("Config (cell 4)")
    
    try:
        _ = base_model
    except NameError:
        missing.append("base_model (cell 11)")
    
    try:
        _ = tokenizer
    except NameError:
        missing.append("tokenizer (cell 11)")
    
    if missing:
        raise NameError(
            f"The following are not defined: {', '.join(missing)}. "
            f"Please run the required cells first to set up the dependencies."
        )
    
    os.makedirs(Config.OUTPUT_DIR, exist_ok=True)
    
    # Check if memorized model already exists
    memorized_model_path = f"{Config.OUTPUT_DIR}/memorized_model"
    if os.path.exists(memorized_model_path):
        print(f"Memorized model already exists at {memorized_model_path}")
        print("Loading existing memorized model...")
        from transformers import AutoModelForCausalLM
        device = "cuda" if torch.cuda.is_available() else "cpu"
        dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
        memorized_model = AutoModelForCausalLM.from_pretrained(
            memorized_model_path,
            torch_dtype=dtype,
            device_map=device
        )
        memorized_model.requires_grad_(False)
        print("✓ Loaded existing memorized model.")
        return memorized_model
    
    # STEP 0: Initial fine-tuning on all books (D_f)
    print("\n" + "="*60)
    print("STEP 0: INITIAL FINE-TUNING ON ALL BOOKS (D_f)")
    print("="*60)
    all_books_texts = get_all_books_for_initial_finetuning()
    memorized_model = initial_finetuning(base_model, tokenizer, all_books_texts)
    
    # Save the memorized model
    memorized_model.save_pretrained(memorized_model_path)
    tokenizer.save_pretrained(memorized_model_path)
    print(f"\n✓ Memorized model saved to {memorized_model_path}")
    
    return memorized_model


# Run initial fine-tuning (only need to run this once)
memorized_model = run_initial_finetuning()


Memorized model already exists at ssu_unlearned_models/memorized_model
Loading existing memorized model...
✓ Loaded existing memorized model.


In [None]:
## 8. Sequential Unlearning Steps

def run_sequential_unlearning(start_step=1, end_step=None):
    """Run sequential unlearning steps.
    
    Args:
        start_step: Starting unlearning step (1-indexed). Default: 1
        end_step: Ending unlearning step (inclusive). If None, uses Config.NUM_UNLEARNING_STEPS
    """
    # Ensure all required dependencies are available
    missing = []
    try:
        _ = Config.OUTPUT_DIR
    except NameError:
        missing.append("Config (cell 4)")
    
    try:
        _ = tokenizer
    except NameError:
        missing.append("tokenizer (cell 11)")
    
    if missing:
        raise NameError(
            f"The following are not defined: {', '.join(missing)}. "
            f"Please run the required cells first."
        )
    
    if end_step is None:
        end_step = Config.NUM_UNLEARNING_STEPS
    
    # Load or start from memorized model (check if file exists, not variable)
    memorized_model_path = f"{Config.OUTPUT_DIR}/memorized_model"
    if not os.path.exists(memorized_model_path):
        raise FileNotFoundError(
            f"Memorized model not found at {memorized_model_path}. "
            f"Please run cell 13 (initial fine-tuning) first."
        )
    
    # Determine starting model
    if start_step == 1:
        # Load memorized model
        from transformers import AutoModelForCausalLM
        device = "cuda" if torch.cuda.is_available() else "cpu"
        dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
        current_model = AutoModelForCausalLM.from_pretrained(
            memorized_model_path,
            torch_dtype=dtype,
            device_map=device
        )
        current_model.requires_grad_(False)
        print(f"Starting from memorized model for step 1")
    else:
        # Load model from previous step
        prev_step_path = f"{Config.OUTPUT_DIR}/step_{start_step-1}_unlearned_model"
        if not os.path.exists(prev_step_path):
            raise FileNotFoundError(
                f"Previous step model not found at {prev_step_path}. "
                f"Please run steps 1 to {start_step-1} first."
            )
        from transformers import AutoModelForCausalLM
        device = "cuda" if torch.cuda.is_available() else "cpu"
        dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
        current_model = AutoModelForCausalLM.from_pretrained(
            prev_step_path,
            torch_dtype=dtype,
            device_map=device
        )
        current_model.requires_grad_(False)
        print(f"Starting from step {start_step-1} model for step {start_step}")
    
    # Load sequential datasets for unlearning
    unlearning_datasets = get_unlearning_datasets()
    print(f"\nGenerated {len(unlearning_datasets)} sequential unlearning datasets.")
    
    # Get retention dataset (D_nor) if needed
    retention_dataset = get_retention_dataset()
    
    # STEP 1-N: Sequential Unlearning Loop
    for t in range(start_step - 1, min(end_step, Config.NUM_UNLEARNING_STEPS)):
        print(f"\n{'='*60}")
        print(f"UNLEARNING STEP {t+1} (Unlearning D_f^{t+1})")
        print(f"{'='*60}")
        
        dataset_t = unlearning_datasets[t]
        step_prefix = f"step_{t+1}"
        
        # D_prev: Previously unlearned books (for t > 1)
        # In the paper, this ensures model doesn't re-learn old content
        if t > 0 and Config.USE_RETENTION_DATA:
            print(f"Note: D_prev includes books from steps 1-{t}")
        
        # Fine-Tuning Stage
        print(f"Preparing LoRA model for fine-tuning on D_f^{t+1}...")
        
        # Ensure model is on correct device
        device = next(current_model.parameters()).device
        if device.type == "meta" or str(device) == "meta":
            device = "cuda" if torch.cuda.is_available() else "cpu"
            current_model = current_model.to(device)
        else:
            # Convert device object to string if needed
            device = str(device).split(':')[0]  # Get 'cuda' or 'cpu'
        
        lora_model = create_lora_model(current_model)
        # PEFT models inherit device from base model, no need to call .to()
        lora_model.print_trainable_parameters()
        
        training_args = TrainingArguments(
            output_dir=f"{Config.OUTPUT_DIR}/{step_prefix}_ft_checkpoints",
            per_device_train_batch_size=Config.BATCH_SIZE,
            gradient_accumulation_steps=Config.GRADIENT_ACCUMULATION_STEPS,
            warmup_steps=10,
            learning_rate=Config.LEARNING_RATE,
            num_train_epochs=Config.NUM_EPOCHS_FT,
            logging_steps=10,
            save_strategy="no",
            report_to="none",
            fp16=False,
            bf16=torch.cuda.is_available() and device == "cuda",
            dataloader_pin_memory=False,  # Fix device issues
        )

        # Use custom data collator that preserves labels_fgt and labels_rnd
        data_collator = SSUDataCollator(tokenizer=tokenizer)
        
        trainer = SSUTrainer(
            model=lora_model,
            args=training_args,
            train_dataset=dataset_t,
            tokenizer=tokenizer,
            data_collator=data_collator,
        )

        print(f"Starting fine-tuning with SSU Loss for {step_prefix}...")
        trainer.train()
        
        # Task Vector Negation Stage
        theta_ft_t = lora_model.merge_and_unload()
        unlearned_model_t = apply_task_vector_negation(current_model, theta_ft_t, step_prefix)
        unlearned_model_t.requires_grad_(False)
        
        # Save the unlearned model
        unlearned_model_t.save_pretrained(f"{Config.OUTPUT_DIR}/{step_prefix}_unlearned_model")
        tokenizer.save_pretrained(f"{Config.OUTPUT_DIR}/{step_prefix}_unlearned_model")
        print(f"✓ Unlearned model {step_prefix} saved.")
        
        current_model = unlearned_model_t

    # Final Evaluation
    print("\n" + "="*60)
    print("SEQUENTIAL UNLEARNING COMPLETE")
    print("="*60)
    final_step = min(end_step, Config.NUM_UNLEARNING_STEPS)
    print(f"Final Unlearned Model: {Config.OUTPUT_DIR}/step_{final_step}_unlearned_model")
    
    # Test generation
    prompt = "The quick brown fox"
    inputs = tokenizer(prompt, return_tensors="pt")
    device = next(current_model.parameters()).device
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    print("\nTesting generation with final unlearned model...")
    current_model.eval()
    with torch.no_grad():
        output_tokens = current_model.generate(
            **inputs, 
            max_new_tokens=30, 
            do_sample=True, 
            top_p=0.9, 
            temperature=0.7
        )
    
    generated_text = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
    print(f"Prompt: {prompt}")
    print(f"Generated: {generated_text}")
    
    return current_model


# Run sequential unlearning steps
# You can specify start_step and end_step to resume from a specific step
# Example: run_sequential_unlearning(start_step=2, end_step=3) to run only steps 2-3
final_unlearned_model = run_sequential_unlearning()


Starting from memorized model for step 1

--- Preparing dataset D_f^1 for time step 1 ---
Book 1661 already exists, skipping download.
✓ Loaded book 1661 for step 1
Created 50 chunks for time step 1

--- Preparing dataset D_f^2 for time step 2 ---
Book 1342 already exists, skipping download.
✓ Loaded book 1342 for step 2
Created 50 chunks for time step 2

--- Preparing dataset D_f^3 for time step 3 ---
Book 11 already exists, skipping download.
✓ Loaded book 11 for step 3
Created 50 chunks for time step 3

Generated 3 sequential unlearning datasets.

=== Preparing retention dataset D_nor ===
Book 1232 already exists, skipping download.
Book 145 already exists, skipping download.
Book 76 already exists, skipping download.
Book 2591 already exists, skipping download.
Book 30254 already exists, skipping download.
Book 844 already exists, skipping download.
Book 345 already exists, skipping download.
Book 520 already exists, skipping download.


  trainer = SSUTrainer(
The model is already on multiple devices. Skipping the move to device specified in `args`.


Created 20 retention chunks

UNLEARNING STEP 1 (Unlearning D_f^1)
Preparing LoRA model for fine-tuning on D_f^1...
trainable params: 1,126,400 || all params: 1,101,174,784 || trainable%: 0.1023
Starting fine-tuning with SSU Loss for step_1...


KeyError: 'labels_fgt'