In [None]:
import os
import subprocess

import pandas as pd
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torchvision.models import resnet18
from torch.utils.data import DataLoader, Dataset

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' 

In [None]:
# It's really important to add an accelerator to your notebook, as otherwise the submission will fail.
# We recomment using the P100 GPU rather than T4 as it's faster and will increase the chances of passing the time cut-off threshold.

if DEVICE != 'cuda':
    raise RuntimeError('Make sure you have added an accelerator to your notebook; the submission will fail otherwise!')

In [None]:
# Helper functions for loading the hidden dataset.

def load_example(df_row):
    image = torchvision.io.read_image(df_row['image_path'])
    result = {
        'image': image,
        'image_id': df_row['image_id'],
        'age_group': df_row['age_group'],
        'age': df_row['age'],
        'person_id': df_row['person_id']
    }
    return result


class HiddenDataset(Dataset):
    '''The hidden dataset.'''
    def __init__(self, split='train'):
        super().__init__()
        self.examples = []

        df = pd.read_csv(f'/kaggle/input/neurips-2023-machine-unlearning/{split}.csv')
        df['image_path'] = df['image_id'].apply(
            lambda x: os.path.join('/kaggle/input/neurips-2023-machine-unlearning/', 'images', x.split('-')[0], x.split('-')[1] + '.png'))
        df = df.sort_values(by='image_path')
        df.apply(lambda row: self.examples.append(load_example(row)), axis=1)
        if len(self.examples) == 0:
            raise ValueError('No examples.')

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        example = self.examples[idx]
        image = example['image']
        image = image.to(torch.float32)
        example['image'] = image
        return example


def get_dataset(batch_size):
    '''Get the dataset.'''
    retain_ds = HiddenDataset(split='retain')
    forget_ds = HiddenDataset(split='forget')
    val_ds = HiddenDataset(split='validation')

    retain_loader = DataLoader(retain_ds, batch_size=batch_size, shuffle=True)
    forget_loader = DataLoader(forget_ds, batch_size=batch_size, shuffle=True)
    validation_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=True)

    return retain_loader, forget_loader, validation_loader

In [None]:
# def accuracy(net, loader):
#     """Return accuracy on a dataset given by the data loader."""
#     correct = 0
#     total = 0
#     for sample in loader:
#         inputs = sample["image"]
#         targets = sample["age_group"]
#         inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
#         outputs = net(inputs)
#         _, predicted = outputs.max(1)
#         total += targets.size(0)
#         correct += predicted.eq(targets).sum().item()
#     return correct / total

In [None]:
# def combine_loaders_randomly(loader1, loader2, net):
#     """Combine two PyTorch DataLoader objects into a single generator, 
#        randomly picking from each and modifying targets from loader2."""
    
#     iter1 = iter(loader1)
#     iter2 = iter(loader2)
    
#     while True:
#         chosen_loader = random.choices(['loader1', 'loader2'], weights=[90, 10], k=1)[0]
        
#         if chosen_loader == 'loader1':
#             try:
#                 batch = next(iter1)
#                 yield batch
#             except StopIteration:
#                 iter1 = iter(loader1)  # Reset the iterator if exhausted
#         else:
#             try:
#                 inputs, _ = next(iter2)  # Ignore original targets
#                 inputs = inputs.to(DEVICE)
                
#                 # Compute new targets
#                 with torch.no_grad():
#                     preds = net(inputs)
#                     _, new_targets = preds.min(dim=1)
                
#                 yield inputs, new_targets
#             except StopIteration:
#                 iter2 = iter(loader2)  # Reset the iterator if exhausted

In [None]:
# You can replace the below simple unlearning with your own unlearning function.

def unlearning(
    net, 
    retain_loader, 
    forget_loader, 
    val_loader):
    
    epochs = 1
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=0.0005,
                      momentum=0.9, weight_decay=0)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=epochs)
    net.train()

    for ep in range(epochs):
        
        # Initialize iterators at the beginning of each epoch
        iter_retain = iter(retain_loader)
        iter_forget = iter(forget_loader)

        # Initialize weights for random selection
        weights = torch.tensor([80, 20], dtype=torch.float32)

        # Initialize flag for depleted iterator
        one_depleted = False

#         while not one_depleted:
        for _ in range(100):
            # Randomly choose a loader based on weights
            chosen_loader_idx = torch.multinomial(weights, 1).item()

            if chosen_loader_idx == 0:  # corresponds to retain_loader
                try:
                    batch = next(iter_retain)
                except StopIteration:
                    one_depleted = True
                    continue  # Skip to the next iteration

            else:  # corresponds to forget_loader
                try:
                    batch = next(iter_forget)
                except StopIteration:
                    one_depleted = True
                    continue  # Skip to the next iteration

            sample = batch
            inputs = sample["image"]
            targets = sample["age_group"]
            if chosen_loader_idx == 1:
                # Compute new targets
                net.eval()
                with torch.no_grad():
                    preds = net(inputs.to(DEVICE))
                    _, targets = torch.topk(preds, 2, dim=1)
                    targets = targets[:, -1] # Choose worst prediction
            
            inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
            
            net.train()
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
        scheduler.step()
        
    net.eval()

In [None]:
# def unlearning(net, retain, forget, validation):
#     """Unlearning by fine-tuning.

#     Fine-tuning is a very simple algorithm that trains using only
#     the retain set.

#     Args:
#       net : nn.Module.
#         pre-trained model to use as base of unlearning.
#       retain : torch.utils.data.DataLoader.
#         Dataset loader for access to the retain set. This is the subset
#         of the training set that we don't want to forget.
#       forget : torch.utils.data.DataLoader.
#         Dataset loader for access to the forget set. This is the subset
#         of the training set that we want to forget. This method doesn't
#         make use of the forget set.
#       validation : torch.utils.data.DataLoader.
#         Dataset loader for access to the validation set. This method doesn't
#         make use of the validation set.
#     Returns:
#       net : updated model
#     """
#     epochs = 1
#     early_stop_steps = 5  # Number of steps to consider for early stopping
#     early_stop_threshold = 0.05  # 5% validation performance decrease

#     criterion = nn.CrossEntropyLoss()
#     optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9, weight_decay=5e-4)
#     scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

#     # Record the original performance on the validation set
#     original_val_accuracy = accuracy(net, validation)
    
#     net.train()
#     steps_without_improvement = 0
#     step_count = 0

#     combined_loader = combine_loaders_randomly(retain, forget, net)

#     for _ in range(epochs):
#         for sample in combined_loader:
#             inputs = sample["image"]
#             targets = sample["age_group"]
#             inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
#             optimizer.zero_grad()
#             outputs = net(inputs)
#             loss = criterion(outputs, targets)
#             loss.backward()
#             optimizer.step()

#             step_count += 1

#             # Evaluate on validation set
#             current_val_accuracy = accuracy(net, validation)
# #             print(f"Validation acc {current_val_accuracy:.2f}")

#             # Check for performance decrease
#             if current_val_accuracy < (1 - early_stop_threshold) * original_val_accuracy:
#                 steps_without_improvement += 1
#                 if steps_without_improvement >= early_stop_steps:
# #                     print(f"Early stopping triggered at step {step_count}.")
#                     net.eval()
#                     return net
#             else:
#                 steps_without_improvement = 0


#         scheduler.step()

#     net.eval()
#     return net

In [None]:
if os.path.exists('/kaggle/input/neurips-2023-machine-unlearning/empty.txt'):
    # mock submission
    subprocess.run('touch submission.zip', shell=True)
else:
    
    # Note: it's really important to create the unlearned checkpoints outside of the working directory 
    # as otherwise this notebook may fail due to running out of disk space.
    # The below code saves them in /kaggle/tmp to avoid that issue.
    
    os.makedirs('/kaggle/tmp', exist_ok=True)
    retain_loader, forget_loader, validation_loader = get_dataset(64)
    net = resnet18(weights=None, num_classes=10)
    net.to(DEVICE)
    for i in range(512):
        net.load_state_dict(torch.load('/kaggle/input/neurips-2023-machine-unlearning/original_model.pth'))
        unlearning(net, retain_loader, forget_loader, validation_loader)
        state = net.state_dict()
        torch.save(state, f'/kaggle/tmp/unlearned_checkpoint_{i}.pth')
        
    # Ensure that submission.zip will contain exactly 512 checkpoints 
    # (if this is not the case, an exception will be thrown).
    unlearned_ckpts = os.listdir('/kaggle/tmp')
    if len(unlearned_ckpts) != 512:
        raise RuntimeError('Expected exactly 512 checkpoints. The submission will throw an exception otherwise.')
        
    subprocess.run('zip submission.zip /kaggle/tmp/*.pth', shell=True)