In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import sys

import pandas as pd
import numpy as np

import time

import torch

import copy
from collections import OrderedDict

from tqdm.auto import tqdm

In [2]:
from datasets import load_original_dataset, load_deleted_dataset
from models import CNN

In [3]:
DATA_DIR = 'Datasets/Features/'
BATCH_SIZE = 256
WEIGHT_DECAY = 0.1
EPOCHS = 1
PERCENTAGES = [1, 10, 20, 30, 40, 50, 60, 70, 80, 90, 99]

In [4]:
sys.path.append(os.path.abspath('./libraries/SelectiveForgetting/'))

# NTK based Forgetting

In [5]:
def vectorize_params(model):
    param = []
    for p in model.parameters():
        param.append(p.data.view(-1).cpu().numpy())
    return np.concatenate(param)


# NTK Update


def delta_w_utils(model, dataset):
    model.eval()
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=False)
    num = len(dataloader)
    
    G_list = []
    f0_minus_y = []
    
    for idx, batch in enumerate(tqdm(dataloader, leave=False)):
        batch = [tensor.cuda() for tensor in batch]
        input, target = batch
        # if "mnist" in args.dataset:
        #     input = input.view(input.shape[0], -1)
        target = target.cpu().detach().numpy()
        output = model(input)
        
        G_sample=[]
        for cls in range(output.shape[-1]):
            grads = torch.autograd.grad(
                output[0, cls], model.parameters(), retain_graph=True
            )
            grads = np.concatenate([g.view(-1).cpu().numpy() for g in grads])
            G_sample.append(grads)

        grads = np.stack(G_sample).transpose()
        G_list.append(grads)

        # if args.lossfn=='ce':
        p = (
            torch.nn.functional.softmax(output, dim=1)
            .cpu()
            .detach()
            .numpy()
            .transpose()
        )
        p[target] -= 1
        f0_y_update = copy.deepcopy(p)
        
        f0_minus_y.append(f0_y_update.mean(axis=-1, keepdims=True))

    return num, np.concatenate(G_list, axis=-1), np.vstack(f0_minus_y)


# Reshape delta_w


def get_delta_w_dict(delta_w, model):
    # Give normalized delta_w
    delta_w_dict = OrderedDict()
    params_visited = 0
    for k, p in model.named_parameters():
        num_params = np.prod(list(p.shape))
        update_params = delta_w[params_visited : params_visited + num_params]
        delta_w_dict[k] = torch.Tensor(update_params).view_as(p)
        params_visited += num_params
    return delta_w_dict


In [6]:
def worker(model, model_init, train_set, forget_set):

    # Jacobians and Hessians
    
    num_to_retain, G_r, f0_minus_y_r = delta_w_utils(copy.deepcopy(model), train_set)
    
    num_to_forget, G_f, f0_minus_y_f = delta_w_utils(copy.deepcopy(model), forget_set)
    
    G = np.concatenate([G_r, G_f], axis=1)
    f0_minus_y = np.concatenate([f0_minus_y_r, f0_minus_y_f])
    
    # w_lin(D)
    
    num_total = num_to_retain + num_to_forget
    
    theta = G.transpose().dot(G) + num_total * WEIGHT_DECAY * np.eye(G.shape[1])
    theta_inv = np.linalg.inv(theta)
    
    w_complete = -G.dot(theta_inv.dot(f0_minus_y))
    
    # w_lin(D_r)
    
    theta_r = G_r.transpose().dot(G_r) + num_to_retain * WEIGHT_DECAY * np.eye(
        G_r.shape[1]
    )
    theta_r_inv = np.linalg.inv(theta_r)
    
    w_retain = -G_r.dot(theta_r_inv.dot(f0_minus_y_r))
    
    # Scrubbing Direction
    
    delta_w = (w_retain - w_complete).squeeze()
    
    # Trapezium Trick
    
    m_pred_error = (
        vectorize_params(model) - vectorize_params(model_init) - w_retain.squeeze()
    )
    
    inner = np.inner(
        delta_w / np.linalg.norm(delta_w), m_pred_error / np.linalg.norm(m_pred_error)
    )
    
    if inner < 0:
        angle = np.arccos(inner) - np.pi / 2
        predicted_norm = np.linalg.norm(delta_w) + 2 * np.sin(angle) * np.linalg.norm(
            m_pred_error
        )
    else:  
        angle = np.arccos(inner)
        predicted_norm = np.linalg.norm(delta_w) + 2 * np.cos(angle) * np.linalg.norm(
            m_pred_error
        )
    
    predicted_scale = predicted_norm / np.linalg.norm(delta_w)
    
    # Scrub using NTK
    
    scale = predicted_scale
    direction = get_delta_w_dict(delta_w, model)
    
    model_scrub = copy.deepcopy(model)
    for k, p in model_scrub.named_parameters():
        p.data += (direction[k] * scale).cuda()
    
    return model_scrub

In [7]:
def fit(model, model_init, save_dir, train_set, test_set, forget_set):
    
    os.makedirs(save_dir, exist_ok=True)

    train_times = list()
    train_accs, test_accs, forget_accs = list(), list(), list()
    
    for epoch in range(EPOCHS):
        
        # train
        
        train_time = 0
        
        start_time = time.time()
        
        model_scrub = worker(model, model_init, train_set, forget_set)
        
        train_time += time.time() - start_time
        
        start_time = time.time()
            
        train_times.append(train_time)
        
        # test
            
        model_scrub.eval()
        with torch.no_grad():
            
            x, y = train_set.tensors
            
            accs = list()
            
            for i in range(0, x.shape[0], BATCH_SIZE):
            
                output = model_scrub(x[i:i+BATCH_SIZE].cuda())

                predicted = torch.argmax(output.data, dim=-1)
                accs.append((predicted == y[i:i+BATCH_SIZE].cuda()).float().mean().detach().cpu().numpy())
            
            train_accs.append(np.mean(accs))
            
            x, y = test_set.tensors

            accs = list()
            
            for i in range(0, x.shape[0], BATCH_SIZE):
            
                output = model_scrub(x[i:i+BATCH_SIZE].cuda())

                predicted = torch.argmax(output.data, dim=-1)
                accs.append((predicted == y[i:i+BATCH_SIZE].cuda()).float().mean().detach().cpu().numpy())
            
            test_accs.append(np.mean(accs))
            

            x, y = forget_set.tensors

            accs = list()

            for i in range(0, x.shape[0], BATCH_SIZE):

                output = model_scrub(x[i:i+BATCH_SIZE].cuda())

                predicted = torch.argmax(output.data, dim=-1)
                accs.append((predicted == y[i:i+BATCH_SIZE].cuda()).float().mean().detach().cpu().numpy())

            forget_accs.append(np.mean(accs))
        
        # save
        torch.save(model_scrub.state_dict(), os.path.join(save_dir, f'{(epoch+1):03d}.pt'))

    return train_times, train_accs, test_accs, forget_accs

In [8]:
results = list()

for percentage in tqdm(PERCENTAGES):
    
    model = CNN().cuda()

    model.load_state_dict(torch.load('./weights/original/005.pt'))

    model_init = CNN().cuda()

    model_init.load_state_dict(torch.load('./weights/init.pt'))
    
    train_set, test_set, forget_set = load_deleted_dataset(DATA_DIR, percentage)

    train_times, train_accs, test_accs, forget_accs = fit(model, model_init, f'weights/SelectiveForgetting/{percentage}', train_set, test_set, forget_set)
    
    df = pd.DataFrame(zip(train_times, train_accs, test_accs, forget_accs), columns=['train_time', 'train_acc', 'test_acc', 'forget_acc'])
    df['epoch'] = range(1, EPOCHS+1)
    df['percentage'] = percentage
    
    results.append(df)

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/233 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/211 [00:00<?, ?it/s]

  0%|          | 0/24 [00:00<?, ?it/s]

  0%|          | 0/188 [00:00<?, ?it/s]

  0%|          | 0/47 [00:00<?, ?it/s]

  0%|          | 0/165 [00:00<?, ?it/s]

  0%|          | 0/71 [00:00<?, ?it/s]

  0%|          | 0/141 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/118 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

  0%|          | 0/141 [00:00<?, ?it/s]

  0%|          | 0/71 [00:00<?, ?it/s]

  0%|          | 0/165 [00:00<?, ?it/s]

  0%|          | 0/47 [00:00<?, ?it/s]

  0%|          | 0/188 [00:00<?, ?it/s]

  0%|          | 0/24 [00:00<?, ?it/s]

  0%|          | 0/211 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/233 [00:00<?, ?it/s]

In [9]:
results = pd.concat(results).set_index(['percentage', 'epoch'])

results.to_csv('results/SelectiveForgetting.csv')

results

Unnamed: 0_level_0,Unnamed: 1_level_0,train_time,train_acc,test_acc,forget_acc
percentage,epoch,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
1,1,26.679327,0.987191,0.986133,0.990885
10,1,23.677634,0.987169,0.98623,0.986886
20,1,23.422085,0.986931,0.986133,0.986417
30,1,23.067099,0.987121,0.986426,0.987335
40,1,23.051812,0.986824,0.986719,0.987215
50,1,21.964149,0.986957,0.986133,0.987255
60,1,21.510948,0.986979,0.986426,0.986547
70,1,21.947019,0.988006,0.986426,0.986695
80,1,21.373967,0.987937,0.985742,0.986453
90,1,20.285455,0.989374,0.985937,0.986145
