In [1]:
# Necessary Imports
import os
import torch
import torchvision
import tarfile
import torch.nn as nn
import numpy as np
from torch.utils.data import DataLoader
import matplotlib
import matplotlib.pyplot as plt
import pandas as pd

from datasets import mnist
from utils import *
from models import AllCNN
from metrics import *
from unlearn import *

torch.manual_seed(100)

<torch._C.Generator at 0x7ffeaaef6070>

In [2]:
train_ds, valid_ds = mnist()

batch_size = 256
train_dl = DataLoader(train_ds, batch_size, shuffle=True, num_workers=16)
valid_dl = DataLoader(valid_ds, batch_size, num_workers=16)

In [3]:
num_classes = 10
classwise_train = {}
for i in range(num_classes):
    classwise_train[i] = []

for img, label in train_ds:
    classwise_train[label].append((img, label))
    
classwise_test = {}
for i in range(num_classes):
    classwise_test[i] = []

for img, label in valid_ds:
    classwise_test[label].append((img, label))

In [4]:
device = 'cuda'

In [5]:
model = AllCNN(n_channels = 1).to(device = device)

## Creating the fully trained model

In [6]:
epochs = 25
max_lr = 0.01
grad_clip = 0.1
weight_decay = 1e-4
opt_func = torch.optim.Adam

In [None]:
%%time
history = fit_one_cycle(epochs, max_lr, model, train_dl, valid_dl, 
                             grad_clip=grad_clip, 
                             weight_decay=weight_decay, 
                             opt_func=opt_func, device = device)
torch.save(model.state_dict(), "AllCNN_MNIST_ALL_CLASSES.pt")

In [7]:
model.load_state_dict(torch.load("AllCNN_MNIST_ALL_CLASSES.pt"))
history = [evaluate(model, valid_dl, device = device)]
history

[{'Loss': 0.019072750583291054, 'Acc': 0.993945300579071}]

## Forgetting Class 0 using GKT

In [8]:
# Getting the forget and retain data
forget_valid = []
forget_classes = [0]
for cls in range(num_classes):
    if cls in forget_classes:
        for img, label in classwise_test[cls]:
            forget_valid.append((img, label))

retain_valid = []
for cls in range(num_classes):
    if cls not in forget_classes:
        for img, label in classwise_test[cls]:
            retain_valid.append((img, label))

forget_valid_dl = DataLoader(forget_valid, batch_size, num_workers=3, pin_memory=True)

retain_valid_dl = DataLoader(retain_valid, batch_size, num_workers=3, pin_memory=True)

In [15]:
n_generator_iter = 1
n_student_iter = 10
n_repeat_batch = n_generator_iter + n_student_iter

ERROR! Session/line number was not unique in database. History logging moved to new session 2976


In [22]:
model = AllCNN(n_channels = 1).to(device = device)
model.load_state_dict(torch.load("AllCNN_MNIST_ALL_CLASSES.pt"))

student = AllCNN(n_channels = 1).to(device = device)
generator = LearnableLoader(n_repeat_batch=n_repeat_batch, num_channels = 1, device = device).to(device=device)
optimizer_generator = torch.optim.Adam(generator.parameters(), lr=0.001) 
scheduler_generator = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_generator, 
                                                               mode='min', factor=0.5, patience=2, verbose=True)
optimizer_student = torch.optim.Adam(student.parameters(), lr=0.001)
scheduler_s-tudent = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_student, \
                                    mode='min', factor=0.5, patience=2, verbose=True)

In [23]:
print("Performance of Fully Trained Model on Forget Class")
history = [evaluate(model, forget_valid_dl, device = device)]
print("Accuracy: {}".format(history[0]["Acc"]*100))
print("Loss: {}".format(history[0]["Loss"]))

print("Performance of Fully Trained Model on Retain Class")
history = [evaluate(model, retain_valid_dl, device = device)]
print("Accuracy: {}".format(history[0]["Acc"]*100))
print("Loss: {}".format(history[0]["Loss"]))


history = [evaluate(student, forget_valid_dl, device = device)]
AccForget = history[0]["Acc"]*100
ErrForget = history[0]["Loss"]

history = [evaluate(student, retain_valid_dl, device = device)]
AccRetain = history[0]["Acc"]*100
ErrRetain = history[0]["Loss"]

Performance of Fully Trained Model on Forget Class
Accuracy: 99.51171875
Loss: 0.01288334745913744
Performance of Fully Trained Model on Retain Class
Accuracy: 99.34606552124023
Loss: 0.02070627547800541


In [24]:
generator_path = "./ckpts/mnist_allcnn/generator"
student_path = "./ckpts/mnist_allcnn/student"

#os.makedirs(generator_path)
#os.makedirs(student_path)

idx_pseudo = 0
total_n_pseudo_batches = 4000
n_pseudo_batches = 0
running_gen_loss = []
running_stu_loss = []

threshold = 0.01

In [25]:
import warnings
warnings.filterwarnings("ignore")

### Training the unlearned model

In [26]:
KL_temperature = 1
AT_beta = 250

In [None]:
history_forget = [evaluate(student, forget_valid_dl, device = device)]
AccForget = history_forget[0]["Acc"]*100
ErrForget = history_forget[0]["Loss"]

history_retain = [evaluate(student, retain_valid_dl, device = device)]
AccRetain = history_retain[0]["Acc"]*100
ErrRetain = history_retain[0]["Loss"]

df = pd.DataFrame(columns = ["Epochs", "AccForget", "AccRetain", "ErrForget", "ErrRetain", "MeanGeneratorLoss", "MeanStudentLoss"])
df = df.append({"Epochs":0, "AccForget":AccForget, "AccRetain":AccRetain, "ErrForget":ErrForget, 
                "ErrRetain":ErrRetain, "MeanGeneratorLoss":None, "MeanStudentLoss":None}, ignore_index = True)

# saving the generator
torch.save(generator.state_dict(), os.path.join(generator_path, str(0) + ".pt"))

# saving the student
torch.save(student.state_dict(), os.path.join(student_path, str(0) + ".pt"))




while n_pseudo_batches < total_n_pseudo_batches:
    x_pseudo = generator.__next__()
    preds, *_ = model(x_pseudo)
    mask = (torch.softmax(preds.detach(), dim=1)[:, 0] <= threshold)
    x_pseudo = x_pseudo[mask]
    if x_pseudo.size(0) == 0:
        zero_count += 1
        if zero_count > 100:
            print("Generator Stopped Producing datapoints corresponding to retain classes.")
            print("Resetting the generator to previous checkpoint")
            generator.load_state_dict(torch.load(os.path.join(generator_path, str(((n_pseudo_batches//50)-1)*50) + ".pt")))
        continue
    else:
        zero_count = 0
    
    ## Take n_generator_iter steps on generator
    if idx_pseudo % n_repeat_batch < n_generator_iter:
        student_logits, *student_activations = student(x_pseudo)
        teacher_logits, *teacher_activations = model(x_pseudo)
        generator_total_loss = KT_loss_generator(student_logits, teacher_logits, KL_temperature=KL_temperature)

        optimizer_generator.zero_grad()
        generator_total_loss.backward()
        torch.nn.utils.clip_grad_norm_(generator.parameters(), 5)
        optimizer_generator.step()
        running_gen_loss.append(generator_total_loss.cpu().detach())


    elif idx_pseudo % n_repeat_batch < (n_generator_iter + n_student_iter):
        
        
        with torch.no_grad():
            teacher_logits, *teacher_activations = model(x_pseudo)

        student_logits, *student_activations = student(x_pseudo)
        student_total_loss = KT_loss_student(student_logits, student_activations, 
                                             teacher_logits, teacher_activations, 
                                             KL_temperature=KL_temperature, AT_beta = AT_beta)

        optimizer_student.zero_grad()
        student_total_loss.backward()
        torch.nn.utils.clip_grad_norm_(student.parameters(), 5)
        optimizer_student.step()
        running_stu_loss.append(student_total_loss.cpu().detach())
        
    if (idx_pseudo + 1) % n_repeat_batch == 0:       
        if((n_pseudo_batches)% 50 == 0):
            MeanGLoss = np.mean(running_gen_loss)
            running_gen_loss = []
            MeanSLoss = np.mean(running_stu_loss)
            running_stu_loss = []
            
            history_forget = [evaluate(student, forget_valid_dl, device = device)]
            AccForget = history_forget[0]["Acc"]*100
            ErrForget = history_forget[0]["Loss"]

            history_retain = [evaluate(student, retain_valid_dl, device = device)]
            AccRetain = history_retain[0]["Acc"]*100
            ErrRetain = history_retain[0]["Loss"]
            
            df = df.append({"Epochs":n_pseudo_batches, "AccForget":AccForget, "AccRetain":AccRetain, "ErrForget":ErrForget, 
                            "ErrRetain":ErrRetain, "MeanGeneratorLoss":MeanGLoss, "MeanStudentLoss":MeanSLoss}, ignore_index = True)
            print(df.iloc[-1:])
            scheduler_student.step(history_retain[0]["Loss"])
            scheduler_generator.step(history[0]["Loss"])
            
            # saving the generator
            torch.save(generator.state_dict(), os.path.join(generator_path, str(n_pseudo_batches) + ".pt"))
            
            # saving the student
            torch.save(student.state_dict(), os.path.join(student_path, str(n_pseudo_batches) + ".pt"))
            
            
        n_pseudo_batches += 1
        
    idx_pseudo += 1

In [29]:
df.iloc[10:20]

Unnamed: 0,Epochs,AccForget,AccRetain,ErrForget,ErrRetain,MeanGeneratorLoss,MeanStudentLoss
10,450.0,0.0,85.448492,4.841099,0.643373,-0.009583,0.074831
11,500.0,0.0,85.30454,5.04534,0.682012,-0.013167,0.086326
12,550.0,0.0,86.072773,4.507913,0.471898,-0.006274,0.052073
13,600.0,0.0,87.008828,4.11447,0.287021,-0.007687,0.060969
14,650.0,0.0,88.53299,3.570807,0.226169,-0.006062,0.046331
15,700.0,0.0,89.653504,3.301095,0.201094,-0.009921,0.066714
16,750.0,0.0,94.76707,2.407107,0.121586,-0.006688,0.041866
17,800.0,6.025206,98.412907,2.027194,0.067476,-0.008076,0.059766
18,850.0,53.30373,98.792678,1.357714,0.050143,-0.004267,0.028618
19,900.0,76.265848,98.81438,0.936303,0.04606,-0.004487,0.034179


In [30]:
df.to_csv("MNIST_ALLCNN.csv", index = False)