In [None]:
drive_folder = "Machine_Unlearning_Drive/Cifar10All/"

ssd_folder = "SSD/"

scrub_folder = "SCRUB/"

github_folder = "Machine_Unlearning/"

!pip install scikit-learn torch torchvision seaborn matplotlib

In [None]:
import os
import requests
import numpy as np
import matplotlib.pyplot as plt
from sklearn import linear_model, model_selection
import random

import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader

import torchvision
from torchvision import transforms
from torchvision.utils import make_grid
from torchvision.models import resnet18

from Machine_Unlearning.Metrics.metrics import *

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Running on device:", DEVICE.upper())


def seed_everything(seed):
  RNG = torch.Generator().manual_seed(seed)
  torch.manual_seed(seed)
  random.seed(seed)
  np.random.seed(seed)
  return RNG

SPLIT=0.3
SEED = 1337
RNG = seed_everything(SEED)

In [None]:
# download and pre-process CIFAR10
normalize = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)

train_set = torchvision.datasets.CIFAR10(
    root="./data", train=True, download=True, transform=normalize
)
train_loader = DataLoader(train_set, batch_size=256, shuffle=True, num_workers=2)

# we split held out data into test and validation set
held_out = torchvision.datasets.CIFAR10(
    root="./data", train=False, download=True, transform=normalize
)
test_set, val_set = torch.utils.data.random_split(held_out, [0.5, 0.5], generator=RNG)
test_loader = DataLoader(test_set, batch_size=256, shuffle=False, num_workers=2)
val_loader = DataLoader(val_set, batch_size=256, shuffle=False, num_workers=2)

In [None]:
def load_pretrained_model(DEVICE):
  # download pre-trained weights
  #local_path = drive_folder+"weights_resnet18_cifar10.pth"
  local_path = github_folder + "starting_checkpoint_CIFAR10.pth"
  weights_pretrained = torch.load(local_path, map_location=DEVICE)

  # load model with pre-trained weights
  model = resnet18(weights=None, num_classes=10)
  model.load_state_dict(weights_pretrained)
  model.to(DEVICE)
  model.eval();
  return model

model = load_pretrained_model(DEVICE)
print(f"Train set accuracy: {100.0 * accuracy(model, train_loader):0.1f}%")
print(f"Test set accuracy: {100.0 * accuracy(model, test_loader):0.1f}%")

In [None]:
def unlearning(net, retain, forget, validation, start_alpha = 0.1, alpha_sched = lambda start_a,a,max_ep,ep: a-(start_a/max_ep), lr = 0.001, forget_epochs = 27, use_scheduler = True):
    import math

    print("Starting Unlearning")
    ### FORGETTING ###
    criterion = nn.CrossEntropyLoss(reduction='none')
    optimizer = optim.AdamW(net.parameters(), lr= lr )
    forget_epochs = forget_epochs
    scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.1, total_iters=forget_epochs)

    net.eval()

    alpha = start_alpha

    net.train()

    retain_iter = iter(retain)

    def entropy(outputs):

        p = torch.nn.functional.softmax(outputs, dim=-1)
        return (-torch.where(p > 0, p * p.log(), p.new([0.0])).sum(dim=-1, keepdim=False))

    for i in range(forget_epochs):
      net.eval()

      if i%5==0 :

        val_loss, first_test_moment, second_test_moment, test_std = compute_moments(net, validation)

      net.train()


      if i % ((len(retain)//len(forget)))==0:
        retain_iter = iter(retain)


      for c , (inputs, targets) in enumerate(forget):


        net.zero_grad()
        inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
        out = net(inputs)

        r_inputs, r_targets = next(retain_iter)
        r_inputs, r_targets = r_inputs.to(DEVICE), r_targets.to(DEVICE)
        r_out = net(r_inputs)


        forget_losses = criterion(out, targets)
        retain_losses = criterion(r_out,r_targets)

        #Forget loss metrics
        forget_mean = torch.mean(forget_losses)

        delta_val_loss =  (val_loss - forget_mean)

        retain_mean = torch.mean(retain_losses)

        loss =  alpha*(torch.nn.functional.relu(delta_val_loss)**2) + (1-alpha)* retain_mean


        loss.backward()
        optimizer.step()

      alpha = alpha_sched(start_alpha,alpha,forget_epochs,i)
      if use_scheduler:
        scheduler.step()

    net.eval()
    print("Unlearning done")
    return net

In [None]:
import json

seeds = [1,2,3]
alphas = [0.01,0.02, 0.05, 0.1,0.15,0.2,0.25,0.3,0.35,0.4,0.45,0.5]
splits = [0.01,0.02, 0.05, 0.1,0.15,0.2,0.25,0.3,0.35,0.4,0.45,0.5]

results = [[[{} for _ in seeds] for _ in alphas] for _ in splits]

try:
  with open(drive_folder+"results.json", 'r') as fout:
    results =json.load(fout)
    print("Loaded results:")
    #print(results)
except:
  pass

for i,split in enumerate(splits):
  for j,start_alpha in enumerate(alphas):
    for s,seed in enumerate(seeds):

      experiment_name = f"Exp_Split_{int(split*100)}%_Alpha_{int(start_alpha*100)}%_Seed_{seed}"
      print("Starting new experiment: "+experiment_name)

      if "Experiment_name" in results[i][j][s]:
        print("Skipping experiment")
        continue
      #General seeds
      RNG = torch.Generator().manual_seed(1)
      torch.manual_seed(1337)
      random.seed(1337)
      np.random.seed(1337)

      #Seed for dataset splitting
      GEN = torch.Generator().manual_seed(seed)

      retain_set, forget_set = torch.utils.data.random_split(train_set,[1-split,split],GEN)

      forget_loader = torch.utils.data.DataLoader(
          forget_set, batch_size=256, shuffle=True, num_workers=2 , generator=RNG
      )
      retain_loader = torch.utils.data.DataLoader(
          retain_set, batch_size=256, shuffle=True, num_workers=2, generator=RNG
      )


      forget_epochs = min(int((len(retain_loader) / (len(forget_loader)*2)) * 6), 80) #hard cap on 80
      print(forget_epochs)
      ft_model = load_pretrained_model(DEVICE)
      ft_model = unlearning(ft_model, retain_loader, forget_loader, val_loader,start_alpha=start_alpha, lr = 0.001, forget_epochs = forget_epochs, use_scheduler = True)

      ### EVALUATE

      #Entropy score
      ft_forget_entropies = compute_entropy(ft_model, forget_loader)
      ft_test_entropies = compute_entropy(ft_model, test_loader)
      ft_retain_entropies = compute_entropy(ft_model, retain_loader)

      results[i][j][s]["ft_forget_entropies"] = ft_forget_entropies.tolist()
      results[i][j][s]["ft_test_entropies"] = ft_test_entropies.tolist()
      results[i][j][s]["ft_retain_entropies"] = ft_retain_entropies.tolist()

      # make sure we have a balanced dataset for the MIA

      if len(ft_forget_entropies) > len(ft_test_entropies):
        np.random.shuffle(ft_forget_entropies)
        ft_forget_entropies = ft_forget_entropies[: len(ft_test_entropies)]
      else:
        np.random.shuffle(ft_test_entropies)
        ft_test_entropies = ft_test_entropies[: len(ft_forget_entropies)]

      assert len(ft_test_entropies) == len(ft_forget_entropies)

      ft_samples_mia_entropy = np.concatenate((ft_test_entropies, ft_forget_entropies)).reshape((-1, 1))
      ft_labels_mia_entropy = [0] * len(ft_test_entropies) + [1] * len(ft_forget_entropies)

      ft_mia_scores_entropy = simple_mia(ft_samples_mia_entropy, ft_labels_mia_entropy)

      results[i][j][s]["Mia_entropy"] = ft_mia_scores_entropy.mean()

      ft_forget_losses = compute_losses(ft_model, forget_loader)
      ft_test_losses = compute_losses(ft_model, test_loader)
      ft_retain_losses = compute_losses(ft_model, retain_loader)

      results[i][j][s]["ft_forget_losses"] = ft_forget_losses.tolist()
      results[i][j][s]["ft_test_losses"] = ft_test_losses.tolist()
      results[i][j][s]["ft_retain_losses"] = ft_retain_losses.tolist()


      # make sure we have a balanced dataset for the MIA

      if len(ft_forget_losses) > len(ft_test_losses):
        np.random.shuffle(ft_forget_losses)
        ft_forget_losses = ft_forget_losses[: len(ft_test_losses)]
      else:
        np.random.shuffle(ft_test_losses)
        ft_test_losses = ft_test_losses[: len(ft_forget_losses)]

      assert len(ft_test_losses) == len(ft_forget_losses)

      ft_samples_mia = np.concatenate((ft_test_losses, ft_forget_losses)).reshape((-1, 1))
      ft_labels_mia = [0] * len(ft_test_losses) + [1] * len(ft_forget_losses)

      ft_mia_scores = simple_mia(ft_samples_mia, ft_labels_mia)

      results[i][j][s]["Mia_loss"] = ft_mia_scores.mean()

      retain_acc = accuracy(ft_model,retain_loader)
      test_acc = accuracy(ft_model,test_loader)
      forget_acc = accuracy(ft_model,forget_loader)

      results[i][j][s]["Retain_accuracy"] = retain_acc
      results[i][j][s]["Test_accuracy"] = test_acc
      results[i][j][s]["Forget_accuracy"] = forget_acc

      fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

      ax1.set_title(f"Losses after unlearning.\nAttack accuracy: {ft_mia_scores.mean():0.3f}")
      ax1.hist(ft_test_losses, density=True, alpha=0.5, bins=50, label="Test set on unlearned model")
      ax1.hist(ft_forget_losses, density=True, alpha=0.5, bins=50, label="Forget set on unlearned model")
      #ax1.hist(test_losses, density=True, alpha=0.3, bins=50, label="Forget set on unlearned model")

      ax2.set_title(f"Entropy after unlearning.\nAttack accuracy: {ft_mia_scores_entropy.mean():0.3f}")
      ax2.hist(ft_test_entropies, density=True, alpha=0.5, bins=50, label="Test set on pretrained model")
      ax2.hist(ft_forget_entropies, density=True, alpha=0.5, bins=50, label="Forget set on pretrained model")

      ax1.set_xlabel("Loss")
      ax2.set_xlabel("Loss")
      ax1.set_ylabel("Frequency")
      ax1.set_yscale("log")
      ax2.set_yscale("log")
      ax1.set_xlim((0, max(np.max(ft_test_losses),np.max(ft_forget_losses))))
      ax2.set_xlim((0, max(np.max(ft_test_entropies),np.max(ft_forget_entropies))))
      for ax in (ax1, ax2):
          ax.spines["top"].set_visible(False)
          ax.spines["right"].set_visible(False)
      ax1.legend(frameon=False, fontsize=14)
      ax2.legend(frameon=False, fontsize=14)
      plt.savefig(drive_folder+ experiment_name+".png")
      plt.close(fig)

      results[i][j][s]["Experiment_name"] = experiment_name
      with open(drive_folder+"results.json", 'w') as fout:
          json.dump(results, fout)
      print("Experiment_over")
      #print(results[i][j][s])