In [None]:
drive_folder = "Machine_Unlearning_Drive/Cifar10Results/"

ssd_folder = "SSD/"

scrub_folder = "SCRUB/"

github_folder = "Machine_Unlearning/"

!pip install scikit-learn torch torchvision

In [None]:
import os
import requests
import numpy as np
import matplotlib.pyplot as plt
from sklearn import linear_model, model_selection
import random

import torch
import json
from torch import nn
from torch import optim
from torch.utils.data import DataLoader

import torchvision
from torchvision import transforms
from torchvision.utils import make_grid
from torchvision.models import resnet18

from Machine_Unlearning.Metrics.metrics import *

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Running on device:", DEVICE.upper())

def seed_everything(seed):
  RNG = torch.Generator().manual_seed(seed)
  torch.manual_seed(seed)
  random.seed(seed)
  np.random.seed(seed)
  return RNG

SEED = 42 # @param {type:"number"}
RNG = seed_everything(SEED)
SPLIT = 0.15 # @param ["0.03", "0.15", "0.30"] {type:"raw"}
STARTING_ALPHA = SPLIT * 5/3
results = {}

In [None]:
RNG = seed_everything(SEED)
# download and pre-process CIFAR10
normalize = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)

train_set = torchvision.datasets.CIFAR10(
    root="./data", train=True, download=True, transform=normalize
)
train_loader = DataLoader(train_set, batch_size=256, shuffle=True, num_workers=2)

# we split held out data into test and validation set
held_out = torchvision.datasets.CIFAR10(
    root="./data", train=False, download=True, transform=normalize
)
test_set, val_set = torch.utils.data.random_split(held_out, [0.5, 0.5], generator=RNG)
test_loader = DataLoader(test_set, batch_size=256, shuffle=False, num_workers=2)
val_loader = DataLoader(val_set, batch_size=256, shuffle=False, num_workers=2)



GEN1 = torch.Generator().manual_seed(42)
retain_set, forget_set = torch.utils.data.random_split(train_set,[1-SPLIT,SPLIT],GEN1)


forget_loader = torch.utils.data.DataLoader(
    forget_set, batch_size=256, shuffle=True, num_workers=2 , generator=RNG
)
retain_loader = torch.utils.data.DataLoader(
    retain_set, batch_size=256, shuffle=True, num_workers=2, generator=RNG
)

In [None]:
def readout(model,name):
  RNG = seed_everything(SEED)
  test_entropies = compute_entropy(model, test_loader)
  retain_entropies = compute_entropy(model, retain_loader)
  forget_entropies = compute_entropy(model, forget_loader)


  results[f"test_entropies_{name}"] = test_entropies.tolist()
  results[f"retain_entropies_{name}"] = retain_entropies.tolist()
  results[f"forget_entropies_{name}"] = forget_entropies.tolist()

  test_losses = compute_losses(model, test_loader)
  retain_losses = compute_losses(model, retain_loader)
  forget_losses = compute_losses(model, forget_loader)

  results[f"test_losses_{name}"] = test_losses.tolist()
  results[f"retain_losses_{name}"] = retain_losses.tolist()
  results[f"forget_losses_{name}"] = forget_losses.tolist()

  # Since we have more forget losses than test losses, sub-sample them, to have a class-balanced dataset.
  gen = np.random.default_rng(1)
  if len(test_losses) > len(forget_losses):
    gen.shuffle(test_losses)
    test_losses = test_losses[: len(forget_losses)]
  else:
    gen.shuffle(forget_losses)
    forget_losses = forget_losses[: len(test_losses)]
    # make sure we have a balanced dataset for the MIA
  assert len(test_losses) == len(forget_losses)

  samples_mia = np.concatenate((test_losses, forget_losses)).reshape((-1, 1))
  labels_mia = [0] * len(test_losses) + [1] * len(forget_losses)

  mia_scores = simple_mia(samples_mia, labels_mia)

  print(
      f"The MIA has an accuracy of {mia_scores.mean():.3f} on forgotten vs unseen images"
  )

  results[f"MIA_losses_{name}"] = mia_scores.mean()

  gen = np.random.default_rng(1)
  if len(test_entropies) > len(forget_entropies):
    gen.shuffle(test_entropies)
    test_entropies = test_entropies[: len(forget_entropies)]
  else:
    gen.shuffle(forget_entropies)
    forget_entropies = forget_entropies[: len(test_entropies)]
    # make sure we have a balanced dataset for the MIA
  assert len(test_entropies) == len(forget_entropies)

  samples_mia = np.concatenate((test_entropies, forget_entropies)).reshape((-1, 1))
  labels_mia = [0] * len(test_entropies) + [1] * len(forget_entropies)

  mia_scores = simple_mia(samples_mia, labels_mia)

  print(
      f"The MIA has an accuracy of {mia_scores.mean():.3f} on forgotten vs unseen images"
  )

  results[f"MIA_entropies_{name}"] = mia_scores.mean()

  results[f"train_accuracy_{name}"] = accuracy(model, train_loader)
  results[f"test_accuracy_{name}"] = accuracy(model, test_loader)
  results[f"forget_accuracy_{name}"] = accuracy(model, forget_loader)

In [None]:
def load_pretrained_model(DEVICE):
  local_path = github_folder + "starting_checkpoint_CIFAR10.pth"
  weights_pretrained = torch.load(local_path, map_location=DEVICE)

  # load model with pre-trained weights
  model = resnet18(weights=None, num_classes=10)
  model.load_state_dict(weights_pretrained)
  model.to(DEVICE)
  model.eval();
  return model

model = load_pretrained_model(DEVICE)
readout(model,"original")

In [None]:
def unlearning(net, retain, forget, validation, start_alpha = 0.1, alpha_sched = lambda start_a,a,max_ep,ep: a-(start_a/max_ep), lr = 0.001, forget_epochs = 27, use_scheduler = True):
    import math


    ### FORGETTING ###
    criterion = nn.CrossEntropyLoss(reduction='none')
    optimizer = optim.AdamW(net.parameters(), lr= lr )
    forget_epochs = forget_epochs
    scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.1, total_iters=forget_epochs)

    net.eval()

    alpha = start_alpha

    net.eval()

    retain_iter = iter(retain)

    def entropy(outputs):

        p = torch.nn.functional.softmax(outputs, dim=-1)
        return (-torch.where(p > 0, p * p.log(), p.new([0.0])).sum(dim=-1, keepdim=False))

    for i in range(forget_epochs):
      net.eval()

      if i%5==0 :
        print("Computing current moments on test set")
        val_loss, first_test_moment, second_test_moment, test_std = compute_moments(net, validation)
        #train_mean, first_train_moment, second_train_moment,train_std = compute_moments(net,retain_loader)
        print("Computed moments: "+str(val_loss)+","+str(first_test_moment)+","+str(second_test_moment))


      ft_forget_losses = compute_losses(net, forget)
      ft_test_losses = compute_losses(net, test_loader)

      gen = np.random.default_rng(1)

      if len(ft_test_losses) > len(ft_forget_losses):
        gen.shuffle(ft_test_losses)
        ft_test_losses = ft_test_losses[: len(ft_forget_losses)]
      else:
        gen.shuffle(ft_forget_losses)
        ft_forget_losses = ft_forget_losses[: len(ft_test_losses)]

      # make sure we have a balanced dataset for the MIA
      assert len(ft_test_losses) == len(ft_forget_losses)

      ft_samples_mia = np.concatenate((ft_test_losses, ft_forget_losses)).reshape((-1, 1))
      labels_mia = [0] * len(ft_test_losses) + [1] * len(ft_forget_losses)

      ft_mia_scores = simple_mia(ft_samples_mia, labels_mia)

      print(
          f"The MIA has an accuracy of {ft_mia_scores.mean():.3f} on forgotten vs unseen images"
      )
      mia_metric_scores.append(ft_mia_scores.mean())

      acc = 100.0 * accuracy(net, test_loader)
      print(f"Accuracy on test set: {acc:.1f} ")
      accuracy_metric_scores.append(acc)


      net.eval()

      print("Forgetting epoch "+str(i))

      if i % (len(retain)//len(forget))==0:
        print("Resetting retain iterator...")
        retain_iter = iter(retain)


      print("using alpha: "+str(alpha))
      for c , (inputs, targets) in enumerate(forget):


        net.zero_grad()
        inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
        out = net(inputs)

        r_inputs, r_targets = next(retain_iter)
        r_inputs, r_targets = r_inputs.to(DEVICE), r_targets.to(DEVICE)
        r_out = net(r_inputs)


        forget_losses = criterion(out, targets)
        retain_losses = criterion(r_out,r_targets)

        #Forget loss metrics
        forget_mean = torch.mean(forget_losses)
        #print(forget_mean)
        forget_var = torch.mean((forget_losses-forget_mean)**2)
        forget_std = forget_var**0.5
        forget_skew = torch.mean((forget_losses-forget_mean)**3) / (forget_std**3)

        delta_val_loss =  (val_loss - forget_mean)
        delta_first_moment = (first_test_moment - forget_var)
        delta_second_moment = (second_test_moment - forget_skew)

        #Retain loss metric
        retain_mean = torch.mean(retain_losses)


        if c % 40 == 0:
          print("delta_val_loss: "+str(delta_val_loss.item()))
          print("delta_first_moment: "+str(delta_first_moment.item()))
          print("delta_second_moment: "+str(delta_second_moment.item()))

        loss =  alpha*(torch.nn.functional.relu(delta_val_loss)**2) + (1-alpha)* retain_mean


        loss.backward()
        optimizer.step()

      alpha = alpha_sched(start_alpha,alpha,forget_epochs,i)
      if use_scheduler:
        scheduler.step()


    net.eval()
    ft_forget_losses = compute_losses(net, forget_loader)
    ft_test_losses = compute_losses(net, test_loader)

    gen = np.random.default_rng(1)

    if len(ft_test_losses) > len(ft_forget_losses):
      gen.shuffle(ft_test_losses)
      ft_test_losses = ft_test_losses[: len(ft_forget_losses)]
    else:
      gen.shuffle(ft_forget_losses)
      ft_forget_losses = ft_forget_losses[: len(ft_test_losses)]
    # make sure we have a balanced dataset for the MIA
    assert len(ft_test_losses) == len(ft_forget_losses)

    ft_samples_mia = np.concatenate((ft_test_losses, ft_forget_losses)).reshape((-1, 1))
    labels_mia = [0] * len(ft_test_losses) + [1] * len(ft_forget_losses)

    ft_mia_scores = simple_mia(ft_samples_mia, labels_mia)

    print(
        f"The MIA has an accuracy of {ft_mia_scores.mean():.3f} on forgotten vs unseen images"
    )

    mia_metric_scores.append(ft_mia_scores.mean())

    acc = 100.0 * accuracy(net, test_loader)
    print(f"Accuracy on test set: {acc:.1f} ")
    accuracy_metric_scores.append(acc)

    net.eval()
    return net

In [None]:
RNG = seed_everything(SEED)
ft_model = load_pretrained_model(DEVICE)

forget_loader = torch.utils.data.DataLoader(
    forget_set, batch_size=256, shuffle=True, num_workers=2 , generator=RNG
)
retain_loader = torch.utils.data.DataLoader(
    retain_set, batch_size=256, shuffle=True, num_workers=2, generator=RNG
)
# Execute the unlearing routine. This might take a few minutes.
# If run on colab, be sure to be running it on  an instance with GPUs
accuracy_metric_scores = []
mia_metric_scores = []
print(len(train_loader))
print(len(retain_loader))
print(len(forget_loader))

def alpha_sched(start_a,a,max_ep,ep):
  return a - (start_a/(max_ep))

forget_epochs = int((len(retain_loader) / (len(forget_loader)*2)) * 6)
print(forget_epochs)
ft_model = unlearning(ft_model, retain_loader, forget_loader, val_loader,start_alpha=STARTING_ALPHA, alpha_sched=alpha_sched, lr = 0.001, forget_epochs = forget_epochs, use_scheduler = True)

results["mia_scores_ours"] = mia_metric_scores
results["accuracy_scores_ours"] = accuracy_metric_scores

In [None]:
readout(ft_model,"ours")

In [None]:
RNG = RNG = seed_everything(SEED)
finetuning_model = load_pretrained_model(DEVICE)

forget_loader = torch.utils.data.DataLoader(
    forget_set, batch_size=256, shuffle=True, num_workers=2 , generator=RNG
)
retain_loader = torch.utils.data.DataLoader(
    retain_set, batch_size=256, shuffle=True, num_workers=2, generator=RNG
)
# Execute the unlearing routine. This might take a few minutes.
# If run on colab, be sure to be running it on  an instance with GPUs
accuracy_metric_scores = []
mia_metric_scores = []

print(len(train_loader))


print(len(forget_loader))

def alpha_sched(start_a,a,max_ep,ep):
  #return start_a
  #if ep > 10:
    #return a - (start_a/(max_ep-10))
  #else:
   # return a
  #return start_a /(ep+1)
  return a - (start_a/(max_ep))

#TODO: Relu with other parameters
#TODO: Relu with entropy
forget_epochs = int((len(retain_loader) / (len(forget_loader))) * 6)
print(forget_epochs)
#starting alpha 0 is equivalent to finetuning (epochs need to be adjusted accordingly)
finetuning_model = unlearning(finetuning_model, retain_loader, forget_loader, test_loader,start_alpha=0, alpha_sched=alpha_sched, lr = 0.001, forget_epochs = forget_epochs, use_scheduler = True)

results["mia_scores_finetuning"] = mia_metric_scores
results["accuracy_scores_finetuning"] = accuracy_metric_scores

In [None]:
readout(finetuning_model,"finetuning")

In [None]:
plt.figure(figsize=(10,5))
plt.title("Ours vs Finetuning Baseline")


plt.plot(np.linspace(0.0,6.0,len(results["mia_scores_ours"])),results["mia_scores_ours"],label="Gradient reversing (Ours)")
plt.plot(np.linspace(0.0,6.0,len(results["mia_scores_finetuning"])),results["mia_scores_finetuning"],label="Finetuning Baseline")

plt.hlines(0.5, 0 ,6 , color=["red"],linestyles='dashed',label="Perfect unlearning (Retraining)")
plt.axhspan(0.505, 0.495, facecolor="red", alpha=0.2)

plt.xticks(np.arange(0.2,6,0.2),minor=True)
plt.xlabel("Iterations")
plt.ylabel("MIA Accuracy")
plt.legend()
plt.show()

In [None]:
plt.figure(figsize=(10,5))
plt.title("Ours vs Finetuning Baseline")

plt.plot(np.linspace(0.0,6.0,len(results["accuracy_scores_ours"])),results["accuracy_scores_ours"],label="Gradient reversing (Ours)")
plt.plot(np.linspace(0.0,6.0,len(results["accuracy_scores_finetuning"])),results["accuracy_scores_finetuning"],label="Finetuning Baseline")
plt.hlines(results["accuracy_scores_ours"][0], 0 ,6 , color=["red"],linestyles='dashed',label="Perfect unlearning (Retraining)")

plt.xticks(np.arange(0.2,6,0.2),minor=True)
plt.xlabel("Iterations")
plt.ylabel("Test set Accuracy")
plt.legend()
plt.show()

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

#ax1.set_title(f"Pre-trained model.\nAttack accuracy: {mia_scores.mean():0.3f}")
ax1.hist(results["test_losses_ours"], density=True, alpha=0.5, bins=50, label="Test set on pretrained model")
ax1.hist(results["forget_losses_ours"], density=True, alpha=0.5, bins=50, label="Forget set on unlearned model")
#ax1.hist(test_losses, density=True, alpha=0.3, bins=50, label="Forget set on unlearned model")

#ax2.set_title(f"Unlearned by fine-tuning.\nAttack accuracy: {ft_mia_scores.mean():0.3f}")
ax2.hist(results["test_losses_original"], density=True, alpha=0.5, bins=50, label="Retain set on pretrained model")
ax2.hist(results["forget_losses_original"], density=True, alpha=0.5, bins=50, label="Retain set on unlearned model")

ax1.set_xlabel("Loss")
ax2.set_xlabel("Loss")
ax1.set_ylabel("Frequency")
ax1.set_yscale("log")
ax2.set_yscale("log")
#ax1.set_xlim((0, max(np.max(ft_test_losses),np.max(ft_forget_losses))))
#ax2.set_xlim((0, max(np.max(test_losses),np.max(forget_losses))))
for ax in (ax1, ax2):
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
ax1.legend(frameon=False, fontsize=14)
ax2.legend(frameon=False, fontsize=14)
plt.show()

In [None]:
from SCRUB.thirdparty.repdistiller.helper.util import adjust_learning_rate as sgda_adjust_learning_rate
from SCRUB.thirdparty.repdistiller.distiller_zoo import DistillKL, HintLoss, Attention, Similarity, Correlation, VIDLoss, RKDLoss
from SCRUB.thirdparty.repdistiller.distiller_zoo import PKT, ABLoss, FactorTransfer, KDSVD, FSP, NSTLoss

from SCRUB.thirdparty.repdistiller.helper.loops import train_distill, train_distill_hide, train_distill_linear, train_vanilla, train_negrad, train_bcu, train_bcu_distill, validate
from SCRUB.thirdparty.repdistiller.helper.pretrain import init

import copy

!mkdir checkpoints

def scrub(teacher, student):

    class AttributeDict(dict):
      __getattr__ = dict.__getitem__
      __setattr__ = dict.__setitem__
      __delattr__ = dict.__delitem__
    args = AttributeDict({})
    args['optim'] = 'sgd'
    args['gamma'] = 1
    args['alpha'] = 0.5
    args['beta'] = 0
    args['smoothing'] = 0.5
    args['msteps'] = 3
    args['clip'] = 0.2
    args['sstart'] = 10
    args['kd_T'] = 4
    args['distill'] = 'kd'

    args['sgda_epochs'] = 6
    args['sgda_learning_rate'] = 0.0005
    args['lr_decay_epochs'] = [3,5,9]
    args['lr_decay_rate'] = 0.1
    args['sgda_weight_decay'] = 5e-4
    args['sgda_momentum'] = 0.9

    args['model'] = "resnet18"
    args['dataset'] = "cifar10"
    args['seed'] =  1


    print(args)
    print(args.clip)
    model_t = copy.deepcopy(teacher)
    model_s = copy.deepcopy(student)

    #this is from https://github.com/ojus1/SmoothedGradientDescentAscent/blob/main/SGDA.py
    #For SGDA smoothing
    beta = 0.1
    def avg_fn(averaged_model_parameter, model_parameter, num_averaged): return (
        1 - beta) * averaged_model_parameter + beta * model_parameter
    swa_model = torch.optim.swa_utils.AveragedModel(
        model_s, avg_fn=avg_fn)

    module_list = nn.ModuleList([])
    module_list.append(model_s)
    trainable_list = nn.ModuleList([])
    trainable_list.append(model_s)

    criterion_cls = nn.CrossEntropyLoss()
    criterion_div = DistillKL(args.kd_T)
    criterion_kd = DistillKL(args.kd_T)


    criterion_list = nn.ModuleList([])
    criterion_list.append(criterion_cls)    # classification loss
    criterion_list.append(criterion_div)    # KL divergence loss, original knowledge distillation
    criterion_list.append(criterion_kd)     # other knowledge distillation loss

    acc_fs = []

    # optimizer
    if args.optim == "sgd":
        optimizer = optim.SGD(trainable_list.parameters(),
                              lr=args.sgda_learning_rate,
                              momentum=args.sgda_momentum,
                              weight_decay=args.sgda_weight_decay)
    elif args.optim == "adam":
        optimizer = optim.Adam(trainable_list.parameters(),
                              lr=args.sgda_learning_rate,
                              weight_decay=args.sgda_weight_decay)
    elif args.optim == "rmsp":
        optimizer = optim.RMSprop(trainable_list.parameters(),
                              lr=args.sgda_learning_rate,
                              momentum=args.sgda_momentum,
                              weight_decay=args.sgda_weight_decay)

    module_list.append(model_t)

    if torch.cuda.is_available():
        module_list.cuda()
        criterion_list.cuda()
        import torch.backends.cudnn as cudnn
        cudnn.benchmark = True
        swa_model.cuda()

    scrub_name = "checkpoints/scrub_{}_{}_seed{}_step".format(args.model, args.dataset, args.seed)
    for epoch in range(1, args.sgda_epochs + 1):

        lr = sgda_adjust_learning_rate(epoch, args, optimizer)

        acc_f, acc5_f, loss_f = validate(forget_loader, model_s, criterion_cls, args, True)
        acc_fs.append(100-acc_f.item())


        maximize_loss = 0
        if epoch <= args.msteps:
            maximize_loss = train_distill(epoch, forget_loader, module_list, swa_model, criterion_list, optimizer, args, "maximize")
        train_acc, train_loss = train_distill(epoch, retain_loader, module_list, swa_model, criterion_list, optimizer, args, "minimize",)
        if epoch >= args.sstart:
            swa_model.update_parameters(model_s)

        torch.save(model_s.state_dict(), scrub_name+str(epoch)+".pt")


        print ("maximize loss: {:.2f}\t minimize loss: {:.2f}\t train_acc: {}".format(maximize_loss, train_loss, train_acc))



    acc_f, acc5_f, loss_f = validate(forget_loader, model_s, criterion_cls, args, True)
    acc_fs.append(100-acc_f.item())





    try:
        selected_idx, _ = min(enumerate(acc_fs), key=lambda x: abs(x[1]-acc_fvs[-1]))
    except:
        selected_idx = len(acc_fs) - 1
    print ("the selected index is {}".format(selected_idx))
    selected_model = "checkpoints/scrub_{}_{}_seed{}_step{}.pt".format(args.model, args.dataset, args.seed, int(selected_idx))
    model_s_final = copy.deepcopy(model_s)
    model_s.load_state_dict(torch.load(selected_model))


    return model_s, model_s_final

In [None]:
RNG = RNG = seed_everything(SEED)
teacher = load_pretrained_model(DEVICE)
student = load_pretrained_model(DEVICE)

model_s, model_s_final = scrub(teacher, student)

readout(model_s,"scrubR")
readout(model_s_final,"scrub")

In [None]:
!pip install wandb
import SSD.src.ssd as ssd


def ssd_tuning(
    model,
    unlearning_teacher,
    retain_train_dl,
    retain_valid_dl,
    forget_train_dl,
    forget_valid_dl,
    valid_dl,
    dampening_constant,
    selection_weighting,
    full_train_dl,
    device,
    **kwargs,
):
    parameters = {
        "lower_bound": 1,  # unused
        "exponent": 1,  # unused
        "magnitude_diff": None,  # unused
        "min_layer": -1,  # -1: all layers are available for modification
        "max_layer": -1,  # -1: all layers are available for modification
        "forget_threshold": 1,  # unused
        "dampening_constant": dampening_constant,  # Lambda from paper
        "selection_weighting": selection_weighting,  # Alpha from paper
    }

    # load the trained model
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

    pdr = ssd.ParameterPerturber(model, optimizer, device, parameters)

    model = model.eval()

    # Calculation of the forget set importances
    sample_importances = pdr.calc_importance(forget_train_dl)

    # Calculate the importances of D (see paper); this can also be done at any point before forgetting.
    original_importances = pdr.calc_importance(full_train_dl)

    # Dampen selected parameters
    pdr.modify_weight(original_importances, sample_importances)

    return model

In [None]:
RNG = seed_everything(SEED)

ssd_model = load_pretrained_model(DEVICE)
unlearning_teacher = resnet18(weights=None, num_classes=10)

kwargs = {
    "model": ssd_model,
    "unlearning_teacher": unlearning_teacher,
    "retain_train_dl": retain_loader,
    "retain_valid_dl": test_loader,
    "forget_train_dl": forget_loader,
    "forget_valid_dl": forget_loader,
    "full_train_dl": train_loader,
    "valid_dl": test_loader,
    "dampening_constant": 1,
    "selection_weighting": 10 * 1,
    "num_classes": 10,
    "dataset_name": 'Cifar10',
    "device": DEVICE,
    "model_name": 'resnet18',
}



ssd_model = ssd_tuning(**kwargs)

readout(ssd_model,"ssd")

In [None]:
def epoch_end(model, epoch, result):
    print(
        "Epoch [{}], last_lr: {:.5f}, train_loss: {:.4f}, val_loss: {:.4f}".format(
            epoch,
            result["lrs"][-1],
            result["train_loss"],
            result["Loss"],
            #result["Acc"],
        )
    )

def training_step(model, batch, device):
    images, clabels = batch
    images, clabels = images.to(device), clabels.to(device)
    out = model(images)  # Generate predictions
    loss = nn.functional.cross_entropy(out, clabels)  # Calculate loss
    return loss


@torch.no_grad()
def evaluate(model, val_loader, device):
    model.eval()
    outputs = [validation_step(model, batch, device) for batch in val_loader]
    return validation_epoch_end(model, outputs)

def validation_step(model, batch, device):
    images, clabels = batch
    images, clabels = images.to(device), clabels.to(device)
    out = model(images)  # Generate predictions
    loss = nn.functional.cross_entropy(out, clabels)  # Calculate loss
    #acc = accuracy(out, clabels)  # Calculate accuracy
    return {"Loss": loss.detach()}#, "Acc": acc}

def validation_epoch_end(model, outputs):
    batch_losses = [x["Loss"] for x in outputs]
    epoch_loss = torch.stack(batch_losses).mean()  # Combine losses
    #batch_accs = [x["Acc"] for x in outputs]
    #epoch_acc = torch.stack(batch_accs).mean()  # Combine accuracies
    return {"Loss": epoch_loss.item()}#, "Acc": epoch_acc.item()}

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

def fit_one_unlearning_cycle(epochs, model, train_loader, val_loader, lr, device):
    history = []

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    for epoch in range(epochs):
        model.train()
        train_losses = []
        lrs = []
        for batch in train_loader:
            loss = training_step(model, batch, device)
            loss.backward()
            train_losses.append(loss.detach().cpu())

            optimizer.step()
            optimizer.zero_grad()

            lrs.append(get_lr(optimizer))

        result = evaluate(model, val_loader, device)
        result["train_loss"] = torch.stack(train_losses).mean()
        result["lrs"] = lrs
        epoch_end(model, epoch, result)
        history.append(result)
    return history

def amnesiac(
    model,
    unlearning_teacher,
    retain_train_dl,
    retain_valid_dl,
    forget_train_dl,
    forget_valid_dl,
    valid_dl,
    num_classes,
    device,
    **kwargs,
):
    unlearninglabels = list(range(num_classes))
    unlearning_trainset = []

    for x, clabel in forget_train_dl.dataset:
        rnd = random.choice(unlearninglabels)
        while rnd == clabel:
            rnd = random.choice(unlearninglabels)
        unlearning_trainset.append((x, rnd))

    for x, y in retain_train_dl.dataset:
        unlearning_trainset.append((x, y))

    unlearning_train_set_dl = DataLoader(
        unlearning_trainset, 128, pin_memory=True, shuffle=True
    )

    _ = fit_one_unlearning_cycle(
        3, model, unlearning_train_set_dl, retain_valid_dl, device=device, lr=0.0001
    )
    return model

In [None]:
RNG = seed_everything(SEED)

amnesic_model = load_pretrained_model(DEVICE)
unlearning_teacher = resnet18(weights=None, num_classes=10)


kwargs = {
    "model": amnesic_model,
    "unlearning_teacher": unlearning_teacher,
    "retain_train_dl": retain_loader,
    "retain_valid_dl": test_loader,
    "forget_train_dl": forget_loader,
    "forget_valid_dl": forget_loader,
    "full_train_dl": train_loader,
    "valid_dl": test_loader,
    "dampening_constant": 1,
    "selection_weighting": 10 * 1,
    "num_classes": 10,
    "dataset_name": 'Cifar10',
    "device": DEVICE,
    "model_name": 'resnet18',
}



amnesiac_model = amnesiac(**kwargs)

readout(amnesiac_model,"amnesiac")

In [None]:
with open(drive_folder+f"results_Cifar10_SPLIT_{int(SPLIT*100)}%_SEED_{SEED}.json", 'w') as fout:
  json.dump(results, fout)