In [None]:
import os
import math
import copy
import subprocess

import pandas as pd
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torch.autograd import grad
from torchvision.models import resnet18
from torch.utils.data import DataLoader, Dataset

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' 

In [None]:
# It's really important to add an accelerator to your notebook, as otherwise the submission will fail.
# We recomment using the P100 GPU rather than T4 as it's faster and will increase the chances of passing the time cut-off threshold.

if DEVICE != 'cuda':
    raise RuntimeError('Make sure you have added an accelerator to your notebook; the submission will fail otherwise!')

In [None]:
# Helper functions for loading the hidden dataset.

def load_example(df_row):
    image = torchvision.io.read_image(df_row['image_path'])
    result = {
        'image': image,
        'image_id': df_row['image_id'],
        'age_group': df_row['age_group'],
        'age': df_row['age'],
        'person_id': df_row['person_id']
    }
    return result


class HiddenDataset(Dataset):
    '''The hidden dataset.'''
    def __init__(self, split='train'):
        super().__init__()
        self.examples = []

        df = pd.read_csv(f'/kaggle/input/neurips-2023-machine-unlearning/{split}.csv')
        df['image_path'] = df['image_id'].apply(
            lambda x: os.path.join('/kaggle/input/neurips-2023-machine-unlearning/', 'images', x.split('-')[0], x.split('-')[1] + '.png'))
        df = df.sort_values(by='image_path')
        df.apply(lambda row: self.examples.append(load_example(row)), axis=1)
        if len(self.examples) == 0:
            raise ValueError('No examples.')

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        example = self.examples[idx]
        image = example['image']
        image = image.to(torch.float32)
        example['image'] = image
        return example


def get_dataset(batch_size):
    '''Get the dataset.'''
    retain_ds = HiddenDataset(split='retain')
    forget_ds = HiddenDataset(split='forget')
    val_ds = HiddenDataset(split='validation')

    retain_loader = DataLoader(retain_ds, batch_size=batch_size, shuffle=True)
    forget_loader = DataLoader(forget_ds, batch_size=batch_size, shuffle=True)
    validation_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=True)

    return retain_loader, forget_loader, validation_loader

In [None]:
def hessian(model, train_loader, device='cuda'):
    loss_fn = torch.nn.CrossEntropyLoss(reduction="mean")

    for p in model.parameters():
        p.grad_acc = 0
        p.grad2_acc = 0

    for sample in train_loader:
        data = sample["image"]
        orig_target = sample["age_group"]
        data, orig_target = data.to(device), orig_target.to(device)
        output = model(data)
        prob = torch.nn.functional.softmax(output, dim=-1).data

        for y in range(output.shape[1]):
            target = torch.empty_like(orig_target).fill_(y)
            loss = loss_fn(output, target)
            model.zero_grad()
            loss.backward(retain_graph=True)
            for p in model.parameters():
                if p.requires_grad:
                    p.grad2_acc += torch.mean(prob[:, y]) * p.grad.data.pow(2)

    for p in model.parameters():
        p.grad2_acc /= len(train_loader)

In [None]:
def get_mean_var(p, alpha, is_base_dist=False):

    num_classes = 10
    num_indexes_to_replace = 4500 #?
    dataset = 'cifar10'

    var = copy.deepcopy(1./(p.grad2_acc+1e-8))
    var = var.clamp(max=1e3)
    if p.shape[0] == num_classes:
        var = var.clamp(max=1e2)
    var = alpha * var
    if p.ndim > 1:
        var = var.mean(dim=1, keepdim=True).expand_as(p).clone()
    if not is_base_dist:
        mu = copy.deepcopy(p.data0.clone())
    else:
        mu = copy.deepcopy(p.data0.clone())

    # if p.shape[0] == num_classes and ((num_indexes_to_replace == 4500 and dataset == "cifar10")):
    #     mu[args.class_to_replace] = 0
    #     var[args.class_to_replace] = 0.0001
    if p.shape[0] == num_classes:
        # Last layer
        var *= 10
    elif p.ndim == 1:
        # BatchNorm
        var *= 10
    return mu, var

In [None]:
def fisher_new(net, retain, alpha=1e-7, device='cuda'):

    for p in net.parameters():
        p.data0 = copy.deepcopy(p.data.clone())
    hessian(net, retain, device)
    for i, p in enumerate(net.parameters()):
        mu, var = get_mean_var(p, alpha, is_base_dist=False)
        p.data = mu+var.sqrt()*torch.empty_like(p.data).normal_()
    return net

In [None]:
# Function to update learning rate
def adjust_learning_rate(optimizer, current_batch, total_batches, initial_lr):
    """Sets the learning rate for warmup over total_batches"""
    lr = initial_lr * (current_batch / total_batches)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

In [None]:
# You can replace the below simple unlearning with your own unlearning function.

def unlearning(
    net, 
    retain_loader, 
    forget_loader, 
    val_loader):
    
    alpha = 1e-7
    # for _ in range(2):
    net = fisher_new(net, retain_loader, alpha=alpha)

    net.eval()

In [None]:
if os.path.exists('/kaggle/input/neurips-2023-machine-unlearning/empty.txt'):
    # mock submission
    subprocess.run('touch submission.zip', shell=True)
else:
    
    # Note: it's really important to create the unlearned checkpoints outside of the working directory 
    # as otherwise this notebook may fail due to running out of disk space.
    # The below code saves them in /kaggle/tmp to avoid that issue.
    
    os.makedirs('/kaggle/tmp', exist_ok=True)
    retain_loader, forget_loader, validation_loader = get_dataset(64*4)
    net = resnet18(weights=None, num_classes=10)
    net.to(DEVICE)
    for i in range(512):
        net.load_state_dict(torch.load('/kaggle/input/neurips-2023-machine-unlearning/original_model.pth'))
        unlearning(net, retain_loader, forget_loader, validation_loader)
        state = net.state_dict()
        torch.save(state, f'/kaggle/tmp/unlearned_checkpoint_{i}.pth')
        
    # Ensure that submission.zip will contain exactly 512 checkpoints 
    # (if this is not the case, an exception will be thrown).
    unlearned_ckpts = os.listdir('/kaggle/tmp')
    if len(unlearned_ckpts) != 512:
        raise RuntimeError('Expected exactly 512 checkpoints. The submission will throw an exception otherwise.')
        
    subprocess.run('zip submission.zip /kaggle/tmp/*.pth', shell=True)