In [None]:
import os
import subprocess

import pandas as pd
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torchvision.models import resnet18
from torch.utils.data import DataLoader, Dataset

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' 

## Use GPU only

In [None]:
if DEVICE != 'cuda':
    raise RuntimeError('Make sure you have added an accelerator to your notebook; the submission will fail otherwise!')

## Helper functions for loading hidden dataset

- Location of dataset : /kaggle/input/neurips-2023-machine-unlearning/
- Contents of each record: Image, Image ID, Age roup (Target), Age, Person ID
- retain.csv, forget.csv, validation.csv provided by competition
- Use dataset loader with shuffle=True (to include randomness between different runs - 512 different checkpoints)

In [None]:

def load_example(df_row):
    
    image = torchvision.io.read_image(df_row['image_path'])
    # For each person, the following information is available
    result = {
        'image': image,
        'image_id': df_row['image_id'],
        'age_group': df_row['age_group'],
        'age': df_row['age'],
        'person_id': df_row['person_id']
    }
    return result


class HiddenDataset(Dataset):

    def __init__(self, split='train'):
        
        super().__init__()
        
        self.examples = []
        # location of Dataset + type of data
        df = pd.read_csv(f'/kaggle/input/neurips-2023-machine-unlearning/{split}.csv')
        
        # Using Image IDs, retrieve images
        df['image_path'] = df['image_id'].apply(
            lambda x: os.path.join('/kaggle/input/neurips-2023-machine-unlearning/', 'images', x.split('-')[0], x.split('-')[1] + '.png'))
        df = df.sort_values(by='image_path')
        
        # Split records for each individual
        df.apply(lambda row: self.examples.append(load_example(row)), axis=1)
        
        if len(self.examples) == 0:
            raise ValueError('No examples.')

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        example = self.examples[idx]
        image = example['image']
        image = image.to(torch.float32)
        example['image'] = image
        return example


def get_dataset(batch_size):
    
    # Load data for Retain, Forget and Validation datasets
    retain_ds = HiddenDataset(split='retain')
    forget_ds = HiddenDataset(split='forget')
    val_ds = HiddenDataset(split='validation')

    # Use dataloader to save RAM
    retain_loader = DataLoader(retain_ds, batch_size=batch_size, shuffle=True)
    forget_loader = DataLoader(forget_ds, batch_size=batch_size, shuffle=True)
    validation_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=True)

    return retain_loader, forget_loader, validation_loader

## Unlearning operation
Strategy 3: Based on paper 'Selective forgetting in Deep Networks' by A. Golatkar et.al

In [None]:
pip install pyhessian

In [None]:
# Version 1
import pyhessian
import copy

def get_mean_var(p,hessian, alpha=3e-6):
        
    var = copy.deepcopy(1./(hessian+1e-8))
    var = var.clamp(max=1e3)   
    var = alpha * var

    mu = copy.deepcopy(p.data0.clone())
    return mu, var

def unlearning(
    net, 
    retain_loader, 
    forget_loader, 
    val_loader):

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=0.005,
                      momentum=0.9, weight_decay=5e-4)

    for p in itertools.chain(net.parameters()):
        p.data0 = copy.deepcopy(p.data.clone())
        
    hessian_accumulator = torch.zeros_like(model.parameters())
    net.train()

    for sample in retain_loader:
        
        inputs = sample["image"]
        targets = sample["age_group"]
        inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
        
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        
        loss.backward()
        optimizer.step()
        
        # Compute Hessian for current batch
        batch_hessian = pyhessian.hessian(model, inputs)

        # Accumulate Hessian contributions
        hessian_accumulator += batch_hessian
        
    # Compute overall Hessian
    overall_hessian = hessian_accumulator / len(retain_loader)    
    
    alpha = 1e-6
    torch.manual_seed(1756)
    for i, p in enumerate(net.parameters()):
        mu, var = get_mean_var(p, overall_hessian, alpha=alpha)
        p.data = mu + var.sqrt() * torch.empty_like(p.data0).normal_()
    
    net.eval()
    

In [None]:
# version 2
import copy
import itertools
from tqdm import tqdm

def get_mean_var(p, is_base_dist=False, alpha=3e-6):
        
    var = copy.deepcopy(1./(p.grad2_acc+1e-8))
    
    var = var.clamp(max=1e3)
    
    var = alpha * var

    mu = copy.deepcopy(p.data0.clone())


def unlearning(
    net, 
    retain_loader, 
    forget_loader, 
    val_loader):

    # netf = copy.deepcopy(net)
    for p in itertools.chain(net.parameters()):
        p.data0 = copy.deepcopy(p.data.clone())
    
    net.train()
    loss_fn = nn.CrossEntropyLoss()

    for p in net.parameters():
        p.grad2_acc = 0
        
    
    for sample in tqdm(retain_loader):

        data = sample["image"]
        orig_target = sample["age_group"]
        
        data, orig_target = data.to(DEVICE), orig_target.to(DEVICE)
        
        output = net(data)
        
        prob = torch.nn.functional.softmax(output, dim=-1).data

        for y in range(output.shape[1]):
            
            target = torch.empty_like(orig_target).fill_(y)
            
            loss = loss_fn(output, target)
            
            net.zero_grad()
            
            loss.backward(retain_graph=True)
            
            for p in net.parameters():
                if p.requires_grad:
                    p.grad2_acc += (prob[:, y] * p.grad.data.pow(2))

                    
    for p in net.parameters():
        
        p.grad2_acc /= len(retain_loader)

        
    alpha = 1e-6
    torch.manual_seed(1756)
    for i, p in enumerate(net.parameters()):
        mu, var = get_mean_var(p, False, alpha=alpha)
        p.data = mu + var.sqrt() * torch.empty_like(p.data0).normal_()
    
    net.eval()
    return net

## Access dataset, load model, call unlearning function and generate submission file with unlearned model

In [None]:
# dummy pathway for local - does not exist in submission
if os.path.exists('/kaggle/input/neurips-2023-machine-unlearning/empty.txt'):
    subprocess.run('touch submission.zip', shell=True)
    
else:
    # tmp directory - cannot save in home dir
    os.makedirs('/kaggle/tmp', exist_ok=True)
    # batch size - 128
    retain_loader, forget_loader, validation_loader = get_dataset(1)
    # load model template
    net = resnet18(weights=None, num_classes=10)
    net.to(DEVICE)
    # load model and call unlearning function 512 times
    for i in range(512):
        net.load_state_dict(torch.load('/kaggle/input/neurips-2023-machine-unlearning/original_model.pth'))
        unlearning(net, retain_loader, forget_loader, validation_loader)
        state = net.state_dict()
        # save as checkpoint
        torch.save(state, f'/kaggle/tmp/unlearned_checkpoint_{i}.pth')
        
    # Ensure that submission.zip will contain exactly 512 checkpoints 
    # (if this is not the case, an exception will be thrown).
    unlearned_ckpts = os.listdir('/kaggle/tmp')
    if len(unlearned_ckpts) != 512:
        raise RuntimeError('Expected exactly 512 checkpoints. The submission will throw an exception otherwise.')
    # zip it and create submission
    subprocess.run('zip submission.zip /kaggle/tmp/*.pth', shell=True)