In [None]:
import torch
from torch.utils.data import DataLoader

In [None]:
from torchvision.datasets import CIFAR100
import torch
from torch.utils.data import Dataset
from torchvision import transforms

transform_train = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

class CustomCIFAR100(CIFAR100):
    def __init__(self, root, train, download, transform):
        super().__init__(root = root, train = train, download = download, transform = transform)
        self.coarse_map = {
            0:[4, 30, 55, 72, 95],
            1:[1, 32, 67, 73, 91],
            2:[54, 62, 70, 82, 92],
            3:[9, 10, 16, 28, 61],
            4:[0, 51, 53, 57, 83],
            5:[22, 39, 40, 86, 87],
            6:[5, 20, 25, 84, 94],
            7:[6, 7, 14, 18, 24],
            8:[3, 42, 43, 88, 97],
            9:[12, 17, 37, 68, 76],
            10:[23, 33, 49, 60, 71],
            11:[15, 19, 21, 31, 38],
            12:[34, 63, 64, 66, 75],
            13:[26, 45, 77, 79, 99],
            14:[2, 11, 35, 46, 98],
            15:[27, 29, 44, 78, 93],
            16:[36, 50, 65, 74, 80],
            17:[47, 52, 56, 59, 96],
            18:[8, 13, 48, 58, 90],
            19:[41, 69, 81, 85, 89]
        }


    def __getitem__(self, index):
        x, y = super().__getitem__(index)
        coarse_y = None
        for i in range(20):
            for j in self.coarse_map[i]:
                if y == j:
                    coarse_y = i
                    break
            if coarse_y != None:
                break
        if coarse_y == None:
            print(y)
            assert coarse_y != None
        return x, y, coarse_y


class UnLearningData(Dataset):
    def __init__(self, forget_data, retain_data):
        super().__init__()
        self.forget_data = forget_data
        self.retain_data = retain_data
        self.forget_len = len(forget_data)
        self.retain_len = len(retain_data)

    def __len__(self):
        return self.retain_len + self.forget_len

    def __getitem__(self, index):
        if(index < self.forget_len):
            x = self.forget_data[index][0]
            y = 1
            return x,y
        else:
            x = self.retain_data[index - self.forget_len][0]
            y = 0
            return x,y

In [None]:
import torch
from torch import nn
from torch.nn import functional as F


def accuracy(outputs, labels):
    _, preds = torch.max(outputs, dim=1)
    return torch.tensor(torch.sum(preds == labels).item() / len(preds))*100

def training_step(model, batch, device):
    images, labels, clabels = batch
    images, clabels = images.to(device), clabels.to(device)
    out = model(images)
    loss = F.cross_entropy(out, clabels)
    return loss

def validation_step(model, batch, device):
    images, labels, clabels = batch
    images, clabels = images.to(device), clabels.to(device)
    out = model(images)
    loss = F.cross_entropy(out, clabels)
    acc = accuracy(out, clabels)
    return {'Loss': loss.detach(), 'Acc': acc}

def validation_epoch_end(model, outputs):
    batch_losses = [x['Loss'] for x in outputs]
    epoch_loss = torch.stack(batch_losses).mean()
    batch_accs = [x['Acc'] for x in outputs]
    epoch_acc = torch.stack(batch_accs).mean()
    return {'Loss': epoch_loss.item(), 'Acc': epoch_acc.item()}

def epoch_end(model, epoch, result):
    print("Epoch [{}], last_lr: {:.5f}, train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format(
        epoch, result['lrs'][-1], result['train_loss'], result['Loss'], result['Acc']))

@torch.no_grad()
def evaluate(model, val_loader, device):
    model.eval()
    outputs = [validation_step(model, batch, device) for batch in val_loader]
    return validation_epoch_end(model, outputs)

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

def fit_one_cycle(epochs,  model, train_loader, val_loader, device, pretrained_lr=0.001, finetune_lr=0.01):
    torch.cuda.empty_cache()
    history = []

    try:
        param_groups = [
            {'params':model.base.parameters(),'lr':pretrained_lr},
            {'params':model.final.parameters(),'lr':finetune_lr}
        ]
        optimizer = torch.optim.Adam(param_groups)
    except:
        optimizer = torch.optim.Adam(model.parameters(), finetune_lr)

    sched = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)

    for epoch in range(epochs):
        model.train()
        train_losses = []
        lrs = []
        for batch in train_loader:
            loss = training_step(model, batch, device)
            train_losses.append(loss)
            loss.backward()

            optimizer.step()
            optimizer.zero_grad()

            lrs.append(get_lr(optimizer))


        # Validation phase
        result = evaluate(model, val_loader, device)
        result['train_loss'] = torch.stack(train_losses).mean().item()
        result['lrs'] = lrs
        epoch_end(model, epoch, result)
        history.append(result)
        sched.step(result['Loss'])
    return history

In [None]:
from torch import nn
import numpy as np
import torch
from torchvision.models import resnet18

class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()

    def forward(self, x):
        return x

class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()
    def forward(self,x):
        return x.view(x.size(0), -1)

class ConvStandard(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=None, output_padding=0, w_sig =\
                 np.sqrt(1.0)):
        super(ConvStandard, self).__init__(in_channels, out_channels,kernel_size)
        self.in_channels=in_channels
        self.out_channels=out_channels
        self.kernel_size=kernel_size
        self.stride=stride
        self.padding=padding
        self.w_sig = w_sig
        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.normal_(self.weight, mean=0, std=self.w_sig/(self.in_channels*np.prod(self.kernel_size)))
        if self.bias is not None:
            torch.nn.init.normal_(self.bias, mean=0, std=0)

    def forward(self, input):
        return F.conv2d(input,self.weight,self.bias,self.stride,self.padding)

class Conv(nn.Sequential):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=None, output_padding=0,
                 activation_fn=nn.ReLU, batch_norm=True, transpose=False):
        if padding is None:
            padding = (kernel_size - 1) // 2
        model = []
        if not transpose:
             model += [nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
                                bias=not batch_norm)]
        else:
            model += [nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding,
                                         output_padding=output_padding, bias=not batch_norm)]
        if batch_norm:
            model += [nn.BatchNorm2d(out_channels, affine=True)]
        model += [activation_fn()]
        super(Conv, self).__init__(*model)
class ResNet18(nn.Module):
    def __init__(self, num_classes, pretrained):
        super().__init__()
        base = resnet18(pretrained=pretrained)
        self.base = nn.Sequential(*list(base.children())[:-1])
        in_features = base.fc.in_features
        self.drop = nn.Dropout()
        self.final = nn.Linear(in_features,num_classes)

    def forward(self,x):
        x = self.base(x)
        x = self.drop(x.view(-1,self.final.in_features))
        return self.final(x)


In [None]:
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
import numpy as np


def UnlearnerLoss(output, labels, full_teacher_logits, unlearn_teacher_logits, KL_temperature):
    labels = torch.unsqueeze(labels, dim = 1)

    f_teacher_out = F.softmax(full_teacher_logits / KL_temperature, dim=1)
    u_teacher_out = F.softmax(unlearn_teacher_logits / KL_temperature, dim=1)

    # label 1 means forget sample
    # label 0 means retain sample
    overall_teacher_out = labels * u_teacher_out + (1-labels)*f_teacher_out
    student_out = F.log_softmax(output / KL_temperature, dim=1)
    return F.kl_div(student_out, overall_teacher_out)

def unlearning_step(model, unlearning_teacher, full_trained_teacher, unlearn_data_loader, optimizer,
            device, KL_temperature):
    losses = []
    for batch in unlearn_data_loader:
        x, y = batch
        x, y = x.to(device), y.to(device)
        with torch.no_grad():
            full_teacher_logits = full_trained_teacher(x)
            unlearn_teacher_logits = unlearning_teacher(x)
        output = model(x)
        optimizer.zero_grad()
        loss = UnlearnerLoss(output = output, labels=y, full_teacher_logits=full_teacher_logits,
                unlearn_teacher_logits=unlearn_teacher_logits, KL_temperature=KL_temperature)
        loss.backward()
        optimizer.step()
        losses.append(loss.detach().cpu().numpy())
    return np.mean(losses)


def fit_one_unlearning_cycle(epochs,  model, train_loader, val_loader, lr, device):
    history = []

    optimizer = torch.optim.Adam(model.parameters(), lr = lr)


    for epoch in range(epochs):
        model.train()
        train_losses = []
        lrs = []
        for batch in train_loader:
            loss = training_step(model, batch, device)
            loss.backward()
            train_losses.append(loss.detach().cpu())

            optimizer.step()
            optimizer.zero_grad()

            lrs.append(get_lr(optimizer))

        result = evaluate(model, val_loader, device)
        result['train_loss'] = torch.stack(train_losses).mean()
        result['lrs'] = lrs
        epoch_end(model, epoch, result)
        history.append(result)
    return history

def blindspot_unlearner(model, unlearning_teacher, full_trained_teacher, retain_data, forget_data, epochs = 10,
                optimizer = 'adam', lr = 0.01, batch_size = 256, num_workers = 32,
                device = 'cuda', KL_temperature = 1):
    unlearning_data = UnLearningData(forget_data=forget_data, retain_data=retain_data)
    unlearning_loader = DataLoader(unlearning_data, batch_size = batch_size, shuffle=True,
                            num_workers=num_workers, pin_memory=True)

    unlearning_teacher.eval()
    full_trained_teacher.eval()
    optimizer = optimizer
    if optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr = lr)
    else:
        optimizer = optimizer

    for epoch in range(epochs):
        loss = unlearning_step(model = model, unlearning_teacher= unlearning_teacher,
                        full_trained_teacher=full_trained_teacher, unlearn_data_loader=unlearning_loader,
                        optimizer=optimizer, device=device, KL_temperature=KL_temperature)
        print("Epoch {} Unlearning Loss {}".format(epoch+1, loss))



# **Metrics used from the paper directly**

In [None]:
from torch.nn import functional as F
import torch
from sklearn.svm import SVC

def JSDiv(p, q):
    m = (p+q)/2
    return 0.5*F.kl_div(torch.log(p), m) + 0.5*F.kl_div(torch.log(q), m)

# ZRF/UnLearningScore
def UnLearningScore(tmodel, gold_model, forget_dl, batch_size, device):
    model_preds = []
    gold_model_preds = []
    with torch.no_grad():
        for batch in forget_dl:
            x, y, cy = batch
            x = x.to(device)
            model_output = tmodel(x)
            gold_model_output = gold_model(x)
            model_preds.append(F.softmax(model_output, dim = 1).detach().cpu())
            gold_model_preds.append(F.softmax(gold_model_output, dim = 1).detach().cpu())


    model_preds = torch.cat(model_preds, axis = 0)
    gold_model_preds = torch.cat(gold_model_preds, axis = 0)
    return 1-JSDiv(model_preds, gold_model_preds)

@torch.no_grad()
def actv_dist(model1, model2, dataloader, device = 'cuda'):
    sftmx = nn.Softmax(dim = 1)
    distances = []
    for batch in dataloader:
        x, _, _ = batch
        x = x.to(device)
        model1_out = model1(x)
        model2_out = model2(x)
        diff = torch.sqrt(torch.sum(torch.square(F.softmax(model1_out, dim = 1) - F.softmax(model2_out, dim = 1)), axis = 1))
        diff = diff.detach().cpu()
        distances.append(diff)
    distances = torch.cat(distances, axis = 0)
    return distances.mean()

In [None]:
train_ds = CustomCIFAR100(root='.', train=True,download=True, transform=transform_train)
valid_ds = CustomCIFAR100(root='.', train=False,download=True, transform=transform_train)

batch_size = 256
train_dl = DataLoader(train_ds, batch_size, shuffle=True, num_workers=32, pin_memory=True)
valid_dl = DataLoader(valid_ds, batch_size, num_workers=32, pin_memory=True)

Files already downloaded and verified
Files already downloaded and verified


In [None]:
num_classes = 100
classwise_train = {}
for i in range(num_classes):
    classwise_train[i] = []

for img, label, clabel in train_ds:
    classwise_train[label].append((img, label, clabel))

classwise_test = {}
for i in range(num_classes):
    classwise_test[i] = []

for img, label, clabel in valid_ds:
    classwise_test[label].append((img, label, clabel))

In [None]:
device = 'cuda'
model = ResNet18(num_classes = 20, pretrained = True).to(device)
epochs = 5
history = fit_one_cycle(epochs, model, train_dl, valid_dl, device = device)


# Forgetting Rocket class
The Rocket is class 98 in CIFAR100.

**Context**

The CIFAR-100 dataset consists of 60000 32x32 colour images in 100 classes, with 600 images per class. The 100 classes in the CIFAR-100 are grouped into 20 superclasses. Each image comes with a "fine" label (the class to which it belongs) and a "coarse" label (the superclass to which it belongs). There are 50000 training images and 10000 test images.

In [None]:
# Getting the forget and retain validation data
forget_valid = []
forget_classes = [98]
for cls in range(num_classes):
    if cls in forget_classes:
        for img, label, clabel in classwise_test[cls]:
            forget_valid.append((img, label, clabel))

retain_valid = []
for cls in range(num_classes):
    if cls not in forget_classes:
        for img, label, clabel in classwise_test[cls]:
            retain_valid.append((img, label, clabel))

forget_train = []
for cls in range(num_classes):
    if cls in forget_classes:
        for img, label, clabel in classwise_train[cls]:
            forget_train.append((img, label, clabel))

retain_train = []
for cls in range(num_classes):
    if cls not in forget_classes:
        for img, label, clabel in classwise_train[cls]:
            retain_train.append((img, label, clabel))

forget_valid_dl = DataLoader(forget_valid, batch_size, num_workers=32, pin_memory=True)

retain_valid_dl = DataLoader(retain_valid, batch_size, num_workers=32, pin_memory=True)

forget_train_dl = DataLoader(forget_train, batch_size, num_workers=32, pin_memory=True)
retain_train_dl = DataLoader(retain_train, batch_size, num_workers=32, pin_memory=True, shuffle = True)
import random
retain_train_subset = random.sample(retain_train, int(0.3*len(retain_train)))
retain_train_subset_dl = DataLoader(retain_train_subset, batch_size, num_workers=32, pin_memory=True, shuffle = True)

In [None]:
# Performance of Fully trained model on retain set
evaluate(model, retain_valid_dl, device)

{'Loss': 0.535236120223999, 'Acc': 85.77934265136719}

In [None]:
# Performance of Fully trained model on retain set
evaluate(model, forget_valid_dl, device)

{'Loss': 0.5363734364509583, 'Acc': 82.0}

## Retrain the model from Scratch
Create Retrained Model (Gold model). This is the model trained from scratch without forget data.

In [None]:
device = 'cuda'
gold_model = ResNet18(num_classes = 20, pretrained = True).to(device)
epochs = 5
history = fit_one_cycle(epochs, gold_model, retain_train_dl, retain_valid_dl, device = device)


In [None]:
# evaluate gold model on forget set
evaluate(gold_model, forget_valid_dl, device)

{'Loss': 7.545389175415039, 'Acc': 3.0}

In [None]:
# evaluate gold model on retain set
evaluate(gold_model, retain_valid_dl, device)

{'Loss': 0.5325239896774292, 'Acc': 85.76885223388672}

## Unlearning via KL Divergence in Research Paper

In [None]:
device = 'cuda'
unlearning_teacher = ResNet18(num_classes = 20, pretrained = False).to(device).eval()
student_model = ResNet18(num_classes = 20, pretrained = False).to(device)
model = model.eval()

KL_temperature = 1

optimizer = torch.optim.Adam(student_model.parameters(), lr = 0.0001)

blindspot_unlearner(model = student_model, unlearning_teacher = unlearning_teacher, full_trained_teacher = model,
          retain_data = retain_train_subset, forget_data = forget_train, epochs = 1, optimizer = optimizer, lr = 0.0001,
          batch_size = 256, num_workers = 32, device = device, KL_temperature = KL_temperature)

In [None]:
# performance of unlearned model on forget set
evaluate(student_model, forget_valid_dl, device)

{'Loss': 3.3266074657440186, 'Acc': 3.0}

In [None]:
# performance of unlearned model on retain set
evaluate(student_model, retain_valid_dl, device)

{'Loss': 0.5810623168945312, 'Acc': 84.57299041748047}