In [1]:
import os
import subprocess

import pandas as pd
import torch
import numpy as np
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision.models import resnet18
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from skimage import io
from time import time

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

In [2]:
!mkdir utkface
!tar -xzf /content/drive/MyDrive/Dataset/utkface.tar.gz -C utkface/

mkdir: cannot create directory ‘utkface’: File exists


In [3]:
!mkdir /content/dataset
!cp /content/drive/MyDrive/Dataset/* /content/dataset

mkdir: cannot create directory ‘/content/dataset’: File exists
cp: -r not specified; omitting directory '/content/drive/MyDrive/Dataset/outs'


# Influence Function Hook

In [4]:
calculated_influences = dict()
required_influence = False

def influence_conv2d(inp, w, ks, st):
    num_out_channels, num_inp_channels, _, _ = w.shape
    inp_unfolded = F.unfold(inp, ks, stride=st).view((inp.shape[0], num_inp_channels, -1, ks[0]*ks[1])).permute((1, 2, 0, 3)).unsqueeze(0)
    weight_unfolded = w.view((num_out_channels, num_inp_channels, -1)).unsqueeze(2).unsqueeze(3)

    out = inp_unfolded * weight_unfolded
    out = out.mean(-2)
    out = torch.max(out, -2)[0]

    return torch.softmax(out, -1).view((num_out_channels, num_inp_channels, ks[0], ks[1]))

def influence_function_hook(mod: nn.Module, inp, out):
    if not required_influence:
      return

    if isinstance(inp, tuple):
        inp = inp[0]

    if isinstance(mod, nn.Linear):
        w = next(mod.parameters())
        b = next(mod.parameters())
        influence_matrix = F.softmax((w * inp.unsqueeze(1)).mean(dim=0), dim=1)

    if isinstance(mod, nn.Conv2d):
        kernel_size = mod.kernel_size
        stride = mod.stride
        padding = mod.padding

        inp_w_pad = F.pad(inp, (padding[0], padding[0], padding[1], padding[1]))
        weight = next(mod.parameters())
        influence_matrix = influence_conv2d(inp_w_pad, weight, kernel_size, stride)

    calculated_influences[id(mod)] = influence_matrix


def exude_influence(module):
    influence = calculated_influences.get(id(module))

    if influence != None:
        weight = next(module.parameters())
        weight.grad.data *= influence

    if isinstance(module, nn.Module):
        for child_module in module.children():
            exude_influence(child_module)


def register_hook_recursive(module, hook):
    if isinstance(module, (nn.Conv2d, nn.Linear)):
        hook_handle = module.register_forward_hook(hook)

    if isinstance(module, nn.Module):
        for child_module in module.children():
            register_hook_recursive(child_module, hook)

# Dataset Utils - Kaggle

In [5]:
# Helper functions for loading the hidden dataset.

def load_example(df_row):
    image = torchvision.io.read_image(df_row['image_path'])
    result = {
        'image': image,
        'image_id': df_row['image_id'],
        'age_group': df_row['age_group'],
        'age': df_row['age'],
        'person_id': df_row['person_id']
    }
    return result


class HiddenDataset(Dataset):
    '''The hidden dataset.'''
    def __init__(self, split='train'):
        super().__init__()
        self.examples = []

        df = pd.read_csv(f'/kaggle/input/neurips-2023-machine-unlearning/{split}.csv')
        df['image_path'] = df['image_id'].apply(
            lambda x: os.path.join('/kaggle/input/neurips-2023-machine-unlearning/', 'images', x.split('-')[0], x.split('-')[1] + '.png'))
        df = df.sort_values(by='image_path')
        df.apply(lambda row: self.examples.append(load_example(row)), axis=1)
        if len(self.examples) == 0:
            raise ValueError('No examples.')

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        example = self.examples[idx]
        image = example['image']
        image = image.to(torch.float32)
        example['image'] = image
        return example


def get_dataset(batch_size):
    '''Get the dataset.'''
    retain_ds = HiddenDataset(split='retain')
    forget_ds = HiddenDataset(split='forget')
    val_ds = HiddenDataset(split='validation')

    retain_loader = DataLoader(retain_ds, batch_size=batch_size, shuffle=True)
    forget_loader = DataLoader(forget_ds, batch_size=batch_size, shuffle=True)
    validation_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=True)

    return retain_loader, forget_loader, validation_loader


# Dataset Utils - UTK

In [6]:
class UTKAgeDataset(torch.utils.data.Dataset):
    def __init__(self, csv_path, images_folder, transform = None):
        # self.data = data
        self.data = pd.read_csv(csv_path)
        self.images_folder = images_folder
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        if torch.is_tensor(index):
            index = index.tolist()

        img_name = os.path.join(self.images_folder,
                                self.data.iloc[index, 0])
        image = io.imread(img_name).astype(np.float16)
        age_bin = self.data.iloc[index, 5]
        if self.transform is not None:
          image = self.transform(image)
        if image.shape[0] == 1:
          image = torch.stack([image,image,image]).reshape((3, 32, 32))
        return (torch.tensor(image).to(torch.float), torch.tensor(age_bin).to(torch.long))


def get_dataset(batch_size):
    '''Get the dataset.'''
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((64,64)),
        # transforms.Normalize((0.1307,), (0.3081,)),
    ])
    retain_ds = UTKAgeDataset(csv_path="/content/dataset/retain_set.csv", images_folder="/content/utkface", transform = transform)
    forget_ds = UTKAgeDataset(csv_path="/content/dataset/forget_set.csv", images_folder="/content/utkface", transform = transform)
    val_ds = UTKAgeDataset(csv_path="/content/dataset/forget_set.csv", images_folder="/content/utkface", transform = transform)

    retain_loader = DataLoader(retain_ds, batch_size=batch_size, shuffle=True)
    forget_loader = DataLoader(forget_ds, batch_size=batch_size, shuffle=True)
    validation_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=True)

    return retain_loader, forget_loader, validation_loader

In [7]:
# You can replace the below simple unlearning with your own unlearning function.

def unlearning(
    net,
    retain_loader,
    forget_loader,
    val_loader,
    epochs=1):
    global required_influence
    print(id(required_influence))
    """Simple unlearning by finetuning."""
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=0.001,
                      momentum=0.9, weight_decay=5e-4)
    checkpoints = list()

    net.train()
    register_hook_recursive(net, influence_function_hook)

    for ep in range(epochs):
        required_influence = False
        st = time()
        for idx, sample in enumerate(retain_loader):
            if idx >= len(forget_loader):
              break
            inputs = sample[0]
            targets = sample[1]
            inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)

            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
        en = time()
        print(f"Time taken for epoch[{ep}/{epochs}] retain set: {en-st}")


        required_influence = True
        st = time()
        for idx, sample in enumerate(forget_loader):
            inputs = sample[0]
            targets = sample[1]
            inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = -criterion(outputs, targets)
            loss.backward()

            exude_influence(net)
            optimizer.step()
        en = time()
        print(f"Time taken for epoch[{ep}/{epochs}] forget set: {en-st}")

        checkpoints.append(net.state_dict())

    net.eval()
    return checkpoints


In [8]:
class AgeResNet(nn.Module):
    def __init__(self, num_bins):
        super(AgeResNet, self).__init__()
        # Load a pretrained resnet model from torchvision.models in Pytorch
        self.model = resnet18(pretrained=True)

        # Change the output layer to output
        num_ftrs = self.model.fc.in_features
        self.model.fc = nn.Linear(num_ftrs, num_bins)

    def forward(self, x):
        return self.model(x)

In [9]:
!mkdir -p /content/outs

In [10]:
NUM_CHECKPOINTS = 10

retain_loader, forget_loader, validation_loader = get_dataset(batch_size=5)
net = AgeResNet(10)
net.to(DEVICE)
net.load_state_dict(torch.load('/content/dataset/age_pred_weights.pt'))

checkpoints = unlearning(net, retain_loader, forget_loader, validation_loader, epochs=10)

saves_per_checkpoints = NUM_CHECKPOINTS // len(checkpoints)
idx = 0
for checkpoint in checkpoints:
    for _ in range(saves_per_checkpoints):
        torch.save(checkpoint, f'/content/outs/unlearned_checkpoint_{idx}.pth')
        idx+=1

while idx < NUM_CHECKPOINTS:
    torch.save(checkpoints[-1], f'/content/outs/unlearned_checkpoint_{idx}.pth')
    idx+=1



101055857603584


  return (torch.tensor(image).to(torch.float), torch.tensor(age_bin).to(torch.long))


Time taken for epoch[0/10] retain set: 55.72339987754822
Time taken for epoch[0/10] forget set: 38.88507843017578
Time taken for epoch[1/10] retain set: 44.00469422340393
Time taken for epoch[1/10] forget set: 34.03611087799072
Time taken for epoch[2/10] retain set: 44.795782804489136
Time taken for epoch[2/10] forget set: 33.88401746749878
Time taken for epoch[3/10] retain set: 47.65117931365967
Time taken for epoch[3/10] forget set: 34.38773202896118
Time taken for epoch[4/10] retain set: 51.36579179763794
Time taken for epoch[4/10] forget set: 34.32260608673096
Time taken for epoch[5/10] retain set: 42.2774453163147
Time taken for epoch[5/10] forget set: 37.783751010894775
Time taken for epoch[6/10] retain set: 42.59378409385681
Time taken for epoch[6/10] forget set: 34.328277349472046
Time taken for epoch[7/10] retain set: 47.33673071861267
Time taken for epoch[7/10] forget set: 34.45958685874939
Time taken for epoch[8/10] retain set: 51.58673810958862
Time taken for epoch[8/10] fo

In [11]:
!cp -r /content/outs /content/drive/MyDrive/Dataset/

In [12]:
calculated_influences

{135271046221680: tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
            0.0000e+00, 0.0000e+00],
           [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
            0.0000e+00, 0.0000e+00],
           [0.0000e+00, 0.0000e+00, 1.3858e-29,  ..., 2.2811e-10,
            2.9712e-32, 0.0000e+00],
           ...,
           [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
            0.0000e+00, 0.0000e+00],
           [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 4.3873e-20,
            2.9819e-21, 5.9072e-40],
           [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
            0.0000e+00, 0.0000e+00]],
 
          [[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
            0.0000e+00, 0.0000e+00],
           [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
            0.0000e+00, 0.0000e+00],
           [0.0000e+00, 0.0000e+00, 3.4041e-27,  ..., 2.6260e-05,
            1.4827e-28, 1.4013e-45],
           ...,
           [0.0000e+0

In [13]:
required_influence

True