In [None]:
import os
import gc
import math
import random
import numpy as np
import subprocess

import pandas as pd
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import prune
import torch.optim as optim
from torchvision.models import resnet18
from torch.utils.data import DataLoader, Dataset, TensorDataset

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' 

In [None]:
# It's really important to add an accelerator to your notebook, as otherwise the submission will fail.
# We recomment using the P100 GPU rather than T4 as it's faster and will increase the chances of passing the time cut-off threshold.

if DEVICE != 'cuda':
    raise RuntimeError('Make sure you have added an accelerator to your notebook; the submission will fail otherwise!')

In [None]:
# Helper functions for loading the hidden dataset.

def load_example(df_row):
    image = torchvision.io.read_image(df_row['image_path'])
    result = {
        'image': image,
        'image_id': df_row['image_id'],
        'age_group': df_row['age_group'],
        'age': df_row['age'],
        'person_id': df_row['person_id']
    }
    return result


class HiddenDataset(Dataset):
    '''The hidden dataset.'''
    def __init__(self, split='train'):
        super().__init__()
        self.examples = []

        df = pd.read_csv(f'/kaggle/input/neurips-2023-machine-unlearning/{split}.csv')
        df['image_path'] = df['image_id'].apply(
            lambda x: os.path.join('/kaggle/input/neurips-2023-machine-unlearning/', 'images', x.split('-')[0], x.split('-')[1] + '.png'))
        df = df.sample(frac=1).reset_index(drop=True)
        df.apply(lambda row: self.examples.append(load_example(row)), axis=1)
        if len(self.examples) == 0:
            raise ValueError('No examples.')

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        example = self.examples[idx]
        image = example['image']
        image = image.to(torch.float32)
        example['image'] = image
        return example


def get_dataset(batch_size):
    '''Get the dataset.'''
    retain_ds = HiddenDataset(split='retain')
    forget_ds = HiddenDataset(split='forget')
    val_ds = HiddenDataset(split='validation')

    retain_loader = DataLoader(retain_ds, batch_size=batch_size, shuffle=False)
    forget_loader = DataLoader(forget_ds, batch_size=batch_size, shuffle=False)
    validation_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=True)

    return retain_loader, forget_loader, validation_loader

In [None]:
def global_unstructure_prune(model, pruning_amount=0.2):

    parameters_to_prune = []
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
            parameters_to_prune.append((module, 'weight'))

    # Global pruning
    prune.global_unstructured(
        parameters_to_prune,
        pruning_method=prune.L1Unstructured,
        amount=pruning_amount
    )

    # Make the pruning permanent
    for module, param_name in parameters_to_prune:
        prune.remove(module, param_name)

In [None]:
def accuracy(net, loader):
    """Return accuracy on a dataset given by the data loader."""
    correct = 0
    total = 0
    for sample in loader:
        inputs = sample["image"]
        targets = sample["age_group"]
        inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
        outputs = net(inputs)
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    return correct / total

In [None]:
def calculate_kl_loss(student_logits, teacher_logits, T=2.0, forget_T=2.0, forget_flag=False):
    
    teacher_logits = teacher_logits/T

    if forget_flag:
        teacher_logits = teacher_logits/forget_T
        teacher_logits = teacher_logits + 0.05*torch.rand(teacher_logits.shape).to(DEVICE)

    # Calculate soft labels from teacher
    teacher_probs = F.softmax(teacher_logits, dim=1)

    # Compute distillation loss
    student_log_probs = F.log_softmax(student_logits/T, dim=1)
    distillation_loss = F.kl_div(student_log_probs, teacher_probs, reduction='batchmean') * (T * T)

    return distillation_loss

In [None]:
# Function to update learning rate
def adjust_learning_rate(optimizer, current_batch, total_batches, initial_lr):
    """Sets the learning rate for warmup over total_batches"""
    lr = initial_lr * (current_batch / total_batches)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

In [None]:
def unlearning(
    net, 
    retain_loader, 
    forget_loader, 
    val_loader,
    class_weights=None,
):
    
    
    '''
    Get teacher logits
    '''
    
    # Retain logits
    teacher_retain_tensor = torch.zeros(len(retain_loader.dataset), 10)
    start_idx = 0
    with torch.no_grad():
        for sample in retain_loader:
            end_idx = start_idx + sample["image"].shape[0]
            outputs = net(sample["image"].to(DEVICE))
            teacher_retain_tensor[start_idx:end_idx] = outputs.cpu()
            start_idx = end_idx

    retain_logit_loader = DataLoader(teacher_retain_tensor, batch_size=64, shuffle=False)
    
    # Forget logits
    teacher_forget_tensor = torch.zeros(len(forget_loader.dataset), 10)
    start_idx = 0
    with torch.no_grad():
        for sample in forget_loader:
            end_idx = start_idx + sample["image"].shape[0]
            outputs = net(sample["image"].to(DEVICE))
            teacher_forget_tensor[start_idx:end_idx] = outputs.cpu()
            start_idx = end_idx

    forget_logit_loader = DataLoader(teacher_forget_tensor, batch_size=64, shuffle=False)

    
    '''
    Get student
    '''
    
    # Apply pruning
    pct = 0.90
    global_unstructure_prune(net, pct)
    
    
    '''
    Training parameters
    '''
    
    T = 2.0
    forget_T = 2.0
    alpha = 0.9
    epochs = 3
    pct_of_retain_for_ft = 0.55 # Be aware of the % below
    calc_val_acc_every = 0.20 # Pct of the retrain set
    forget_KL_CE_ratio = 0.5
    beta = 0.2 # Multiplier to CE for gradient ascent
    
    initial_lr = 0.001/2
    total_samples = len(retain_loader.dataset)
    batch_size = retain_loader.batch_size
    batches_per_epoch  = math.ceil(total_samples / batch_size)
    total_batches = epochs * batches_per_epoch
    warmup_batches = math.ceil(0.3*batches_per_epoch)
    warmup_current_batch = 0
    
    
    optimizer = optim.SGD(net.parameters(), lr=initial_lr, momentum=0.90, weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    
    
    '''
    Training loop
    '''
    
    net.train()

    for ep in range(epochs):
        
        # Initialize flag for depleted iterator
        depleted = False
        
        # Initialize iterators at the beginning of each epoch
        current_batch = 0
        best_val_acc = 0
        iter_retain = iter(retain_loader)
        iter_teacher_retain = iter(retain_logit_loader)
        
        ''' Retain fine-tuning '''
        
        while not depleted:

            try:
                sample = next(iter_retain)
                teacher_logits = next(iter_teacher_retain)

            except StopIteration:
                depleted = True
                break  # Skip to the next iteration
            
            
            # Only use a % of the retain set
            if torch.multinomial(torch.tensor([pct_of_retain_for_ft,1-pct_of_retain_for_ft], dtype=torch.float32), 1).item() > 0.5:
                continue
            
            inputs = sample["image"]
            targets = sample["age_group"]
            
            inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
            teacher_logits = teacher_logits.to(DEVICE)
        
            current_batch += 1
            warmup_current_batch += 1

            # Warm-up for the first 'warmup_batches' batches
            if warmup_current_batch <= warmup_batches:
                adjust_learning_rate(optimizer, warmup_current_batch, warmup_batches, initial_lr)
        
            optimizer.zero_grad()
        
            # Forward pass student
            student_logits = net(inputs)
            
            # Calculate losses
            distillation_loss = calculate_kl_loss(student_logits, teacher_logits, T=T, forget_flag=False)
            criterion = nn.CrossEntropyLoss(weight=class_weights)
            classification_loss = criterion(student_logits, targets)
            loss = alpha*distillation_loss + (1-alpha)*classification_loss
            loss.backward()
            optimizer.step()
            
            
            ''' Validation early-stop '''
            if current_batch%(math.ceil(calc_val_acc_every*batches_per_epoch)-1)==0:
                with torch.no_grad():
                    net.eval()
                    val_acc = accuracy(net, val_loader)
                    if val_acc < best_val_acc:
                        # Restore model to previous checkpoint
                        checkpoint = torch.load(f'/kaggle/tmp/temp_checkpoint.pth')
                        net.load_state_dict(checkpoint['model'])
                        optimizer.load_state_dict(checkpoint['optimizer'])
                        # Apply pruning
                        pct = 0.10
                        global_unstructure_prune(net, pct)
                        break
                    else:
                        best_val_acc = val_acc
                        # Save checkpoint
                        torch.save({
                            'optimizer': optimizer.state_dict(),
                            'model': net.state_dict(),
                        }, f'/kaggle/tmp/temp_checkpoint.pth')
                net.train()
            
            
        ''' Forget fine-tuning '''
        
        if ep!= epochs-1:

            for sample, teacher_logits in zip(forget_loader, forget_logit_loader):
                inputs = sample["image"]
                targets = sample["age_group"]

                inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
                teacher_logits = teacher_logits.to(DEVICE)

                warmup_current_batch += 1

                # Warm-up for the first 'warmup_batches' batches
                if warmup_current_batch <= warmup_batches:
                    adjust_learning_rate(optimizer, warmup_current_batch, warmup_batches, initial_lr)

                optimizer.zero_grad()

                # Forward pass student
                student_logits = net(inputs)

                # Calculate losses
                distillation_loss = calculate_kl_loss(student_logits, teacher_logits, T=T, forget_T=forget_T, forget_flag=True)
                criterion = nn.CrossEntropyLoss()
                classification_loss = criterion(student_logits, targets)

                if torch.multinomial(torch.tensor([forget_KL_CE_ratio,1-forget_KL_CE_ratio], dtype=torch.float32), 1).item() < 0.5:
                    loss = distillation_loss
                else:
                    loss = -beta*classification_loss

                loss.backward()
                optimizer.step()
            
            
        scheduler.step()
        
    net.eval()

In [None]:
if os.path.exists('/kaggle/input/neurips-2023-machine-unlearning/empty.txt'):
    # mock submission
    subprocess.run('touch submission.zip', shell=True)
else:
    # Load the class weights from json file of unknown structure
    import json
 
    class_weights_fname = "/kaggle/input/neurips-2023-machine-unlearning/age_class_weights.json"
    with open(class_weights_fname) as f:
        # Returns JSON object as a dictionary
        class_weights_dict = json.load(f)

    # The keys should be the age_group IDs, mapping to the number of occurences for that age group.
    # But keys are always strings in JSON files (there are no int keys in JSON). We can't be sure
    # the keys in the dict are in the correct order, so let's convert the dictionary into a list
    # by using the expected keys.
    class_weights = [class_weights_dict[str(key)] for key in range(len(class_weights_dict))]
    # Convert list of weights into a float32 tensor
    class_weights = torch.tensor(class_weights).to(DEVICE, dtype=torch.float32)
    # The JSON file actually contains number of occurances. To correct for imbalance, the
    # weighting should be the reciprocal of the count instead.
    class_weights = 1.0 / class_weights

    # Note: it's really important to create the unlearned checkpoints outside of the working directory 
    # as otherwise this notebook may fail due to running out of disk space.
    # The below code saves them in /kaggle/tmp to avoid that issue.
    
    os.makedirs('/kaggle/tmp', exist_ok=True)

    for i in range(512):
        retain_loader, forget_loader, validation_loader = get_dataset(64)
        net = resnet18(weights=None, num_classes=10)
        net.to(DEVICE)
        net.load_state_dict(torch.load('/kaggle/input/neurips-2023-machine-unlearning/original_model.pth'))
        unlearning(net, retain_loader, forget_loader, validation_loader, class_weights=class_weights)
        del retain_loader
        del forget_loader
        del validation_loader
        gc.collect()
        state = net.state_dict()
        torch.save(state, f'/kaggle/tmp/unlearned_checkpoint_{i}.pth')
        
    
    # Ensure that submission.zip will contain exactly 512 checkpoints 
    # (if this is not the case, an exception will be thrown).
#     unlearned_ckpts = os.listdir('/kaggle/tmp')
#     if len(unlearned_ckpts) != 512:
#         raise RuntimeError('Expected exactly 512 checkpoints. The submission will throw an exception otherwise.')
        
    subprocess.run('zip submission.zip /kaggle/tmp/unlearned_*.pth', shell=True)
