In [1]:
import os
import gc
import math
import random
import numpy as np
import copy
import subprocess
from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances
from sklearn import linear_model, model_selection
from sklearn.utils import shuffle

import pandas as pd
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import prune
import torch.optim as optim
from torchvision.models import resnet18
from torch.utils.data import DataLoader, Dataset, TensorDataset, Subset
from torchvision.models.feature_extraction import create_feature_extractor

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' 



In [2]:
# It's really important to add an accelerator to your notebook, as otherwise the submission will fail.
# We recomment using the P100 GPU rather than T4 as it's faster and will increase the chances of passing the time cut-off threshold.

if DEVICE != 'cuda':
    raise RuntimeError('Make sure you have added an accelerator to your notebook; the submission will fail otherwise!')

In [3]:
# Helper functions for loading the hidden dataset.

def load_example(df_row):
    image = torchvision.io.read_image(df_row['image_path'])
    result = {
        'image': image,
        'image_id': df_row['image_id'],
        'age_group': df_row['age_group'],
        'age': df_row['age'],
        'person_id': df_row['person_id']
    }
    return result


class HiddenDataset(Dataset):
    '''The hidden dataset.'''
    def __init__(self, split='train'):
        super().__init__()
        self.examples = []

        df = pd.read_csv(f'/kaggle/input/neurips-2023-machine-unlearning/{split}.csv')
        df['image_path'] = df['image_id'].apply(
            lambda x: os.path.join('/kaggle/input/neurips-2023-machine-unlearning/', 'images', x.split('-')[0], x.split('-')[1] + '.png'))
        df = df.sample(frac=1).reset_index(drop=True)
        df.apply(lambda row: self.examples.append(load_example(row)), axis=1)
        if len(self.examples) == 0:
            raise ValueError('No examples.')

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        example = self.examples[idx]
        image = example['image']
        image = image.to(torch.float32)
        example['image'] = image
        return example


def get_dataset(batch_size):
    '''Get the dataset.'''
    retain_ds = HiddenDataset(split='retain')
    forget_ds = HiddenDataset(split='forget')
    val_ds = HiddenDataset(split='validation')

    retain_loader = DataLoader(retain_ds, batch_size=batch_size, shuffle=True)
    forget_loader = DataLoader(forget_ds, batch_size=batch_size, shuffle=True)
    validation_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=True)

    return retain_loader, forget_loader, validation_loader

In [4]:
def accuracy(net, loader):
    """Return accuracy on a dataset given by the data loader."""
    correct = 0
    total = 0
    for sample in loader:
        inputs = sample["image"]
        targets = sample["age_group"]
        inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
        outputs = net(inputs)
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    return correct / total

In [5]:
# Function to update learning rate
def adjust_learning_rate(optimizer, current_batch, total_batches, initial_lr):
    """Sets the learning rate for warmup over total_batches"""
    lr = initial_lr * (current_batch / total_batches)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

In [6]:
def get_embeddings(
    net, 
    retain_loader,
    val_loader
):
    
    '''
    Feature extraction
    '''
    
    feat_extractor = create_feature_extractor(net, {'avgpool': 'feat1'})
    
    '''
    Get class weights
    '''
    
    # Retain logits
    data = np.empty((len(retain_loader.dataset), 513), dtype=object)
    idx = 0
    
    with torch.no_grad():
        for sample in retain_loader:
            # Get logits
            targets = sample["age_group"]
            
            # Feature extraction
            inputs = sample["image"]
            person_id = sample["person_id"]
            outputs = feat_extractor(inputs.to(DEVICE))['feat1']
            feats = torch.flatten(outputs, start_dim=1)
        
            for i in range(len(targets)):
                data[idx] = [targets[i].item()] + feats[i].cpu().numpy().tolist()
                idx +=1
       
    columns = ['unique_id'] + [f'feat_{i}' for i in range(512)]
    embeddings_retain_df = pd.DataFrame(data, columns=columns)
    

    # Val logits
    data = np.empty((len(val_loader.dataset), 513), dtype=object)
    idx = 0
    
    with torch.no_grad():
        for sample in val_loader:
            # Get logits
            targets = sample["age_group"]
            
            # Feature extraction
            inputs = sample["image"]
            person_id = sample["person_id"]
            outputs = feat_extractor(inputs.to(DEVICE))['feat1']
            feats = torch.flatten(outputs, start_dim=1)
        
            for i in range(len(targets)):
                data[idx] = [str(person_id[i])] + feats[i].cpu().numpy().tolist()
                idx +=1

    columns = ['unique_id'] + [f'feat_{i}' for i in range(512)]
    embeddings_val_df = pd.DataFrame(data, columns=columns)
    

    return embeddings_retain_df, embeddings_val_df

In [7]:
# Contrastive Loss
class ContrastiveLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin
    
    def forward(self, output1, output2, label):
        euclidean_distance = nn.functional.pairwise_distance(output1, output2)
        loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) + (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
        return loss_contrastive

In [8]:
# Extract feature and pooling layers to create a Custom Model
class CustomResNet18(nn.Module):
    def __init__(self, original_model):
        super(CustomResNet18, self).__init__()

        # Extract features and pooling layers
        self.features = nn.Sequential(*list(original_model.children())[:-2])
        self.pooling = list(original_model.children())[-2]

    def forward(self, x):
        x = self.features(x)
        x = self.pooling(x)
        x = torch.squeeze(x)
        return x

In [9]:
def contrastive_unlearning(net, forget_loader, grouped_retain_df, grouped_val_df, LR=1e-3, max_num_steps=3):
    
    custom_model = CustomResNet18(net).to(DEVICE)
    criterion = ContrastiveLoss()
    optimizer = optim.AdamW(custom_model.parameters(), lr=LR)
    
    for i, batch in enumerate(forget_loader):
        custom_model.train()
        optimizer.zero_grad()
        inputs = batch['image'].to(DEVICE)
        targets = batch['age_group']
        person_ids = batch['person_id']

        # Forward pass to get embeddings for the forget_batch
        forget_embeddings = custom_model(inputs)

        positive_pairs = []
        negative_pairs = []

        with torch.no_grad():  # Disable gradient computation to save memory
            
            # Fetch Positive Pairs
            for index, pid in enumerate(person_ids):
                candidate_embeddings = grouped_val_df.get(str(pid), None)
                if candidate_embeddings is not None:  # If a positive pair exists
                    selected_embedding = shuffle(candidate_embeddings, n_samples=1)[0]  # Randomly select one
                else:  # Fallback to using the instance's own embedding
                    selected_embedding = forget_embeddings[index].cpu().detach().numpy()

                positive_pairs.append(torch.tensor(selected_embedding.astype(float)).float())


            # Convert to tensors for ease of computation
            positive_pairs = torch.stack(positive_pairs).to(DEVICE)

            # Fetch Negative Pairs
            for tgt in targets.cpu().numpy():
                candidate_embeddings = grouped_retain_df.get(tgt, None)
                if candidate_embeddings is not None:
                    selected_embedding = shuffle(candidate_embeddings, n_samples=1)[0]  # Randomly select one
                    negative_pairs.append(torch.tensor(selected_embedding.astype(float)).float())
                else:
                    break


            # Convert to tensors for ease of computation
            negative_pairs = torch.stack(negative_pairs).to(DEVICE)

        # Compute Contrastive Loss
        positive_loss = criterion(forget_embeddings, positive_pairs, torch.zeros(positive_pairs.shape[0]).to(DEVICE))
        negative_loss = criterion(forget_embeddings, negative_pairs, torch.ones(negative_pairs.shape[0]).to(DEVICE))

        # Total loss
        loss = positive_loss # + negative_loss

        loss.backward()
        optimizer.step()
        
        if i==max_num_steps:
            break

In [10]:
def retrain_step(net, retain_loader, retain_class_weights=None, LR=5e-5, max_num_steps=3):

    initial_retain_lr = LR
    criterion = nn.CrossEntropyLoss()
    optimizer_retain = optim.SGD(net.parameters(), lr=initial_retain_lr, momentum=0.9, weight_decay=5e-4)

    warmup_current_batch = 0
    warmup_batches = math.ceil(0.4*len(retain_loader.dataset))
    
    net.train()
    
    for counter, sample in enumerate(retain_loader):

        inputs = sample["image"]
        targets = sample["age_group"]
        inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)

        warmup_current_batch += 1

        # Warm-up for the first 'warmup_batches' batches
#         if warmup_current_batch <= warmup_batches:
#             adjust_learning_rate(optimizer_retain, warmup_current_batch, warmup_batches, initial_retain_lr)

        optimizer_retain.zero_grad()

        # Forward pass
        logits = net(inputs)

        # Calculate loss
        criterion = nn.CrossEntropyLoss(weight=retain_class_weights, label_smoothing=0.0)
        classification_loss = criterion(logits, targets)
        loss = classification_loss
        loss.backward()
        optimizer_retain.step()

        if counter==max_num_steps:
            break
        
#     torch.save({
#         'net': net.state_dict(),
#     }, f'/kaggle/tmp/temp_checkpoint.pth')

In [11]:
if os.path.exists('/kaggle/input/neurips-2023-machine-unlearning/empty.txt'):
    # mock submission
    subprocess.run('touch submission.zip', shell=True)
else:
    os.makedirs('/kaggle/tmp', exist_ok=True)
     
    '''
    Get data loaders
    '''
    batch_size = 64
    retain_loader, forget_loader, validation_loader = get_dataset(batch_size)
    
    
    '''
    Contrastive
    '''
    
    net = resnet18(weights=None, num_classes=10)
    net.to(DEVICE)
    net.load_state_dict(torch.load('/kaggle/input/neurips-2023-machine-unlearning/original_model.pth'))
    
    embeddings_retain_df, embeddings_val_df = get_embeddings(net, retain_loader, validation_loader)

    # Pre-group embeddings by unique_id for fast lookup
    grouped_val_df = embeddings_val_df.groupby('unique_id').apply(lambda x: x.iloc[:, 1:].values)
    grouped_retain_df = embeddings_retain_df.groupby('unique_id').apply(lambda x: x.iloc[:, 1:].values)


    '''
    Loop
    '''
    
    for i in range(512):
        net = resnet18(weights=None, num_classes=10)
        net.to(DEVICE)
        net.load_state_dict(torch.load('/kaggle/input/neurips-2023-machine-unlearning/original_model.pth'))
        for _ in range(2):
            contrastive_unlearning(net, forget_loader, grouped_retain_df, grouped_val_df, LR=1e-4, max_num_steps=10)
            retrain_step(net, retain_loader, LR=1e-5, max_num_steps=2)
        state = net.state_dict()
        torch.save(state, f'/kaggle/tmp/unlearned_checkpoint_{i}.pth')
        gc.collect()
    
    # In the tmp/ folder, there will be 512 checkpoints to submit + 1 for validation early stop that doesn't get zipped
    subprocess.run('zip submission.zip /kaggle/tmp/unlearned_*.pth', shell=True)