# Initialization

In [None]:
%cd SelectiveForgetting

In [None]:
import sys
sys.path.append('/home/sirjanhansda/AIMLTermPaperSem4/SelectiveForgetting')

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import os
os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES']='1'
import variational
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.ticker import FuncFormatter
import os
import time
import math
import pandas as pd
from collections import OrderedDict
from sklearn.linear_model import LogisticRegression
import numpy as np
    
import copy
import torch.nn as nn
from torch.autograd import Variable
from typing import List
import itertools
from tqdm.autonotebook import tqdm
from models import *
import models
from logger import *
import models
import datasets
from utils import *

In [None]:
def pdb():
    import pdb
    pdb.set_trace

def parameter_count(model):
    count=0
    for p in model.parameters():
        count+=np.prod(np.array(list(p.shape)))
    print(f'Total Number of Parameters: {count}')

def vectorize_params(model):
    param = []
    for p in model.parameters():
        param.append(p.data.view(-1).cpu().numpy())
    return np.concatenate(param)

def print_param_shape(model):
    for k,p in model.named_parameters():
        print(k,p.shape)

# Train the model

In [None]:
numf = 4739

In [None]:
train_time_o = 0
train_time_ft = 0
train_time_f = 0
train_time_r = 0

In [None]:
tic = time.perf_counter()

In [None]:
%run main.py --dataset mnist --model mlp --filters 0.4 --lr 0.1 --lossfn ce --num-classes 10 --batch-size 128 --weight-decay 0.001 --epochs 5
#875,607s

In [None]:
toc = time.perf_counter()
train_time_o += (toc - tic)
train_time_ft += (toc - tic)
train_time_f += (toc - tic)

Marking of forgotten data and fine-tuning

In [None]:
tic = time.perf_counter()

In [None]:
%run main.py --dataset mnist --model mlp --filters 0.4 --lr 0.0001\
--resume checkpoints/mnist_mlp_0_4_forget_None_lr_0_01_bs_128_ls_ce_wd_0_001_seed_1_4.pt --disable-bn\
--weight-decay 0.001 --batch-size 128 --epochs 5 --seed 1 --forget-class 0 --num-to-forget {numf} --epochs 5
#29,068s

In [None]:
toc = time.perf_counter()
train_time_ft += (toc - tic)
train_time_f += (toc - tic)

# Fisher functions

In [None]:
def hessian(dataset, model):
    device = torch.device("cuda" if use_cuda else "cpu")

    model.eval()
    train_loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=3)
    loss_fn = nn.CrossEntropyLoss()

    for p in model.parameters():
        p.grad_acc = 0
        p.grad2_acc = 0
    
    for data, orig_target in tqdm(train_loader):
        data, orig_target = data.to(device), orig_target.to(device)
        output = model(data)
        prob = F.softmax(output, dim=-1).data

        for y in range(output.shape[1]):
            target = torch.empty_like(orig_target).fill_(y)
            loss = loss_fn(output, target)
            model.zero_grad()
            loss.backward(retain_graph=True)
            for p in model.parameters():
                if p.requires_grad:
                    p.grad_acc += (orig_target == target).float() * p.grad.data 
                    p.grad2_acc += prob[:, y] * p.grad.data.pow(2) 
    for p in model.parameters():
        p.grad_acc /= len(train_loader)
        p.grad2_acc /= len(train_loader)

In [None]:
def get_mean_var(p, alpha=3e-6):
    var = copy.deepcopy(1./(p.grad2_acc+1e-8))
    var = var.clamp(max=1e3) 
    if p.size(0) == num_classes:
        var = var.clamp(max=1e2)
    var = alpha * var 
    if p.ndim > 1:
        var = var.mean(dim=1, keepdim=True).expand_as(p).clone()
    mu = copy.deepcopy(p.data0.clone())
    if p.size(0) == num_classes and num_to_forget is None:
        mu[class_to_forget] = 0
        var[class_to_forget] = 0.0001
    if p.size(0) == num_classes:
        # Last layer
        var *= 10
    elif p.ndim == 1:
        # BatchNorm
        var *= 10 
#         var*=1
    return mu, var

def kl_divergence_fisher(mu0, var0, mu1, var1):
    return ((mu1 - mu0).pow(2)/var0 + var1/var0 - torch.log(var1/var0) - 1).sum()

# Models

In [None]:

args = {'augment':False, 
        'batch_size':128, 'dataset':'mnist', 'disable_bn':True, 'epochs':5, 
        'filters':0.4, 'forget_class':0, 'l1':False, 'lossfn':'ce', 'lr':0.0001, 'model':'mlp', 
        'momentum':0.9, 'no_cuda':False, 'num_classes':10, 'num_to_forget':numf, 
        'name':f'mnist_mlp_0_4_forget_None_lr_0_01_bs_128_ls_ce_wd_0_001_seed_1', 
        'resume':'checkpoints/mnist_mlp_0_4_forget_None_lr_0_1_bs_128_ls_ce_wd_0_001_seed_1_4.pt', 
        'seed':1, 'step_size':32, 'unfreeze_start':None, 'weight_decay':0.001}
arch = args['model'] 
filters=args['filters']
arch_filters = arch +'_'+ str(filters).replace('.','_')
augment = False
dataset = args['dataset']
class_to_forget = args['forget_class']
init_checkpoint = f"checkpoints/{args['name']}_init.pt"
num_classes=args['num_classes']
num_to_forget = args['num_to_forget']

train_loader, valid_loader, test_loader = datasets.get_loaders(args['dataset'], class_to_replace=args['forget_class'], 
                                                               num_indexes_to_replace=args['num_to_forget'], 
                                                               batch_size=args['batch_size'], seed=args['seed'], augment=args['augment'])

num_total = len(train_loader.dataset)
num_to_retain = num_total - num_to_forget
seed = args['seed']
unfreeze_start = None

learningrate=f"lr_{str(args['lr']).replace('.','_')}"
batch_size=f"_bs_{str(args['batch_size'])}"
lossfn=f"_ls_{args['lossfn']}"
wd=f"_wd_{str(args['weight_decay']).replace('.','_')}"
seed_name=f"_seed_{args['seed']}_"

num_tag = '' if num_to_forget is None else f'_num_{num_to_forget}'
unfreeze_tag = '_' if unfreeze_start is None else f'_unfreeze_from_{unfreeze_start}_'
augment_tag = '' if not augment else f'augment_'

use_cuda = not args['no_cuda'] and torch.cuda.is_available()
args['device'] = torch.device("cuda" if use_cuda else "cpu")

model = models.get_model(args['model'], num_classes=num_classes, filters_percentage=args['filters']).to(args['device'])

In [None]:
# the same lines are also included in the mail - wtf
train_loader_full, valid_loader_full, test_loader_full = datasets.get_loaders(dataset, batch_size=args['batch_size'], seed=seed, augment=False, shuffle=True)
# In this, it is indicated which dataset should be forgotten
marked_loader, _, _ = datasets.get_loaders(dataset, class_to_replace=class_to_forget, num_indexes_to_replace=num_to_forget, only_mark=True, batch_size=1, seed=seed, augment=False, shuffle=True)

def replace_loader_dataset(dataset, batch_size=args['batch_size'], seed=1, shuffle=True):
    manual_seed(seed)
    return torch.utils.data.DataLoader(dataset, batch_size=batch_size,num_workers=3,pin_memory=True,shuffle=shuffle)

# FORGET DATASET
forget_dataset = copy.deepcopy(marked_loader.dataset)
marked = forget_dataset.targets < 0 # data to forget is marked negative
forget_dataset.data = forget_dataset.data[marked]
forget_dataset.targets = - forget_dataset.targets[marked] - 1
forget_loader = replace_loader_dataset(forget_dataset, seed=seed, shuffle=True)

#RETAIN DATASET
retain_dataset = copy.deepcopy(marked_loader.dataset)
marked = retain_dataset.targets >= 0
retain_dataset.data = retain_dataset.data[marked]
retain_dataset.targets = retain_dataset.targets[marked]
retain_loader = replace_loader_dataset(retain_dataset, seed=seed, shuffle=True)

assert(len(forget_dataset) + len(retain_dataset) == len(train_loader_full.dataset))

In [None]:
def l2_penalty(model,model_init,weight_decay):
    l2_loss = 0
    for (k,p),(k_init,p_init) in zip(model.named_parameters(),model_init.named_parameters()):
        if p.requires_grad:
            l2_loss +=  (p-p_init).pow(2).sum()
    l2_loss *= (weight_decay/2.)
    return l2_loss

def evaluate(model, train_loader, epoch=0, weight_decay=None, mode='test'):
  criterion=torch.nn.CrossEntropyLoss()
  model.eval()
  mult=1
  metrics = AverageMeter()
  device = torch.device("cuda" if use_cuda else "cpu")

  with torch.set_grad_enabled(mode != 'test'):
      for batch_idx, (data, target) in enumerate(train_loader):
          data, target = data.to(device), target.to(device)
              
          output = model(data)
          metrics.update(n=data.size(0), error=get_error(output, target), accuracy=get_accuracy(output, target))
  return metrics.avg['accuracy']


In [None]:
def get_accuracy(output, target):
    _, predicted = output.max(1)
    correct = predicted.eq(target).sum().item()
    accuracy = correct / target.size(0)
    return accuracy

In [None]:
import copy
learningrate=f"lr_{str(0.01).replace('.','_')}"
original_name = f'checkpoints/{dataset}_{arch_filters}_forget_None{unfreeze_tag}{augment_tag}{learningrate}{batch_size}{lossfn}{wd}{seed_name}4.pt'
learningrate=f"lr_{str(args['lr']).replace('.','_')}"
finetuned_name = f'checkpoints/{dataset}_{arch_filters}_forget_{class_to_forget}{num_tag}{unfreeze_tag}{augment_tag}{learningrate}{batch_size}{lossfn}{wd}{seed_name}4.pt'
model_o = copy.deepcopy(model)
model_ft = copy.deepcopy(model)

model_o.load_state_dict(torch.load(original_name)) # MODEL WITHOUT FORGETTING: model
model_ft.load_state_dict(torch.load(finetuned_name)) # MODEL WITH FORGETTING CLASS 0: model0
# model_r.load_state_dict(torch.load(finetuned_name))

#model_ft.cuda()
#model_o.cuda()

tic = time.perf_counter()
for p in model_ft.parameters():
    p.data0 = p.data.clone()

model_f = copy.deepcopy(model_ft)
#model_f.cuda()
hessian(retain_loader.dataset, model_f)
for p in itertools.chain(model_f.parameters()):
  p.data0 = copy.deepcopy(p.data.clone())

alpha = 1e-7
torch.manual_seed(seed)
for i, p in enumerate(model_f.parameters()):
    w, var = get_mean_var(p, alpha=alpha)
    p.data = w + var.sqrt() * torch.empty_like(p.data0).normal_()
    

toc = time.perf_counter()
train_time_f += (toc - tic)

In [None]:
tic = time.perf_counter()

In [None]:
# retrain
%run main.py --dataset mnist --model mlp --filters 0.4 --lr 0.01\
--weight-decay 0.001 --batch-size 128 --epochs 5\
--forget-class 0 --num-to-forget {numf} --seed 1

In [None]:
toc = time.perf_counter()
train_time_r += (toc - tic)

In [None]:
time_dict = {'original': train_time_o,
             'finetuned': train_time_ft,
             'fisher': train_time_f,
             'retrain': train_time_r}
df = pd.DataFrame([time_dict])
df.to_csv(f'data_num/timesnumf{numf}.csv')

In [None]:
learningrate=f"lr_{str(0.01).replace('.','_')}"
retrain_name = f'checkpoints/{dataset}_{arch_filters}_forget_{class_to_forget}{num_tag}{unfreeze_tag}{augment_tag}{learningrate}{batch_size}{lossfn}{wd}{seed_name}4.pt'
model_r = copy.deepcopy(model)
model_r.load_state_dict(torch.load(retrain_name))
#model_r.cuda()

In [None]:
acc_dict = {}
acc_dict['Original'] = evaluate(model_o, retain_loader, epoch=0, weight_decay=wd, mode='test')
acc_dict['Finetuned'] = evaluate(model_ft, retain_loader, epoch=0, weight_decay=wd, mode='test')
acc_dict['Fisher'] = evaluate(model_f, retain_loader, epoch=0, weight_decay=wd, mode='test')
acc_dict['Retrain'] = evaluate(model_r, retain_loader, epoch=0, weight_decay=wd, mode='test')

df = pd.DataFrame([acc_dict])
df.to_csv(f'data_num/acc{numf}_{alpha}.csv')

In [None]:
acc_dict = {}
acc_dict['Original'] = evaluate(model_o, forget_loader, epoch=0, weight_decay=wd, mode='test')
acc_dict['Finetuned'] = evaluate(model_ft, forget_loader, epoch=0, weight_decay=wd, mode='test')
acc_dict['Fisher'] = evaluate(model_f, forget_loader, epoch=0, weight_decay=wd, mode='test')
acc_dict['Retrain'] = evaluate(model_r, forget_loader, epoch=0, weight_decay=wd, mode='test')

df = pd.DataFrame([acc_dict])
df.to_csv(f'data/acc{numf}_forget_{alpha}.csv')

In [None]:
acc_dict = {}
acc_dict['Original'] = evaluate(model_o, test_loader_full, epoch=0, weight_decay=wd, mode='test')
acc_dict['Finetuned'] = evaluate(model_ft, test_loader_full, epoch=0, weight_decay=wd, mode='test')
acc_dict['Fisher'] = evaluate(model_f, test_loader_full, epoch=0, weight_decay=wd, mode='test')
acc_dict['Retrain'] = evaluate(model_r, test_loader_full, epoch=0, weight_decay=wd, mode='test')

df = pd.DataFrame([acc_dict])
df.to_csv(f'data/acc{numf}_test_{alpha}.csv')

In [None]:
# Hessian on the retrained
hessian(retain_loader.dataset, model_r)
for p in itertools.chain(model_r.parameters()):
  p.data0 = copy.deepcopy(p.data.clone())

alpha = 1e-6
torch.manual_seed(seed)
for i, p in enumerate(model_r.parameters()):
    w, var = get_mean_var(p, alpha=alpha)
    p.data = w + var.sqrt() * torch.empty_like(p.data0).normal_()

In [None]:
log_dict = {}

# Consistency: distance of parameters

In [None]:
def distance(model,model0):
    distance=0
    normalization=0
    for (k, p), (k0, p0) in zip(model.named_parameters(), model0.named_parameters()):
        current_dist = (p.data-p0.data).pow(2).sum().item()
        current_norm = p.data.pow(2).sum().item()
        distance += current_dist
        normalization += current_norm
    print(f'Distance: {np.sqrt(distance)}')
    print(f'Normalized Distance: {1.0*np.sqrt(distance/normalization)}')
    return 1.0*np.sqrt(distance/normalization)

In [None]:
log_dict['cons_Original_and_Fisher']=distance(model_o,model_f)
log_dict['cons_Retrained_and_Fisher']=distance(model_r,model_f)

In [None]:
distance(model,model)

# Effectiveness

In [None]:
wd = 0.001

In [None]:
eval_f = evaluate(model_f, test_loader_full, epoch=0, weight_decay=wd, mode='test')
eval_r = evaluate(model_r, test_loader_full, epoch=0, weight_decay=wd, mode='test')

log_dict['effectiveness_test'] = abs(eval_f - eval_r)

In [None]:
eval_f = evaluate(model_f, forget_loader, epoch=0, weight_decay=wd, mode='test')
eval_r = evaluate(model_r, forget_loader, epoch=0, weight_decay=wd, mode='test')

log_dict['effectiveness_forget'] = abs(eval_f - eval_r)

In [None]:

eval_f = evaluate(model_f, retain_loader, epoch=0, weight_decay=wd, mode='test')
eval_r = evaluate(model_r, retain_loader, epoch=0, weight_decay=wd, mode='test')

log_dict['effectiveness'] = abs(eval_f - eval_r)

# Certifiability

In [None]:
def cert(model1, model2):
  eval1 = evaluate(model1, forget_loader, epoch=0, weight_decay=wd, mode='test')
  eval2 = evaluate(model2, forget_loader, epoch=0, weight_decay=wd, mode='test')
  return (abs(eval1-eval2) / (abs(eval1) + abs(eval2))) * 100

In [None]:
log_dict['cert'] = cert(model_f, model_r)

# distance of w(D) from initialization


In [None]:
def dist_init(resume,seed=1):
    device = torch.device("cuda" if use_cuda else "cpu")
    manual_seed(seed)
    model_init = models.get_model(arch, num_classes=num_classes, filters_percentage=filters).to(device)
    model_init.load_state_dict(torch.load(resume))
    return model_init

In [None]:
model_init = dist_init(init_checkpoint)
for p in model_init.parameters():
    p.data0 = p.data.clone() 

In [None]:
log_dict['dist_Original_Original_init']=distance(model_init,model_o)
log_dict['dist_Fisher_Original_init']=distance(model_init,model_f)


# Set lambda hyperparameter

In [None]:
def l2_penalty(model,model_init,weight_decay):
    l2_loss = 0
    for (k,p),(k_init,p_init) in zip(model.named_parameters(),model_init.named_parameters()):
        if p.requires_grad:
            l2_loss += (p-p_init).pow(2).sum()
    l2_loss *= (weight_decay/2.)
    return l2_loss

def run_train_epoch(model: nn.Module, model_init, data_loader: torch.utils.data.DataLoader, 
                    loss_fn: nn.Module,
                    optimizer: torch.optim.SGD, split: str, ignore_index=None,
                    negative_gradient=False, negative_multiplier=-1, random_labels=False,
                    quiet=False,delta_w=None,scrub_act=False, wd=0.001):
    model.eval()
    metrics = AverageMeter()    
    num_labels = data_loader.dataset.targets.max().item() + 1
    
    with torch.set_grad_enabled(split != 'test'):
        for idx, batch in enumerate(tqdm(data_loader, leave=False)):
            batch = [tensor.to(next(model.parameters()).device) for tensor in batch]
            input, target = batch
            output = model(input)

            loss = loss_fn(output, target) + l2_penalty(model,model_init,wd) # a losshoz hozzáadja az l2 távot is
            metrics.update(n=input.size(0), loss=loss_fn(output,target).item(), error=get_error(output, target))
            
            # won't use this
            if split != 'test':
                model.zero_grad()
                loss.backward()
                optimizer.step()
    return metrics.avg
    
def test(model, data_loader):
    loss_fn = nn.CrossEntropyLoss()
    model_init=copy.deepcopy(model)
    return run_train_epoch(model, model_init, data_loader, loss_fn, optimizer=None, split='test', ignore_index=None, quiet=True)

In [None]:
results = []
alpha_list = [1e-8,1e-7,1e-6,1e-5, 1e-3, 1e-2]
test_error_list = []
information_list = []

runs = 3
for s in range(runs):
    torch.manual_seed(s)
    test_error_list.append([])
    information_list.append([])
    for alpha in alpha_list:        
        for i, p in enumerate(model_f.parameters()):
            mu, var = get_mean_var(p, alpha=alpha)
            p.data = mu + var.sqrt() * torch.empty_like(p.data0).normal_()

        for i, p in enumerate(model_r.parameters()):
            mu, var = get_mean_var(p, alpha=alpha)
            p.data = mu + var.sqrt() * torch.empty_like(p.data0).normal_()

        metrics = test(model_f, test_loader_full)

        total_kl = 0
        for (k, p), (k0, p0) in zip(model_f.named_parameters(), model_r.named_parameters()):
            mu0, var0 = get_mean_var(p, alpha=alpha)
            mu1, var1 = get_mean_var(p0, alpha=alpha)
            kl = kl_divergence_fisher(mu0, var0, mu1, var1).item()
            total_kl += kl
            
        test_error_list[s].append(metrics['error'])
        information_list[s].append(total_kl)

alpha_list = np.ndarray.flatten(np.array([alpha_list for i in range(runs)]))
test_error_list = np.ndarray.flatten(np.array(test_error_list))
information_list = np.ndarray.flatten(np.array(information_list))

info_dict = {}
info_dict['alpha'] = alpha_list
info_dict['error'] = [i*100 for i in test_error_list]
info_dict['info'] = information_list
df = pd.DataFrame(info_dict)    

print(df)

In [None]:
import matplotlib.pyplot as plt
def plot_info(ax,df,information_list,title):
  sns.lineplot(x="info", y="error",data=df,ax=ax)
  ax.set(xscale="log")
  ax.set_xlabel('Residual information',size=16)
  ax.set_ylabel('Error on test set (%)',size=16)
  ax.set_title(title,size=16)
  ax.tick_params(axis="y", labelsize=16)
  ax.tick_params(axis="x", labelsize=16)

In [None]:
fig, ax = plt.subplots(figsize=(5.5, 4))
plot_info(ax,df,None,f'information contained in the weights') 

fig.tight_layout()
fig.savefig(f'data/lambda_plot{numf}.png', bbox_inches='tight')

# Saving the data

In [None]:
log_df = pd.DataFrame([log_dict])

In [None]:
log_df.to_csv(f'data_num/experiments{numf}_{alpha}.csv')
