### MACHINE UNLEARNING

## Import Dependencies

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import utils
import variational
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.ticker import FuncFormatter
from itertools import cycle
import os
import time
import math
import pandas as pd
from collections import OrderedDict
from sklearn.linear_model import LogisticRegression
import copy
import torch.nn as nn
from torch.autograd import Variable
from typing import List
import itertools
from tqdm.autonotebook import tqdm
from models import *
import models
from logger import *
import wandb
from thirdparty.repdistiller.helper.util import adjust_learning_rate as sgda_adjust_learning_rate
from thirdparty.repdistiller.distiller_zoo import DistillKL, HintLoss, Attention, Similarity, Correlation, VIDLoss, RKDLoss
from thirdparty.repdistiller.distiller_zoo import PKT, ABLoss, FactorTransfer, KDSVD, FSP, NSTLoss
from thirdparty.repdistiller.helper.loops import train_distill, train_distill_hide, train_distill_linear, train_vanilla, train_negrad, train_bcu, train_bcu_distill, validate
from thirdparty.repdistiller.helper.pretrain import init
import os
os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID'
os.environ["CUDA_VISIBLE_DEVICES"] = ""

## Train Original Model

This cell will run main.py which generate a model and leave many checkpoints for later references. We
prepare standard_model_for_pilot for the sake of pilot.

dataset: mnist

possible datasets, (specifically refer to the dataset_multiclass.py)
cifar10,small_cifar5,small_cifar6,small_cifar10,small_binary_cifar10,
cifar100,mnist,small_mnist,small_binary_mnist,lacuna100,lacuna10,small_lacuna5,small_lacuna6,small_lacuna10
small_binary_lacuna10,tinyimagenet_pretrain,tinyimagenet_finetune,tinyimagenet_finetune5,mix10,mix100

model: mlp

possible models (refer to models.py)
ntk_wide_resnet, is_wide_resnet, wide_resnet, resnet_small, resnet, allcnn_no_bn, ntk_allcnn
smallallcnn, allcnn, ntk_mlp, ntk_linear, mlp

dataroot: data/MNIST
resume from standard_model_for_pilot.pt in checkpoints folder

In [None]:
%run main.py --dataset mnist --model mlp --dataroot=data/MNIST/ --filters 1.0 --lr 0.001 \
--resume checkpoints/standard_model_for_pilot.pt --disable-bn \
--weight-decay 0.1 --batch-size 128 --epochs 31 --seed 3

# Note
After running this cell, the code will generate a trace of checkpoints under the "checkpoints" folder.
They are useful for later cell codes. But you could remove them after completing the experiment.

Also, the code would also download a MNIST dataset in the data folder. You could remove it after the
experiment.

The code also generate a log file under "logs" folder if one wants to analyze using matlab. You could
remove it after the experiment.

## Train Retain Model

This code is used to train the retain model. The specific change compared to the upper cell (to train
the original model) is the forget class command. This command refers datasets.get_loaders() in main.py
and in dataset_multiclass.py, four functions are used to separate the retain class and the forget
class. They are replace_indexes(), replace_class(), confuse_class(), get_loaders().


In [None]:
%run main.py --dataset mnist --model mlp --dataroot=data/MNIST/ --filters 1.0 --lr 0.001 \
--resume checkpoints/standard_model_for_pilot.pt --disable-bn \
--weight-decay 0.1 --batch-size 128 --epochs 31 \
--forget-class 0,1,2,3,4,5 --num-to-forget 300 --seed 3

## Analysis

This function will count the parameters in the model.

In [None]:
def parameter_count(model):
    count=0
    for p in model.parameters():
        count+=np.prod(np.array(list(p.shape)))
    print(f'Total Number of Parameters: {count}')

Importantly, the model object is an empty object in this. It is not initialized with our previous trained models.
Here, we measure the parameters in the empty model object.

Nore: deepcopy will give you the entire object, while other methods only gives you the reference pointer (for references,
see Java, C basics).

In [None]:
parameter_count(copy.deepcopy(model))

## Initialize original model and retain model

Here, we will use the checkpoints references to get the models that we have trained in the previous cells. The models
that is initialized is in the cache of this jupyter notebook workflow.

Initialize three empty model objects

In [None]:
model_original = copy.deepcopy(model)
model_retain = copy.deepcopy(model)
model_pretrain = copy.deepcopy(model)

We use checkpoint name to refer to the objects. Thus, here we initialize all the name parameters.

In [None]:
arch = args.model
filters=args.filters
arch_filters = arch +'_'+ str(filters).replace('.','_')
augment = False
dataset = args.dataset
class_to_forget = args.forget_class
init_checkpoint = f"checkpoints/{args.name}_init.pt"
num_classes=args.num_classes
num_to_forget = args.num_to_forget
num_total = len(train_loader.dataset)
num_to_retain = num_total - 300#num_to_forget
seed = args.seed
unfreeze_start = None
learningrate=f"lr_{str(args.lr).replace('.','_')}"
batch_size=f"_bs_{str(args.batch_size)}"
lossfn=f"_ls_{args.lossfn}"
wd=f"_wd_{str(args.weight_decay).replace('.','_')}"
seed_name=f"_seed_{args.seed}_"
num_tag = '' if num_to_forget is None else f'_num_{num_to_forget}'
unfreeze_tag = '_' if unfreeze_start is None else f'_unfreeze_from_{unfreeze_start}_'
augment_tag = '' if not augment else f'augment_'
training_epochs=30
log_dict={}
log_dict['epoch']=training_epochs

Here, we get the name of the model we want to get.

In [None]:
m_name = f'checkpoints/{dataset}_{arch_filters}_forget_None{unfreeze_tag}{augment_tag}{learningrate}{batch_size}{lossfn}{wd}{seed_name}{training_epochs}.pt'
m0_name = f'checkpoints/{dataset}_{arch_filters}_forget_{class_to_forget}{num_tag}{unfreeze_tag}{augment_tag}{learningrate}{batch_size}{lossfn}{wd}{seed_name}{training_epochs}.pt'

Make sure the names are correct.

In [None]:
print(m_name)

In [None]:
print(m0_name)

In [None]:
print(init_checkpoint)

We get the original model, the retain, and the pretrained model that we used to train the original model and the
retain model.

In [None]:
model_original.load_state_dict(torch.load(m_name))
model_retain.load_state_dict(torch.load(m0_name))
model_pretrain.load_state_dict(torch.load(init_checkpoint))

## Analysis

We use parameters to see if we actually get the models from the checkpoints.

In [None]:
parameter_count(copy.deepcopy(model_original))
parameter_count(copy.deepcopy(model_retain))
parameter_count(copy.deepcopy(model_pretrain))

We analyze the parameters in the model

In [None]:
for p in model_original.parameters():
    p.data0 = p.data.clone()
print("Model_Original:", p)

for p in model_pretrain.parameters():
    p.data0 = p.data.clone()
print("Model_Pretrain", p)

for p in model_pretrain.parameters():
    p.data0 = p.data.clone()
print("Model_Retain", p)

We use distances to further analyze the three models. First, we define the distance function.

In [None]:
# this method take the model object, then, iterate each parameter in the model.parameters() [this will
# return the parameters set in the model object]. for each parameter object, we name another parameter
# that is the same as the parameter, in named_parameters, later, in the distance function, we use
# this copy of parameters, to measure the distance.
def named_parameters_creat(model_):
    for p in model_.parameters():
        p.data0 = p.data.clone()

def distance(model,model0):
    distance=0
    normalization=0

    # Create named parameters in the model object as a copy of the original parameters
    named_parameters_creat(model)
    named_parameters_creat(model0)

    # Calculate distance iterating through the named parameters
    for (k, p), (k0, p0) in zip(model.named_parameters(), model0.named_parameters()):
        space='  ' if 'bias' in k else ''
        current_dist=(p.data0-p0.data0).pow(2).sum().item()
        current_norm=p.data0.pow(2).sum().item()
        distance+=current_dist
        normalization+=current_norm

    print(f'Distance: {np.sqrt(distance)}')
    print(f'Normalized Distance: {1.0*np.sqrt(distance/normalization)}')
    return 1.0*np.sqrt(distance/normalization)

We compare original model and the retain model.

In [None]:
distance(model_original,model_retain)

We compare the retain model and the pre-trained model

In [None]:
distance(model_pretrain,model_retain)

We compare the pretrain model and the original model

In [None]:
distance(model_pretrain, model_original)

## Prepare Dataset for Unlearning

explicit documentation has been add to datasets_multiclass.py

initialize retain and forget dataset batch size. And get forget dataset.

Forget class has been originated when we train the model in main.py and it is given in jupyter notebook
in our previous initialization section. To change the forget class, we need to re-run the main.py where
we get the retain model.

It is in the name.

In [None]:
args.retain_bs = 32
args.forget_bs = 64
train_loader_full, valid_loader_full, test_loader_full   = datasets.get_loaders(dataset, batch_size=args.batch_size, seed=seed, root=args.dataroot, augment=False, shuffle=True)
marked_loader, _, _ = datasets.get_loaders(dataset, class_to_replace=class_to_forget, num_indexes_to_replace=num_to_forget, only_mark=True, batch_size=1, seed=seed, root=args.dataroot, augment=False, shuffle=True)

In this case, we define a function that puts the dataset into the data loader.

In [None]:
def replace_loader_dataset(data_loader, dataset, batch_size=args.batch_size, seed=1, shuffle=True):
    manual_seed(seed)
    loader_args = {'num_workers': 0, 'pin_memory': False}
    def _init_fn(worker_id):
        np.random.seed(int(seed))
    return torch.utils.data.DataLoader(dataset, batch_size=batch_size,num_workers=0,pin_memory=True,shuffle=shuffle)

Get forget dataset

In [None]:
forget_dataset = copy.deepcopy(marked_loader.dataset)
marked = forget_dataset.targets < 0
forget_dataset.data = forget_dataset.data[marked]
forget_dataset.targets = - forget_dataset.targets[marked] - 1
forget_loader = replace_loader_dataset(train_loader_full, forget_dataset, batch_size=args.forget_bs, seed=seed, shuffle=True)

Get retain dataset

In [None]:
retain_dataset = copy.deepcopy(marked_loader.dataset)
marked = retain_dataset.targets >= 0
retain_dataset.data = retain_dataset.data[marked]
retain_dataset.targets = retain_dataset.targets[marked]
retain_loader = replace_loader_dataset(train_loader_full, retain_dataset, batch_size=args.retain_bs, seed=seed, shuffle=True)

Make sure everything is working

In [1]:
assert(len(forget_dataset) + len(retain_dataset) == len(train_loader_full.dataset))
print (len(forget_loader.dataset))
print (len(retain_loader.dataset))
print (len(test_loader_full.dataset))
print (len(train_loader_full.dataset))
from collections import Counter
print(dict(Counter(train_loader_full.dataset.targets)))

NameError: name 'forget_dataset' is not defined

## Fisher Forgetting

In [None]:
from utils import *
def get_metrics(model,dataloader,criterion,samples_correctness=False,use_bn=False,delta_w=None,scrub_act=False):
    activations=[]
    predictions=[]
    if use_bn:
        model.train()
        dataloader = torch.utils.data.DataLoader(retain_loader.dataset, batch_size=128, shuffle=True)
        for i in range(10):
            for batch_idx, (data, target) in enumerate(dataloader):
                data, target = data.to(args.device), target.to(args.device)
                output = model(data)
    dataloader = torch.utils.data.DataLoader(dataloader.dataset, batch_size=1, shuffle=False)
    model.eval()
    metrics = AverageMeter()
    mult = 0.5 if args.lossfn=='mse' else 1
    for batch_idx, (data, target) in enumerate(dataloader):
        data, target = data.to(args.device), target.to(args.device)
        if args.lossfn=='mse':
            target=(2*target-1)
            target = target.type(torch.cuda.FloatTensor).unsqueeze(1)
        if 'mnist' in args.dataset:
            data=data.view(data.shape[0],-1)
        output = model(data)
        if scrub_act:
            G = []
            for cls in range(num_classes):
                grads = torch.autograd.grad(output[0,cls],model.parameters(),retain_graph=True)
                grads = torch.cat([g.view(-1) for g in grads])
                G.append(grads)
            grads = torch.autograd.grad(output_sf[0,cls],model_scrubf.parameters(),retain_graph=False)
            G = torch.stack(G).pow(2)
            delta_f = torch.matmul(G,delta_w)
            output += delta_f.sqrt()*torch.empty_like(delta_f).normal_()

        loss = mult*criterion(output, target)
        if samples_correctness:
            activations.append(torch.nn.functional.softmax(output,dim=1).cpu().detach().numpy().squeeze())
            predictions.append(get_error(output,target))
        metrics.update(n=data.size(0), loss=loss.item(), error=get_error(output, target))
    if samples_correctness:
        return metrics.avg,np.stack(activations),np.array(predictions)
    else:
        return metrics.avg

def l2_penalty(model,model_init,weight_decay):
    l2_loss = 0
    for (k,p),(k_init,p_init) in zip(model.named_parameters(),model_init.named_parameters()):
        if p.requires_grad:
            l2_loss += (p-p_init).pow(2).sum()
    l2_loss *= (weight_decay/2.)
    return l2_loss

def run_train_epoch(model: nn.Module, model_init, data_loader: torch.utils.data.DataLoader,
                    loss_fn: nn.Module,
                    optimizer: torch.optim.SGD, split: str, epoch: int, ignore_index=None,
                    negative_gradient=False, negative_multiplier=-1, random_labels=False,
                    quiet=False,delta_w=None,scrub_act=False):
    model.eval()
    metrics = AverageMeter()
    num_labels = data_loader.dataset.targets.max().item() + 1

    with torch.set_grad_enabled(split != 'test'):
        for idx, batch in enumerate(tqdm(data_loader, leave=False)):
            batch = [tensor.to(next(model.parameters()).device) for tensor in batch]
            input, target = batch
            output = model(input)
            if split=='test' and scrub_act:
                G = []
                for cls in range(num_classes):
                    grads = torch.autograd.grad(output[0,cls],model.parameters(),retain_graph=True)
                    grads = torch.cat([g.view(-1) for g in grads])
                    G.append(grads)
                grads = torch.autograd.grad(output_sf[0,cls],model_scrubf.parameters(),retain_graph=False)
                G = torch.stack(G).pow(2)
                delta_f = torch.matmul(G,delta_w)
                output += delta_f.sqrt()*torch.empty_like(delta_f).normal_()
            loss = loss_fn(output, target) + l2_penalty(model,model_init,args.weight_decay)
            metrics.update(n=input.size(0), loss=loss_fn(output,target).item(), error=get_error(output, target))

            if split != 'test':
                model.zero_grad()
                loss.backward()
                optimizer.step()
    if not quiet:
        log_metrics(split, metrics, epoch)
    return metrics.avg

def run_neggrad_epoch(model: nn.Module, model_init, data_loader: torch.utils.data.DataLoader,
                    forget_loader: torch.utils.data.DataLoader,
                    alpha: float,
                    loss_fn: nn.Module,
                    optimizer: torch.optim.SGD, split: str, epoch: int, ignore_index=None,
                    quiet=False):
    model.eval()
    metrics = AverageMeter()
    num_labels = data_loader.dataset.targets.max().item() + 1

    with torch.set_grad_enabled(split != 'test'):
        for idx, (batch_retain,batch_forget) in enumerate(tqdm(zip(data_loader,cycle(forget_loader)), leave=False)):
            batch_retain = [tensor.to(next(model.parameters()).device) for tensor in batch_retain]
            batch_forget = [tensor.to(next(model.parameters()).device) for tensor in batch_forget]
            input_r, target_r = batch_retain
            input_f, target_f = batch_forget
            output_r = model(input_r)
            output_f = model(input_f)
            loss = alpha*(loss_fn(output_r, target_r) + l2_penalty(model,model_init,args.weight_decay)) - (1-alpha)*loss_fn(output_f, target_f)
            metrics.update(n=input_r.size(0), loss=loss_fn(output_r,target_r).item(), error=get_error(output_r, target_r))
            if split != 'test':
                model.zero_grad()
                loss.backward()
                optimizer.step()
    if not quiet:
        log_metrics(split, metrics, epoch)
    return metrics.avg
def finetune(model: nn.Module, data_loader: torch.utils.data.DataLoader, lr=0.01, epochs=10, quiet=False):
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=0.0)
    model_init=copy.deepcopy(model)
    for epoch in range(epochs):
        run_train_epoch(model, model_init, data_loader, loss_fn, optimizer, split='train', epoch=epoch, ignore_index=None, quiet=quiet)
        #train_vanilla(epoch, data_loader, model, loss_fn, optimizer, args)

def negative_grad(model: nn.Module, data_loader: torch.utils.data.DataLoader, forget_loader: torch.utils.data.DataLoader, alpha: float, lr=0.01, epochs=10, quiet=False):
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=0.0)
    model_init=copy.deepcopy(model)
    for epoch in range(epochs):
        run_neggrad_epoch(model, model_init, data_loader, forget_loader, alpha, loss_fn, optimizer, split='train', epoch=epoch, ignore_index=None, quiet=quiet)
        #train_negrad(epoch, data_loader, forget_loader, model, loss_fn, optimizer,  alpha)

def fk_fientune(model: nn.Module, data_loader: torch.utils.data.DataLoader, args, lr=0.01, epochs=10, quiet=False):
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=0.0)
    model_init=copy.deepcopy(model)
    for epoch in range(epochs):
        sgda_adjust_learning_rate(epoch, args, optimizer)
        run_train_epoch(model, model_init, data_loader, loss_fn, optimizer, split='train', epoch=epoch, ignore_index=None, quiet=quiet)
        #train_negrad(epoch, data_loader, forget_loader, model, loss_fn, optimizer,  alpha)
def test(model, data_loader):
    loss_fn = nn.CrossEntropyLoss()
    model_init=copy.deepcopy(model)
    return run_train_epoch(model, model_init, data_loader, loss_fn, optimizer=None, split='test', epoch=epoch, ignore_index=None, quiet=True)

def hessian(dataset, model):
    model.eval()
    train_loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False)
    loss_fn = nn.CrossEntropyLoss()

    for p in model.parameters():
        p.grad_acc = 0
        p.grad2_acc = 0

    for data, orig_target in tqdm(train_loader):
        data, orig_target = data.to(args.device), orig_target.to(args.device)
        output = model(data)
        prob = F.softmax(output, dim=-1).data

        for y in range(output.shape[1]):
            target = torch.empty_like(orig_target).fill_(y)
            loss = loss_fn(output, target)
            model.zero_grad()
            loss.backward(retain_graph=True)
            for p in model.parameters():
                if p.requires_grad:
                    p.grad_acc += (orig_target == target).float() * p.grad.data
                    p.grad2_acc += prob[:, y] * p.grad.data.pow(2)
    for p in model.parameters():
        p.grad_acc /= len(train_loader)
        p.grad2_acc /= len(train_loader)
        def get_mean_var(p, is_base_dist=False, alpha=3e-6):
    var = copy.deepcopy(1./(p.grad2_acc+1e-8))
    var = var.clamp(max=1e3)
    if p.size(0) == num_classes:
        var = var.clamp(max=1e2)
    var = alpha * var

    if p.ndim > 1:
        var = var.mean(dim=1, keepdim=True).expand_as(p).clone()
    if not is_base_dist:
        mu = copy.deepcopy(p.data0.clone())
    else:
        mu = copy.deepcopy(p.data0.clone())
    if p.size(0) == num_classes and num_to_forget is None:
        mu[class_to_forget] = 0
        var[class_to_forget] = 0.0001
    if p.size(0) == num_classes:
        # Last layer
        var *= 10
    elif p.ndim == 1:
        # BatchNorm
        var *= 10
#         var*=1
    return mu, var

def kl_divergence_fisher(mu0, var0, mu1, var1):
    return ((mu1 - mu0).pow(2)/var0 + var1/var0 - torch.log(var1/var0) - 1).sum()

In [None]:
model_ft = copy.deepcopy(model)
retain_loader = replace_loader_dataset(train_loader_full,retain_dataset, seed=seed, batch_size=args.batch_size, shuffle=True)
finetune(model_ft, retain_loader, epochs=10, quiet=True, lr=0.01)