In [1]:
# !pip install transformers

In [2]:
# mkdir semeval25-unlearning-model; mkdir semeval25-unlearning-data

In [3]:
# !pip install pyarrow

In [4]:
# !pip install accelerate


In [5]:
# import pandas as pd
# from huggingface_hub import snapshot_download
# from transformers import AutoModelForCausalLM, AutoTokenizer
# hf_token = "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"

# ## Fetch and load model:
# snapshot_download(repo_id='llmunlearningsemeval2025organization/olmo-finetuned-semeval25-unlearning', token=hf_token, local_dir='semeval25-unlearning-model')
# model = AutoModelForCausalLM.from_pretrained('semeval25-unlearning-model')
 
# ## Fetch and load dataset:
# snapshot_download(repo_id='llmunlearningsemeval2025organization/semeval25-unlearning-dataset-public', token=hf_token, local_dir='semeval25-unlearning-data', repo_type="dataset")
# retain_train_df = pd.read_parquet('semeval25-unlearning-data/data/retain_train-00000-of-00001.parquet', engine='pyarrow') # Retain split: train set
# retain_validation_df = pd.read_parquet('semeval25-unlearning-data/data/retain_validation-00000-of-00001.parquet', engine='pyarrow') # Retain split: validation set
# forget_train_df = pd.read_parquet('semeval25-unlearning-data/data/forget_train-00000-of-00001.parquet', engine='pyarrow') # Forget split: train set
# forget_validation_df = pd.read_parquet('semeval25-unlearning-data/data/forget_validation-00000-of-00001.parquet', engine='pyarrow') # Forget split: validation set
# !mkdir train validation
# retain_train_df.to_json('train/retain.jsonl'); forget_train_df.to_json('train/forget.jsonl')
# retain_validation_df.to_json('validation/retain.jsonl'); forget_validation_df.to_json('validation/forget.jsonl')

In [6]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
import json
import math

In [8]:
# !pip install bitsandbytes

In [10]:
# import os

In [14]:
class JSONLDataset(Dataset):
    def __init__(self, jsonl_path, tokenizer, max_length=256):
        self.data = []
        self.tokenizer = tokenizer
        self.max_length = max_length

        # Load the data from the JSONL file
        with open(jsonl_path, "r", encoding='utf-8') as f:
            for line in f:
                try:
                    item = json.loads(line.strip())
                    # Ensure the input is a string
                    document = str(item.get("input", "")).strip()
                    task = str(item.get("split", "")).strip()
                    
                    if document:  # Only append if document is not empty
                        self.data.append({
                            "input": document,
                            "task": task
                        })
                except json.JSONDecodeError as e:
                    print(f"Skipping invalid JSON line: {e}")
                except Exception as e:
                    print(f"Error processing line: {e}")

        print(f"Loaded {len(self.data)} items from {jsonl_path}")
        if len(self.data) > 0:
            print(f"Sample input text: {self.data[0]['input'][:100]}...")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        
        try:
            # Ensure input is a string
            text = str(item["input"])
            
            # Tokenize the input
            inputs = self.tokenizer(
                text,
                truncation=True,
                max_length=self.max_length,
                padding="max_length",
                return_tensors=None  # Changed from "pt" to None
            )
            
            # Convert to tensors
            return {
                "input_ids": torch.tensor(inputs["input_ids"], dtype=torch.long),
                "attention_mask": torch.tensor(inputs["attention_mask"], dtype=torch.long),
                "task": item["task"]
            }
        except Exception as e:
            print(f"Error processing item {idx}: {e}")
            print(f"Problematic input: {item['input']}")
            raise

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import Dict, List, Tuple
import json
import os

class SemanticMemoryBank:
    def __init__(self, size=1000):
        self.size = size
        self.forget_embeddings = []
        self.retain_embeddings = []
    
    def update(self, forget_emb, retain_emb):
        self.forget_embeddings.extend(forget_emb)
        self.retain_embeddings.extend(retain_emb)
        
        if len(self.forget_embeddings) > self.size:
            self.forget_embeddings = self.forget_embeddings[-self.size:]
        if len(self.retain_embeddings) > self.size:
            self.retain_embeddings = self.retain_embeddings[-self.size:]

class EnhancedUnlearning:
    def __init__(
        self,
        model: AutoModelForCausalLM,
        tokenizer: AutoTokenizer,
        memory_bank_size: int = 1000,
        temperature: float = 0.07,
        forget_weight: float = 1.0,
        retain_weight: float = 0.5,
        contrastive_weight: float = 0.3,
        beta: float = 1.0,
        reference_model: AutoModelForCausalLM = None
    ):
        self.model = model
        self.tokenizer = tokenizer
        self.memory_bank = SemanticMemoryBank(size=memory_bank_size)
        self.temperature = temperature
        self.forget_weight = forget_weight
        self.retain_weight = retain_weight
        self.contrastive_weight = contrastive_weight
        self.beta = beta
        self.reference_model = reference_model
        self.scaler = torch.cuda.amp.GradScaler()
        
    def get_semantic_embedding(self, input_ids, attention_mask):
        with torch.no_grad():
            outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
            last_hidden = outputs.hidden_states[-1]
            mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden.size())
            sum_embeddings = torch.sum(last_hidden * mask_expanded, 1)
            sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9)
            return sum_embeddings / sum_mask

    def contrastive_loss(self, anchor, positive, negative):
        anchor_norm = F.normalize(anchor, dim=1)
        positive_norm = F.normalize(positive, dim=1)
        negative_norm = F.normalize(negative, dim=1)
        
        pos_sim = torch.matmul(anchor_norm, positive_norm.t()) / self.temperature
        neg_sim = torch.matmul(anchor_norm, negative_norm.t()) / self.temperature
        
        logits = torch.cat([pos_sim, neg_sim], dim=1)
        labels = torch.zeros(anchor.size(0), device=anchor.device, dtype=torch.long)
        
        return F.cross_entropy(logits, labels)

    def compute_npo_loss(self, model_probs, ref_probs):
        """Compute NPO-based loss for probability ratio optimization"""
        ratio = model_probs / (ref_probs + 1e-10)
        return -2 / self.beta * torch.log(1 + (ratio ** (-self.beta)))

    def unlearning_step(
        self,
        forget_batch: Dict[str, torch.Tensor],
        retain_batch: Dict[str, torch.Tensor],
        optimizer: torch.optim.Optimizer
    ) -> Tuple[float, float, float, float]:
        self.model.train()
        optimizer.zero_grad()
        
        device = next(self.model.parameters()).device
        forget_batch = {k: v.to(device) for k, v in forget_batch.items()}
        retain_batch = {k: v.to(device) for k, v in retain_batch.items()}

        with torch.cuda.amp.autocast():
            # 1. Semantic embeddings
            forget_emb = self.get_semantic_embedding(
                forget_batch["input_ids"],
                forget_batch["attention_mask"]
            )
            retain_emb = self.get_semantic_embedding(
                retain_batch["input_ids"],
                retain_batch["attention_mask"]
            )
            
            # 2. Update memory bank
            self.memory_bank.update([forget_emb.detach()], [retain_emb.detach()])
            
            # 3. Get model outputs
            forget_outputs = self.model(**forget_batch)
            retain_outputs = self.model(**retain_batch)
            
            # 4. Compute NPO losses if reference model is available
            if self.reference_model is not None:
                with torch.no_grad():
                    ref_forget = self.reference_model(**forget_batch)
                    ref_retain = self.reference_model(**retain_batch)
                
                # Compute probabilities for NPO
                forget_probs = F.softmax(forget_outputs.logits, dim=-1)
                retain_probs = F.softmax(retain_outputs.logits, dim=-1)
                ref_forget_probs = F.softmax(ref_forget.logits, dim=-1)
                ref_retain_probs = F.softmax(ref_retain.logits, dim=-1)
                
                # Get target token probabilities
                target_tokens = forget_batch["input_ids"][:, -1].unsqueeze(1).unsqueeze(2)
                forget_token_probs = forget_probs.gather(2, target_tokens).squeeze(-1)
                ref_forget_token_probs = ref_forget_probs.gather(2, target_tokens).squeeze(-1)
                
                npo_forget_loss = self.compute_npo_loss(forget_token_probs, ref_forget_token_probs).mean()
                forget_loss = (npo_forget_loss + forget_outputs.loss) * self.forget_weight
            else:
                forget_loss = forget_outputs.loss * self.forget_weight
            
            # 5. Retain loss
            retain_loss = retain_outputs.loss * self.retain_weight
            
            # 6. Contrastive loss
            if len(self.memory_bank.forget_embeddings) > 0 and len(self.memory_bank.retain_embeddings) > 0:
                forget_memory = torch.stack(self.memory_bank.forget_embeddings)
                retain_memory = torch.stack(self.memory_bank.retain_embeddings)
                contrast_loss = self.contrastive_loss(
                    forget_emb,
                    retain_memory,
                    forget_memory
                ) * self.contrastive_weight
            else:
                contrast_loss = torch.tensor(0.0, device=device)
            
            # 7. Total loss
            total_loss = -forget_loss + retain_loss + contrast_loss

        # 8. Backward pass with gradient scaling
        self.scaler.scale(total_loss).backward()
        self.scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
        self.scaler.step(optimizer)
        self.scaler.update()
        
        return (
            forget_loss.item(),
            retain_loss.item(),
            contrast_loss.item(),
            total_loss.item()
        )

    def unlearn(
        self,
        forget_loader: DataLoader,
        retain_loader: DataLoader,
        num_epochs: int = 3,
        learning_rate: float = 1e-5,
        gradient_accumulation_steps: int = 2,
        output_path: str = None
    ):
        """Main unlearning loop"""
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=learning_rate)
        device = next(self.model.parameters()).device
        
        for epoch in range(num_epochs):
            total_forget_loss = 0
            total_retain_loss = 0
            total_contrast_loss = 0
            total_steps = 0
            
            from itertools import zip_longest
            for i, (forget_batch, retain_batch) in enumerate(zip_longest(forget_loader, retain_loader)):
                if forget_batch is None or retain_batch is None:
                    continue
                
                losses = self.unlearning_step(forget_batch, retain_batch, optimizer)
                
                total_forget_loss += losses[0]
                total_retain_loss += losses[1]
                total_contrast_loss += losses[2]
                total_steps += 1
                
                if (i + 1) % gradient_accumulation_steps == 0:
                    optimizer.step()
                    optimizer.zero_grad()
                
                torch.cuda.empty_cache()
            
            # Print epoch statistics
            print(f"Epoch {epoch+1}/{num_epochs}")
            print(f"Average Forget Loss: {total_forget_loss/total_steps:.4f}")
            print(f"Average Retain Loss: {total_retain_loss/total_steps:.4f}")
            print(f"Average Contrast Loss: {total_contrast_loss/total_steps:.4f}")
        
        if output_path:
            self.model.save_pretrained(output_path)
            self.tokenizer.save_pretrained(output_path)
            print(f"Model saved to {output_path}")



In [20]:
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
import json
import gc
from transformers import BitsAndBytesConfig

class JSONLDataset(Dataset):
    def __init__(self, jsonl_path, tokenizer, max_length=256):
        self.data = []
        self.tokenizer = tokenizer
        self.max_length = max_length

        with open(jsonl_path, "r", encoding='utf-8') as f:
            for line in f:
                try:
                    item = json.loads(line.strip())
                    input_dict = item.get('input', {})
                    
                    for idx in sorted(input_dict.keys()):
                        text = str(input_dict[idx])
                        if text:
                            self.data.append({
                                "input": text,
                                "task": item.get('task', {}).get(idx, "default")
                            })
                            
                except json.JSONDecodeError as e:
                    print(f"Skipping invalid JSON line: {e}")
                except Exception as e:
                    print(f"Error processing line: {e}")

        print(f"Loaded {len(self.data)} items from {jsonl_path}")
        if len(self.data) > 0:
            print(f"Sample input text: {self.data[0]['input'][:100]}...")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        
        try:
            inputs = self.tokenizer(
                item["input"],
                truncation=True,
                max_length=self.max_length,
                padding="max_length",
                return_tensors=None
            )
            
            return {
                "input_ids": torch.tensor(inputs["input_ids"], dtype=torch.long),
                "attention_mask": torch.tensor(inputs["attention_mask"], dtype=torch.long),
            }
        except Exception as e:
            print(f"Error processing item {idx}: {e}")
            raise

from transformers import BitsAndBytesConfig

def setup_unlearning(
    model_path: str,
    tokenizer_name: str,
    retain_path: str,
    forget_path: str,
    batch_size: int = 1,
    max_length: int = 256
):
    try:
        torch.cuda.empty_cache()
        gc.collect()

        quantization_config = BitsAndBytesConfig(
            load_in_8bit=True,
            llm_int8_enable_fp32_cpu_offload=True
        )

        tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token

        model = AutoModelForCausalLM.from_pretrained(
            model_path,
            quantization_config=quantization_config,
            device_map="auto",
            torch_dtype=torch.float16
        )

        reference_model = AutoModelForCausalLM.from_pretrained(
            model_path,
            quantization_config=quantization_config,
            device_map="auto",
            torch_dtype=torch.float16
        )

        retain_dataset = JSONLDataset(retain_path, tokenizer, max_length)
        forget_dataset = JSONLDataset(forget_path, tokenizer, max_length)

        retain_loader = DataLoader(retain_dataset, batch_size=batch_size, shuffle=True)
        forget_loader = DataLoader(forget_dataset, batch_size=batch_size, shuffle=True)

        return model, reference_model, tokenizer, retain_loader, forget_loader

    except Exception as e:
        print(f"Error during setup: {str(e)}")
        raise

class EnhancedUnlearning:
    def __init__(
        self,
        model,
        tokenizer,
        memory_bank_size=1000,
        temperature=0.07,
        forget_weight=1.0,
        retain_weight=0.5,
        contrastive_weight=0.3,
        beta=1.0,
        reference_model=None
    ):
        self.model = model
        self.device = next(model.parameters()).device
        self.tokenizer = tokenizer
        self.memory_bank = SemanticMemoryBank(size=memory_bank_size)
        self.temperature = temperature
        self.forget_weight = forget_weight
        self.retain_weight = retain_weight
        self.contrastive_weight = contrastive_weight
        self.beta = beta
        self.reference_model = reference_model
        self.scaler = torch.cuda.amp.GradScaler()

    def get_semantic_embedding(self, input_ids, attention_mask):
        with torch.no_grad():
            outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
            last_hidden = outputs.hidden_states[-1]
            mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden.size())
            sum_embeddings = torch.sum(last_hidden * mask_expanded, 1)
            sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9)
            return sum_embeddings / sum_mask

    def contrastive_loss(self, anchor, positive, negative):
        anchor_norm = F.normalize(anchor, dim=1)
        positive_norm = F.normalize(positive, dim=1)
        negative_norm = F.normalize(negative, dim=1)
        
        pos_sim = torch.matmul(anchor_norm, positive_norm.t()) / self.temperature
        neg_sim = torch.matmul(anchor_norm, negative_norm.t()) / self.temperature
        
        logits = torch.cat([pos_sim, neg_sim], dim=1)
        labels = torch.zeros(anchor.size(0), device=anchor.device, dtype=torch.long)
        
        return F.cross_entropy(logits, labels)

    def compute_npo_loss(self, model_probs, ref_probs):
        """Compute NPO-based loss for probability ratio optimization"""
        ratio = model_probs / (ref_probs + 1e-10)
        return -2 / self.beta * torch.log(1 + (ratio ** (-self.beta)))

    def unlearning_step(self, forget_batch, retain_batch, optimizer):
        self.model.train()
        optimizer.zero_grad()
        
        forget_batch = {k: v.to(next(self.model.parameters()).device) for k, v in forget_batch.items()}
        retain_batch = {k: v.to(next(self.model.parameters()).device) for k, v in retain_batch.items()}

        with torch.cuda.amp.autocast():
            forget_outputs = self.model(**forget_batch)
            retain_outputs = self.model(**retain_batch)
            
            if self.reference_model:
                with torch.no_grad():
                    ref_forget = self.reference_model(**forget_batch)
                    
                forget_probs = torch.softmax(forget_outputs.logits, dim=-1)
                ref_forget_probs = torch.softmax(ref_forget.logits, dim=-1)
                
                npo_forget_loss = self.compute_npo_loss(forget_probs, ref_forget_probs)
                forget_loss = npo_forget_loss * self.forget_weight
            else:
                forget_loss = forget_outputs.logits.mean() * self.forget_weight
            
            retain_loss = retain_outputs.logits.mean() * self.retain_weight
            total_loss = -forget_loss + retain_loss

        self.scaler.scale(total_loss).backward()
        self.scaler.step(optimizer)
        self.scaler.update()
        
        return forget_loss.item(), retain_loss.item(), 0, total_loss.item()

    def unlearn(
        self,
        forget_loader: DataLoader,
        retain_loader: DataLoader,
        num_epochs: int = 3,
        learning_rate: float = 1e-5,
        gradient_accumulation_steps: int = 2,
        output_path: str = None
    ):
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=learning_rate)
        
        for epoch in range(num_epochs):
            total_forget_loss = 0
            total_retain_loss = 0
            total_contrast_loss = 0
            total_steps = 0
            
            for i, (forget_batch, retain_batch) in enumerate(zip(forget_loader, retain_loader)):
                if forget_batch is None or retain_batch is None:
                    continue
                
                losses = self.unlearning_step(forget_batch, retain_batch, optimizer)
                
                total_forget_loss += losses[0]
                total_retain_loss += losses[1]
                total_contrast_loss += losses[2]
                total_steps += 1
                
                if (i + 1) % gradient_accumulation_steps == 0:
                    optimizer.step()
                    optimizer.zero_grad()
            
                torch.cuda.empty_cache()
            
            print(f"Epoch {epoch+1}/{num_epochs}")
            print(f"Average Forget Loss: {total_forget_loss/total_steps:.4f}")
            print(f"Average Retain Loss: {total_retain_loss/total_steps:.4f}")
            print(f"Average Contrast Loss: {total_contrast_loss/total_steps:.4f}")
        
        if output_path:
            self.model.save_pretrained(output_path)
            self.tokenizer.save_pretrained(output_path)
            print(f"Model saved to {output_path}")

In [21]:



if __name__ == "__main__":
    model_path = 'semeval25-unlearning-model'
    tokenizer_name = 'allenai/OLMo-1B-0724-hf'
    retain_path = "/teamspace/studios/this_studio/train/retain.jsonl"
    forget_path = "/teamspace/studios/this_studio/train/forget.jsonl"

    try:
        model, reference_model, tokenizer, retain_loader, forget_loader = setup_unlearning(
            model_path=model_path,
            tokenizer_name=tokenizer_name,
            retain_path=retain_path,
            forget_path=forget_path,
            batch_size=1
        )

        unlearner = EnhancedUnlearning(
            model=model,
            tokenizer=tokenizer,
            reference_model=reference_model,
            beta=1.0,
            memory_bank_size=1000
        )

        unlearner.unlearn(
            forget_loader=forget_loader,
            retain_loader=retain_loader,
            num_epochs=3,
            learning_rate=1e-5,
            output_path="./unlearned_model"
        )

    except Exception as e:
        print(f"\nError during setup or training: {str(e)}")
        print("\nDebugging information:")
        import traceback
        traceback.print_exc()

Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]

Loaded 1136 items from /teamspace/studios/this_studio/train/retain.jsonl
Sample input text: Fredericka Amber was born on December 21, 1969. Her Social Security number is 900-22-6238 and her ph...
Loaded 1112 items from /teamspace/studios/this_studio/train/forget.jsonl
Sample input text: In the mystical city of Deadesius, where magic and mystery intertwined, two sorceresses, Marcile and...

Error during setup or training: grad can be implicitly created only for scalar outputs

Debugging information:


Traceback (most recent call last):
  File "/tmp/ipykernel_1646/454195701.py", line 24, in <module>
    unlearner.unlearn(
  File "/tmp/ipykernel_1646/2073610325.py", line 214, in unlearn
    losses = self.unlearning_step(forget_batch, retain_batch, optimizer)
  File "/tmp/ipykernel_1646/2073610325.py", line 187, in unlearning_step
    self.scaler.scale(total_loss).backward()
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/_tensor.py", line 522, in backward
    torch.autograd.backward(
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/autograd/__init__.py", line 259, in backward
    grad_tensors_ = _make_grads(tensors, grad_tensors_, is_grads_batched=False)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/autograd/__init__.py", line 132, in _make_grads
    raise RuntimeError(
RuntimeError: grad can be implicitly created only for scalar outputs


In [1]:
!pip install -U bitsandbytes



In [6]:
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
import json
import gc
from transformers import BitsAndBytesConfig

class JSONLDataset(Dataset):
    def __init__(self, jsonl_path, tokenizer, max_length=256):
        self.data = []
        self.tokenizer = tokenizer
        self.max_length = max_length

        with open(jsonl_path, "r", encoding='utf-8') as f:
            for line in f:
                try:
                    item = json.loads(line.strip())
                    input_dict = item.get('input', {})
                    
                    for idx in sorted(input_dict.keys()):
                        text = str(input_dict[idx])
                        if text:
                            self.data.append({
                                "input": text,
                                "task": item.get('task', {}).get(idx, "default")
                            })
                            
                except json.JSONDecodeError as e:
                    print(f"Skipping invalid JSON line: {e}")
                except Exception as e:
                    print(f"Error processing line: {e}")

        print(f"Loaded {len(self.data)} items from {jsonl_path}")
        if len(self.data) > 0:
            print(f"Sample input text: {self.data[0]['input'][:100]}...")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        
        try:
            inputs = self.tokenizer(
                item["input"],
                truncation=True,
                max_length=self.max_length,
                padding="max_length",
                return_tensors=None
            )
            
            return {
                "input_ids": torch.tensor(inputs["input_ids"], dtype=torch.long),
                "attention_mask": torch.tensor(inputs["attention_mask"], dtype=torch.long),
            }
        except Exception as e:
            print(f"Error processing item {idx}: {e}")
            raise

from transformers import BitsAndBytesConfig

def setup_unlearning(
    model_path: str,
    tokenizer_name: str,
    retain_path: str,
    forget_path: str,
    batch_size: int = 1,
    max_length: int = 256
):
    try:
        torch.cuda.empty_cache()
        gc.collect()

        quantization_config = BitsAndBytesConfig(
            load_in_8bit=True,
            llm_int8_enable_fp32_cpu_offload=True
        )

        tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token

        model = AutoModelForCausalLM.from_pretrained(
            model_path,
            quantization_config=quantization_config,
            device_map="auto",
            torch_dtype=torch.float16
        )

        reference_model = AutoModelForCausalLM.from_pretrained(
            model_path,
            quantization_config=quantization_config,
            device_map="auto",
            torch_dtype=torch.float16
        )

        retain_dataset = JSONLDataset(retain_path, tokenizer, max_length)
        forget_dataset = JSONLDataset(forget_path, tokenizer, max_length)

        retain_loader = DataLoader(retain_dataset, batch_size=batch_size, shuffle=True)
        forget_loader = DataLoader(forget_dataset, batch_size=batch_size, shuffle=True)

        return model, reference_model, tokenizer, retain_loader, forget_loader

    except Exception as e:
        print(f"Error during setup: {str(e)}")
        raise
    
class SemanticMemoryBank:
    def __init__(self, size=1000):
        self.size = size
        self.forget_embeddings = []
        self.retain_embeddings = []
    
    def update(self, forget_emb, retain_emb):
        self.forget_embeddings.extend(forget_emb)
        self.retain_embeddings.extend(retain_emb)
        
        if len(self.forget_embeddings) > self.size:
            self.forget_embeddings = self.forget_embeddings[-self.size:]
        if len(self.retain_embeddings) > self.size:
            self.retain_embeddings = self.retain_embeddings[-self.size:]


class EnhancedUnlearning:
    def __init__(
        self,
        model,
        tokenizer,
        memory_bank_size=1000,
        temperature=0.07,
        forget_weight=1.0,
        retain_weight=0.5,
        contrastive_weight=0.3,
        beta=1.0,
        reference_model=None
    ):
        self.model = model
        self.device = next(model.parameters()).device
        self.tokenizer = tokenizer
        self.memory_bank = SemanticMemoryBank(size=memory_bank_size)
        self.temperature = temperature
        self.forget_weight = forget_weight
        self.retain_weight = retain_weight
        self.contrastive_weight = contrastive_weight
        self.beta = beta
        self.reference_model = reference_model
        self.scaler = torch.cuda.amp.GradScaler()

    def get_semantic_embedding(self, input_ids, attention_mask):
        with torch.no_grad():
            outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
            last_hidden = outputs.hidden_states[-1]
            mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden.size())
            sum_embeddings = torch.sum(last_hidden * mask_expanded, 1)
            sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9)
            return sum_embeddings / sum_mask

    def contrastive_loss(self, anchor, positive, negative):
        anchor_norm = F.normalize(anchor, dim=1)
        positive_norm = F.normalize(positive, dim=1)
        negative_norm = F.normalize(negative, dim=1)
        
        pos_sim = torch.matmul(anchor_norm, positive_norm.t()) / self.temperature
        neg_sim = torch.matmul(anchor_norm, negative_norm.t()) / self.temperature
        
        logits = torch.cat([pos_sim, neg_sim], dim=1)
        labels = torch.zeros(anchor.size(0), device=anchor.device, dtype=torch.long)
        
        return F.cross_entropy(logits, labels)

    def compute_npo_loss(self, model_probs, ref_probs):
        """Compute NPO-based loss for probability ratio optimization"""
        ratio = model_probs / (ref_probs + 1e-10)
        return -2 / self.beta * torch.log(1 + (ratio ** (-self.beta)))

    def unlearning_step(self, forget_batch, retain_batch, optimizer):
        self.model.train()
        optimizer.zero_grad()
        
        # Move batches to appropriate device
        device = next(self.model.parameters()).device
        forget_batch = {k: v.to(device) for k, v in forget_batch.items()}
        retain_batch = {k: v.to(device) for k, v in retain_batch.items()}

        with torch.cuda.amp.autocast():
            # Get model outputs
            forget_outputs = self.model(**forget_batch)
            retain_outputs = self.model(**retain_batch)
            
            # Calculate losses using mean reduction
            if self.reference_model:
                with torch.no_grad():
                    ref_forget = self.reference_model(**forget_batch)
                    
                forget_probs = torch.softmax(forget_outputs.logits, dim=-1)
                ref_forget_probs = torch.softmax(ref_forget.logits, dim=-1)
                
                # Make sure loss is scalar
                npo_forget_loss = (forget_probs - ref_forget_probs).pow(2).mean()
                forget_loss = npo_forget_loss * self.forget_weight
            else:
                forget_loss = forget_outputs.logits.mean() * self.forget_weight
            
            # Make retain loss scalar
            retain_loss = retain_outputs.logits.mean() * self.retain_weight
            
            # Compute total loss
            total_loss = (-forget_loss + retain_loss).mean()  # Ensure scalar output
            
        # Backward pass
        self.scaler.scale(total_loss).backward()
        
        # Clip gradients if needed
        if hasattr(self, 'max_grad_norm'):
            self.scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
        
        self.scaler.step(optimizer)
        self.scaler.update()
        
        return (
            forget_loss.item(),
            retain_loss.item(),
            0,  # No contrastive loss for now
            total_loss.item()
        )

    def unlearn(
        self,
        forget_loader: DataLoader,
        retain_loader: DataLoader,
        num_epochs: int = 3,
        learning_rate: float = 1e-5,
        gradient_accumulation_steps: int = 2,
        output_path: str = None
    ):
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=learning_rate)
        
        for epoch in range(num_epochs):
            total_forget_loss = 0
            total_retain_loss = 0
            total_contrast_loss = 0
            total_steps = 0
            
            for i, (forget_batch, retain_batch) in enumerate(zip(forget_loader, retain_loader)):
                if forget_batch is None or retain_batch is None:
                    continue
                
                losses = self.unlearning_step(forget_batch, retain_batch, optimizer)
                
                total_forget_loss += losses[0]
                total_retain_loss += losses[1]
                total_contrast_loss += losses[2]
                total_steps += 1
                
                if (i + 1) % gradient_accumulation_steps == 0:
                    optimizer.step()
                    optimizer.zero_grad()
            
                torch.cuda.empty_cache()
            
            print(f"Epoch {epoch+1}/{num_epochs}")
            print(f"Average Forget Loss: {total_forget_loss/total_steps:.4f}")
            print(f"Average Retain Loss: {total_retain_loss/total_steps:.4f}")
            print(f"Average Contrast Loss: {total_contrast_loss/total_steps:.4f}")
        
        if output_path:
            self.model.save_pretrained(output_path)
            self.tokenizer.save_pretrained(output_path)
            print(f"Model saved to {output_path}")

# Final Working Code

In [None]:
class JSONLDataset(Dataset):
    def __init__(self, jsonl_path, tokenizer, max_length=256):
        self.data = []
        self.tokenizer = tokenizer
        self.max_length = max_length

        with open(jsonl_path, "r", encoding='utf-8') as f:
            for line in f:
                try:
                    item = json.loads(line.strip())
                    input_dict = item.get('input', {})
                    
                    for idx in sorted(input_dict.keys()):
                        text = str(input_dict[idx])
                        if text:
                            self.data.append({
                                "input": text,
                                "task": item.get('task', {}).get(idx, "default")
                            })
                            
                except json.JSONDecodeError as e:
                    print(f"Skipping invalid JSON line: {e}")
                except Exception as e:
                    print(f"Error processing line: {e}")

        print(f"Loaded {len(self.data)} items from {jsonl_path}")
        if len(self.data) > 0:
            print(f"Sample input text: {self.data[0]['input'][:100]}...")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        
        try:
            inputs = self.tokenizer(
                item["input"],
                truncation=True,
                max_length=self.max_length,
                padding="max_length",
                return_tensors=None
            )
            
            return {
                "input_ids": torch.tensor(inputs["input_ids"], dtype=torch.long),
                "attention_mask": torch.tensor(inputs["attention_mask"], dtype=torch.long),
            }
        except Exception as e:
            print(f"Error processing item {idx}: {e}")
            raise


In [8]:
from transformers import BitsAndBytesConfig
import torch
import torch.nn.functional as F
from torch.cuda.amp import autocast
import gc

def setup_unlearning(
    model_path: str,
    tokenizer_name: str,
    retain_path: str,
    forget_path: str,
    batch_size: int = 1,
    max_length: int = 256
):
    try:
        # Clear memory
        torch.cuda.empty_cache()
        gc.collect()

        # Advanced quantization config
        double_quant_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_compute_dtype=torch.bfloat16
        )

        print(f"Loading tokenizer from {tokenizer_name}")
        tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token

        print(f"Loading model from {model_path}")
        model = AutoModelForCausalLM.from_pretrained(
            model_path,
            quantization_config=double_quant_config,
            device_map="auto",
            torch_dtype=torch.bfloat16
        )
        model.gradient_checkpointing_enable()

        print("Loading reference model")
        reference_model = AutoModelForCausalLM.from_pretrained(
            model_path,
            quantization_config=double_quant_config,
            device_map="auto",
            torch_dtype=torch.bfloat16
        )

        retain_dataset = JSONLDataset(retain_path, tokenizer, max_length)
        forget_dataset = JSONLDataset(forget_path, tokenizer, max_length)

        retain_loader = DataLoader(retain_dataset, batch_size=batch_size, shuffle=True)
        forget_loader = DataLoader(forget_dataset, batch_size=batch_size, shuffle=True)

        return model, reference_model, tokenizer, retain_loader, forget_loader

    except Exception as e:
        print(f"Error during setup: {str(e)}")
        raise

class EnhancedUnlearning:
    def __init__(
        self,
        model,
        tokenizer,
        memory_bank_size=1000,
        temperature=0.07,
        forget_weight=1.0,
        retain_weight=0.5,
        contrastive_weight=0.3,
        beta=1.0,
        reference_model=None
    ):
        self.model = model
        self.tokenizer = tokenizer
        self.temperature = temperature
        self.forget_weight = forget_weight
        self.retain_weight = retain_weight
        self.contrastive_weight = contrastive_weight
        self.beta = beta
        self.reference_model = reference_model
        self.max_grad_norm = 1.0

    def compute_loss(self, outputs, labels):
        logits = outputs.logits
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), 
                             shift_labels.view(-1), 
                             reduction='mean')
        return loss

    def unlearning_step(self, forget_batch, retain_batch, optimizer):
        self.model.train()
        optimizer.zero_grad()
        
        device = next(self.model.parameters()).device
        forget_batch = {k: v.to(device) for k, v in forget_batch.items()}
        retain_batch = {k: v.to(device) for k, v in retain_batch.items()}

        with autocast(dtype=torch.bfloat16):
            # Compute forget loss
            forget_outputs = self.model(**forget_batch)
            forget_loss = self.compute_loss(forget_outputs, forget_batch['input_ids'])
            
            if self.reference_model is not None:
                with torch.no_grad():
                    ref_outputs = self.reference_model(**forget_batch)
                ref_loss = self.compute_loss(ref_outputs, forget_batch['input_ids'])
                forget_loss = (forget_loss - ref_loss).abs().mean() * self.forget_weight
            else:
                forget_loss = forget_loss * self.forget_weight

            # Compute retain loss
            retain_outputs = self.model(**retain_batch)
            retain_loss = self.compute_loss(retain_outputs, retain_batch['input_ids']) * self.retain_weight

            # Total loss
            total_loss = (-forget_loss + retain_loss).mean()

        # Manual backward pass
        total_loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
        
        # Optimizer step
        optimizer.step()
        
        return (
            forget_loss.item(),
            retain_loss.item(),
            0.0,
            total_loss.item()
        )

    def unlearn(
        self,
        forget_loader: DataLoader,
        retain_loader: DataLoader,
        num_epochs: int = 3,
        learning_rate: float = 1e-5,
        output_path: str = None
    ):
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=learning_rate)
        
        for epoch in range(num_epochs):
            total_forget_loss = 0
            total_retain_loss = 0
            total_steps = 0
            
            for forget_batch, retain_batch in zip(forget_loader, retain_loader):
                losses = self.unlearning_step(forget_batch, retain_batch, optimizer)
                
                total_forget_loss += losses[0]
                total_retain_loss += losses[1]
                total_steps += 1

                if total_steps % 10 == 0:
                    print(f"Step {total_steps}: forget_loss={losses[0]:.4f}, retain_loss={losses[1]:.4f}")
                
                torch.cuda.empty_cache()
            
            avg_forget_loss = total_forget_loss / total_steps
            avg_retain_loss = total_retain_loss / total_steps
            
            print(f"Epoch {epoch+1}/{num_epochs}")
            print(f"Average Forget Loss: {avg_forget_loss:.4f}")
            print(f"Average Retain Loss: {avg_retain_loss:.4f}")
        
        if output_path:
            self.model.save_pretrained(output_path)
            self.tokenizer.save_pretrained(output_path)

In [9]:



if __name__ == "__main__":
    model_path = 'semeval25-unlearning-model'
    tokenizer_name = 'allenai/OLMo-1B-0724-hf'
    retain_path = "/teamspace/studios/this_studio/train/retain.jsonl"
    forget_path = "/teamspace/studios/this_studio/train/forget.jsonl"

    try:
        model, reference_model, tokenizer, retain_loader, forget_loader = setup_unlearning(
            model_path=model_path,
            tokenizer_name=tokenizer_name,
            retain_path=retain_path,
            forget_path=forget_path,
            batch_size=1
        )

        unlearner = EnhancedUnlearning(
            model=model,
            tokenizer=tokenizer,
            reference_model=reference_model,
            beta=1.0,
            memory_bank_size=1000
        )

        unlearner.unlearn(
            forget_loader=forget_loader,
            retain_loader=retain_loader,
            num_epochs=3,
            learning_rate=1e-5,
            output_path="./unlearned_model"
        )

    except Exception as e:
        print(f"\nError during setup or training: {str(e)}")
        print("\nDebugging information:")
        import traceback
        traceback.print_exc()

Loading tokenizer from allenai/OLMo-1B-0724-hf
Loading model from semeval25-unlearning-model


Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]

Loading reference model


Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]

Loaded 1136 items from /teamspace/studios/this_studio/train/retain.jsonl
Sample input text: Fredericka Amber was born on December 21, 1969. Her Social Security number is 900-22-6238 and her ph...
Loaded 1112 items from /teamspace/studios/this_studio/train/forget.jsonl
Sample input text: In the mystical city of Deadesius, where magic and mystery intertwined, two sorceresses, Marcile and...
Step 10: forget_loss=0.3388, retain_loss=7.6547
Step 20: forget_loss=0.6909, retain_loss=7.8953
Step 30: forget_loss=1.4393, retain_loss=7.2507
Step 40: forget_loss=1.3568, retain_loss=7.1351
Step 50: forget_loss=1.9899, retain_loss=1.9247
Step 60: forget_loss=2.1625, retain_loss=6.2816
Step 70: forget_loss=1.4600, retain_loss=6.6625
Step 80: forget_loss=2.9014, retain_loss=6.4581
Step 90: forget_loss=0.2990, retain_loss=6.1174
Step 100: forget_loss=3.2637, retain_loss=6.7708
Step 110: forget_loss=1.8208, retain_loss=5.9204
Step 120: forget_loss=4.3596, retain_loss=6.5308
Step 130: forget_loss=4.9296,