In [1]:
import os
import copy
import time
import random
import subprocess
from functools import partial

import numpy as np
import pandas as pd
from PIL import Image

import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torchvision.models import resnet18
from torch.utils.data import Dataset, DataLoader

In [2]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

def parameter_count(model):
    total_count = 0
    trainable_count = 0
    for p in model.parameters():
        total_count += torch.prod(torch.tensor(p.shape)).item()
        if p.requires_grad:
            trainable_count += torch.prod(torch.tensor(p.shape)).item()

    return total_count, trainable_count

if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(device)

cuda


In [3]:
class MovingAverage:
    def __init__(self, name, rd=4):
        self.name = name
        # avg value
        self.val = 0.0
        self.sum = 0.0
        self.count = 0
        self.rd = rd

    def update(self, x):
        self.sum += x
        self.count += 1

        # update self.value
        self.val = round(self.sum / self.count, self.rd)

    def value(self) -> float:
        return self.val

In [4]:
class HiddenDataset(Dataset):
    def __init__(self, df, base_dir):
        super().__init__()
        df['image_path'] = df['image_id'].apply(lambda x: os.path.join(base_dir,'images', x.split('-')[0], x.split('-')[1] + '.png'))
        self.df = df

        # read the images at the init only
        # self.images = [
        #    torch.tensor(np.transpose(np.array(Image.open(x).convert('RGB')), [2, 0, 1])) for x in self.df['image_path'].tolist()
        # ]
        self.images = [torchvision.io.read_image(x) for x in self.df['image_path'].tolist()]

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        image = self.images[index]
        age = self.df['age_group'].iloc[index]
        return image, age


In [5]:
if os.path.exists('/kaggle/input/neurips-2023-machine-unlearning/empty.txt'):
    # save the file while saving the version
    # subprocess.run('touch submission.zip', shell=True)
    base_dir = "/kaggle/input/mock-cifar10-data"
    num_checkpoints = 10
    real_run = False
else:
    # this part will run when we submit to kaggle.
    base_dir = "/kaggle/input/neurips-2023-machine-unlearning/"
    num_checkpoints = 512
    real_run = True

In [6]:
os.makedirs('/kaggle/tmp', exist_ok=True)

print(f"Initializing the model")
trained_model = resnet18(weights=None, num_classes=10)
original_path = os.path.join(base_dir, 'original_model.pth')
print(f"Loading the model from checkpoint = {original_path}")
trained_model.load_state_dict(torch.load(original_path, map_location=device))
trained_model.to(device)

retain_df = pd.read_csv(os.path.join(base_dir, "retain.csv"))
forget_df = pd.read_csv(os.path.join(base_dir, "forget.csv"))
validation_df = pd.read_csv(os.path.join(base_dir, "validation.csv"))

print(f"Initializing the retain dataset")
retain_dataset = HiddenDataset(retain_df, base_dir)
# retain_dataset = HiddenDataset(forget_df, base_dir)

print(f"Initializing the forget dataset")
forget_dataset = HiddenDataset(forget_df, base_dir)

print(f"Initializing the validation dataset")
validation_dataset = HiddenDataset(validation_df, base_dir)

print(f"length of retain dataset = {len(retain_dataset)}")
print(f"length of forget dataset = {len(forget_dataset)}")
print(f"length of validation dataset = {len(validation_dataset)}")

Initializing the model
Loading the model from checkpoint = /kaggle/input/mock-cifar10-data/original_model.pth
Initializing the retain dataset
Initializing the forget dataset
Initializing the validation dataset
length of retain dataset = 27440
length of forget dataset = 560
length of validation dataset = 3500


In [7]:
# copied from : https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
class MyModel(nn.Module):
    def __init__(self, tm):
        super(MyModel, self).__init__()
        self.trained_model = tm

    def forward(self, x):
        x = self.trained_model.conv1(x)
        x = self.trained_model.bn1(x)
        x = self.trained_model.relu(x)
        x = self.trained_model.maxpool(x)

        x = self.trained_model.layer1(x)
        x = self.trained_model.layer2(x)
        x = self.trained_model.layer3(x)
        x = self.trained_model.layer4(x)

        x = self.trained_model.avgpool(x)
        x = torch.flatten(x, 1)
        # x will be of shape (batchsize, 512)
        return x


In [8]:
def calculate_accuracy(model_init, dataloader, device):
    model_init.eval()
    gt = np.array([])
    pred = np.array([])
    with torch.no_grad():
        for X, y in dataloader:
            X = X.float().to(device)
            y = y.long().to(device)

            out = model_init(X)
            y_pred = torch.argmax(out, dim=1)

            gt = np.append(gt, y.cpu().numpy())
            pred = np.append(pred, y_pred.cpu().numpy())

    acc = round(float(np.mean(gt == pred)), 6)

    return acc


def frx(features, gt, params):
    loss_fn = nn.CrossEntropyLoss()
    alpha = params[:5120].view(10, 512)
    beta = params[5120:].view(1, 10)
    # features = (batch_size, 512), alphaT = (512, 10) , beta = (1, 10)
    logits = torch.mm(features, torch.transpose(alpha, 0, 1)) + beta
    # logits -> (512, 10), gt -> (512)
    loss = loss_fn(logits, gt)
    return loss


def unlearning(
    model,
    retain_loader,
    forget_loader,
    validation_loader,
    device
):
    # evaluate first
    retain_acc = calculate_accuracy(model, retain_loader, device)
    forget_acc = calculate_accuracy(model, forget_loader, device)
    validation_acc = calculate_accuracy(model, validation_loader, device)

    print(f"Initial retain acc = {retain_acc}, forget acc = {forget_acc}, validation acc = {validation_acc}")

    # add noise
    std = 1e-3
    for p in model.parameters():
        noise = std * torch.randn_like(p.data)
        p.data = p.data + noise

    # create a model
    mymodel = MyModel(copy.deepcopy(model))
    mymodel.to(device)
    # set it to eval
    mymodel.eval()

    # collect features retain
    retain_features = []
    retain_gt = []
    with torch.no_grad():
        for X_retain, y_retain in retain_loader:
            # change
            X_retain = X_retain.float().to(device)
            y_retain = y_retain.long().to(device)
            out_retain = mymodel(X_retain)
            retain_features.append(out_retain)
            retain_gt.append(y_retain)

    retain_features = torch.cat(retain_features, dim=0)
    retain_gt = torch.cat(retain_gt)
    print(f"retain features shape = {retain_features.shape}")
    print(f"retain gt shape = {retain_gt.shape}")

    # collect features forget
    forget_features = []
    forget_gt = []
    with torch.no_grad():
        for X_forget, y_forget in forget_loader:
            # change
            X_forget = X_forget.float().to(device)
            y_forget = y_forget.long().to(device)
            out_forget = mymodel(X_forget)
            forget_features.append(out_forget)
            forget_gt.append(y_forget)

    forget_features = torch.cat(forget_features, dim=0)
    forget_gt = torch.cat(forget_gt)
    print(f"forget features shape = {forget_features.shape}")
    print(f"forget gt shape = {forget_gt.shape}")

    # collect the param vector
    print(f"mean weight = {torch.mean(model.fc.weight.data)}, mean bias = {torch.mean(model.fc.bias.data)}")
    vfcat = torch.cat(
        [model.fc.weight.data.view(-1, 1), model.fc.bias.data.view(-1, 1)],
        dim=0
    ).squeeze()

    # create the partial function and calculate the retain hessian
    retain_param = copy.deepcopy(vfcat)
    retain_param.requires_grad = True
    print(f"retain parameter vector shape = {retain_param.shape}")

    retain_partial_frx = partial(frx, retain_features, retain_gt)
    retain_hessian = torch.autograd.functional.hessian(retain_partial_frx, retain_param)
    print(f"retain hessian shape = {retain_hessian.shape}")

    # create the partial function and calculate the forget gradient
    forget_param = copy.deepcopy(vfcat)
    forget_param.requires_grad = True
    print(f"forget parameter vector shape = {forget_param.shape}")

    forget_partial_frx = partial(frx, forget_features, forget_gt)
    forget_loss = frx(forget_features, forget_gt, forget_param)
    forget_grad = torch.autograd.grad(forget_loss, forget_param)
    print(f"forget gradient shape = {forget_grad[0].shape}")

    # calculate the final param
    eps = 1e-3
    final_param = copy.deepcopy(vfcat)
    pertub = torch.mm(torch.linalg.inv(retain_hessian), torch.unsqueeze(forget_grad[0], -1)).squeeze()
    final_param = final_param + eps * pertub

    # replace the final param as the model weight and bias
    model.fc.weight.data = copy.deepcopy(final_param[:5120].view(10, 512))
    model.fc.bias.data = copy.deepcopy(final_param[5120:].view(10))
    print(f"mean weight = {torch.mean(model.fc.weight.data)}, mean bias = {torch.mean(model.fc.bias.data)}")

    # evaluate now
    retain_acc_update = calculate_accuracy(model, retain_loader, device)
    forget_acc_update = calculate_accuracy(model, forget_loader, device)
    validation_acc_update = calculate_accuracy(model, validation_loader, device)

    print(f"After scrub retain acc = {retain_acc_update}, forget acc = {forget_acc_update}, validation acc = {validation_acc_update}")


In [9]:
T1 = time.time()

batch_size = 64
retain_loader = DataLoader(retain_dataset, batch_size=batch_size, shuffle=True)
forget_loader = DataLoader(forget_dataset, batch_size=batch_size, shuffle=True)
validation_loader = DataLoader(validation_dataset, batch_size=batch_size, shuffle=False)

for sd in range(num_checkpoints):
    TS1 = time.time()
    print(f"Running for checkpoint = {sd}")
    final_model = copy.deepcopy(trained_model)
    unlearning(final_model, retain_loader, forget_loader, validation_loader, device)
    TS2 = time.time()
    state = final_model.half().state_dict()
    torch.save(state, f'/kaggle/tmp/unlearned_checkpoint_{sd}.pth')
    print(f"Time taken = {round(TS2 - TS1, 3)} seconds")

T2 = time.time()
timetaken_models = round(T2 - T1, 3)
print(f"Total timetaken to run the {num_checkpoints} models is = {timetaken_models} seconds")

Running for checkpoint = 0
Initial retain acc = 0.979701, forget acc = 0.985714, validation acc = 0.725714
retain features shape = torch.Size([27440, 512])
retain gt shape = torch.Size([27440])
forget features shape = torch.Size([560, 512])
forget gt shape = torch.Size([560])
mean weight = -0.012752010487020016, mean bias = -0.013015004806220531
retain parameter vector shape = torch.Size([5130])
retain hessian shape = torch.Size([5130, 5130])
forget parameter vector shape = torch.Size([5130])
forget gradient shape = torch.Size([5130])
mean weight = 0.11834371089935303, mean bias = 11.994539260864258
After scrub retain acc = 0.978972, forget acc = 0.980357, validation acc = 0.727429
Time taken = 17.88 seconds
Running for checkpoint = 1
Initial retain acc = 0.979701, forget acc = 0.985714, validation acc = 0.725714
retain features shape = torch.Size([27440, 512])
retain gt shape = torch.Size([27440])
forget features shape = torch.Size([560, 512])
forget gt shape = torch.Size([560])
mean 

In [10]:
T3 = time.time()
# Ensure that submission.zip will contain exactly num_checkpoints 
# (if this is not the case, an exception will be thrown).
unlearned_ckpts = os.listdir('/kaggle/tmp')
if len(unlearned_ckpts) != num_checkpoints:
    raise RuntimeError(f'Expected exactly {num_checkpoints} checkpoints. The submission will throw an exception otherwise.')

subprocess.run('zip submission.zip /kaggle/tmp/*.pth', shell=True)
T4 = time.time()
zip_time_taken = round(T4 - T3, 3)
print(f"Total time taken to zip the {num_checkpoints} models is = {zip_time_taken} seconds")

  adding: kaggle/tmp/unlearned_checkpoint_0.pth (deflated 7%)
  adding: kaggle/tmp/unlearned_checkpoint_1.pth (deflated 7%)
  adding: kaggle/tmp/unlearned_checkpoint_2.pth (deflated 7%)
  adding: kaggle/tmp/unlearned_checkpoint_3.pth (deflated 7%)
  adding: kaggle/tmp/unlearned_checkpoint_4.pth (deflated 7%)
  adding: kaggle/tmp/unlearned_checkpoint_5.pth (deflated 7%)
  adding: kaggle/tmp/unlearned_checkpoint_6.pth (deflated 7%)
  adding: kaggle/tmp/unlearned_checkpoint_7.pth (deflated 7%)
  adding: kaggle/tmp/unlearned_checkpoint_8.pth (deflated 7%)
  adding: kaggle/tmp/unlearned_checkpoint_9.pth (deflated 7%)
Total time taken to zip the 10 models is = 10.026 seconds
