# angular unlearning

In [1]:
import os
import torch
import torch.nn as nn
from torch import optim
from  torchvision import datasets
from torchvision import models
import torchvision.transforms as transforms
from torch.utils.data.dataset import TensorDataset
from torch.utils.data import DataLoader, Subset
from torch.utils.data.sampler import SubsetRandomSampler, SequentialSampler
from sklearn.metrics import classification_report
import numpy as np
import random 
import copy
import time

from models import init_params as w_init
from models import TargetNet, AllCNN
from trainer import train, eval, loss_picker, optimizer_picker

from torch.backends import cudnn
cudnn.benchmark = False      # if benchmark=True, deterministic will be False
cudnn.deterministic = True


  from .autonotebook import tqdm as notebook_tqdm


## Hyper parameters

In [2]:
data_name = 'cifar10'
batch_size = 64#

num_epochs = 30
learning_rate = 0.01
loss_mode = 'cross'
optimization = 'sgd'
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

seed = 2024
torch.manual_seed(seed)           
torch.cuda.manual_seed(seed)      
torch.cuda.manual_seed_all(seed)   

random.seed(seed)
np.random.seed(seed)

cuda:0


## Load data

In [3]:
def get_dataset(data_name):
    
    #model: 2 conv. layers followed by 2 FC layers
    if(data_name == 'mnist'):
        trainset = datasets.MNIST('./data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ]))
        testset = datasets.MNIST('./data', train=False, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ]))
        
    #model: ResNet-50
    if(data_name == 'cifar10'):
        transform = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        trainset = datasets.CIFAR10(root='./data', train=True,
                                                download=True, transform=transform)
        testset = datasets.CIFAR10(root='./data', train=False,
                                                download=True, transform=transform)
    
    if(data_name == 'cifar100'):
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5074, 0.4867, 0.4411), (0.2675, 0.2565, 0.2761)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5074, 0.4867, 0.4411), (0.2675, 0.2565, 0.2761)),
        ])
        trainset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
        testset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)
    
    return trainset, testset

def get_dataloader(trainset, testset, batch_size, device):
    train_loader  = DataLoader(dataset=trainset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(dataset=testset, batch_size=batch_size, shuffle=True)

    return train_loader, test_loader

trainset, testset = get_dataset(data_name)
train_loader, test_loader = get_dataloader(trainset, testset, batch_size, device=device)
print(train_loader)

Files already downloaded and verified
Files already downloaded and verified
<torch.utils.data.dataloader.DataLoader object at 0x000001AE6E43BB50>


## train and save an original model

In [None]:
import time
from models import *
 
num_classes = max(train_loader.dataset.targets) + 1 
if data_name == 'cifar10':
    model = AllCNN(n_channels=3, num_classes=num_classes, filters_percentage=0.5).to(device)
    num_epochs = 30
if data_name == 'cifar100':
    normalization = NormalizeByChannelMeanStd(
            mean=[0.5071, 0.4866, 0.4409], std=[0.2673, 0.2564, 0.2762]
        )
    model = resnet18(pretrained=True)
    model.fc = nn.Linear(model.fc.in_features, 100)
    model.normalize = normalization
    model = model.to(device)
    num_epochs = 182
    learning_rate = 0.1
elif data_name == 'tiny-imagenet':
    normalization = NormalizeByChannelMeanStd(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        )
    model = resnet34(pretrained=True)
    model.fc = nn.Linear(model.fc.in_features, 200)
    model.normalize = normalization
    model = model.to(device)
    num_epochs = 200
    learning_rate = 0.1


criterion = loss_picker('cross')
optimizer = optimizer_picker(optimization, model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

best_acc = 0

start = time.perf_counter()
for epo in range(num_epochs):
    train(model=model, data_loader=train_loader, criterion=criterion, optimizer=optimizer, loss_mode='cross', device=device)
    _, acc = eval(model=model, data_loader=test_loader, mode='', print_perform=False, device=device)
    print('EPOCH[%d]: test acc:%.4f'%(epo+1, acc))
    scheduler.step()

    if acc >= best_acc:
        best_acc = acc
        torch.save(model, './checkpoints/original_best_%s.pth'%data_name)
end = time.perf_counter()


## retrain an model for remain data 
## Load data

In [4]:
from torch.utils.data import Subset

def get_dataloader(trainset, testset, batch_size, device):
    train_loader  = DataLoader(dataset=trainset, batch_size=batch_size, shuffle=True, num_workers=0)
    test_loader = DataLoader(dataset=testset, batch_size=batch_size, shuffle=True, num_workers=0)

    return train_loader, test_loader

def split_class_data(dataset, forget_class, num_forget):
    forget_index = []
    class_remain_index = []
    remain_index = []
    sum = 0
    for i, (data, target) in enumerate(dataset):
        if target == forget_class and sum < num_forget:
            forget_index.append(i)
            sum += 1
        elif target == forget_class and sum >= num_forget:
            class_remain_index.append(i)
            remain_index.append(i)
            sum += 1
        else:
            remain_index.append(i)
    return forget_index, remain_index, class_remain_index

def get_class_unlearn_loader(trainset, testset, forget_class, batch_size, num_forget, repair_num_ratio=0.01):
    train_forget_index, train_remain_index, class_remain_index = split_class_data(trainset, forget_class, num_forget=num_forget)
    test_forget_index, test_remain_index, _ = split_class_data(testset, forget_class, num_forget=len(testset))

    repair_class_index = random.sample(class_remain_index, int(repair_num_ratio*len(class_remain_index)))

    train_forget_sampler = SubsetRandomSampler(train_forget_index)#5000
    train_forget_sampler_0_5 =  SubsetRandomSampler(train_forget_index[::2])
    train_remain_index_0_5 = list((set(train_forget_index) - set(train_forget_index[::2])).union(set(train_remain_index)))
    train_remain_sampler_0_5 =  SubsetRandomSampler(train_remain_index_0_5)

    train_forget_sampler_0_2 =  SubsetRandomSampler(train_forget_index[::5])
    train_remain_index_0_2 = list((set(train_forget_index) - set(train_forget_index[::5])).union(set(train_remain_index)))
    train_remain_sampler_0_2 =  SubsetRandomSampler(train_remain_index_0_2)

    train_forget_sampler_0_1 =  SubsetRandomSampler(train_forget_index[::10])
    train_remain_index_0_1 = list((set(train_forget_index) - set(train_forget_index[::10])).union(set(train_remain_index)))
    train_remain_sampler_0_1 =  SubsetRandomSampler(train_remain_index_0_1)

    train_forget_sampler_0_05 =  SubsetRandomSampler(train_forget_index[::20])
    train_remain_index_0_05 = list((set(train_forget_index) - set(train_forget_index[::20])).union(set(train_remain_index)))
    train_remain_sampler_0_05 =  SubsetRandomSampler(train_remain_index_0_05)

    train_forget_sampler_0_01 =  SubsetRandomSampler(train_forget_index[::100])
    train_remain_index_0_01 = list((set(train_forget_index) - set(train_forget_index[::100])).union(set(train_remain_index)))
    train_remain_sampler_0_01 =  SubsetRandomSampler(train_remain_index_0_01)


    train_remain_sampler = SubsetRandomSampler(train_remain_index)#45000

    repair_class_sampler = SubsetRandomSampler(repair_class_index)

    test_forget_sampler = SubsetRandomSampler(test_forget_index)#1000
    test_remain_sampler = SubsetRandomSampler(test_remain_index)#9000

    train_forget_loader = torch.utils.data.DataLoader(dataset=trainset, batch_size=batch_size, sampler=train_forget_sampler)
    train_forget_loader_0_5 = torch.utils.data.DataLoader(dataset=trainset, batch_size=batch_size, sampler=train_forget_sampler_0_5)
    train_forget_loader_0_2 = torch.utils.data.DataLoader(dataset=trainset, batch_size=batch_size, sampler=train_forget_sampler_0_2)
    train_forget_loader_0_1 = torch.utils.data.DataLoader(dataset=trainset, batch_size=batch_size, sampler=train_forget_sampler_0_1)
    train_forget_loader_0_05 = torch.utils.data.DataLoader(dataset=trainset, batch_size=batch_size, sampler=train_forget_sampler_0_05)
    train_forget_loader_0_01 = torch.utils.data.DataLoader(dataset=trainset, batch_size=batch_size, sampler=train_forget_sampler_0_01)

    train_remain_loader = torch.utils.data.DataLoader(dataset=trainset, batch_size=batch_size, sampler=train_remain_sampler)
    train_remain_loader_0_5 = torch.utils.data.DataLoader(dataset=trainset, batch_size=batch_size, sampler=train_remain_sampler_0_5)
    train_remain_loader_0_2 = torch.utils.data.DataLoader(dataset=trainset, batch_size=batch_size, sampler=train_remain_sampler_0_2)
    train_remain_loader_0_1 = torch.utils.data.DataLoader(dataset=trainset, batch_size=batch_size, sampler=train_remain_sampler_0_1)
    train_remain_loader_0_05 = torch.utils.data.DataLoader(dataset=trainset, batch_size=batch_size, sampler=train_remain_sampler_0_05)
    train_remain_loader_0_01 = torch.utils.data.DataLoader(dataset=trainset, batch_size=batch_size, sampler=train_remain_sampler_0_01)

    repair_class_loader = torch.utils.data.DataLoader(dataset=trainset, batch_size=batch_size, sampler=repair_class_sampler)

    test_forget_loader = torch.utils.data.DataLoader(dataset=testset, batch_size=batch_size, sampler=test_forget_sampler)
    test_remain_loader = torch.utils.data.DataLoader(dataset=testset, batch_size=batch_size, sampler=test_remain_sampler)

    return train_forget_loader, train_forget_loader_0_5, train_forget_loader_0_2, train_forget_loader_0_1, train_forget_loader_0_05, train_forget_loader_0_01, \
    train_remain_loader, train_remain_loader_0_5, train_remain_loader_0_2, train_remain_loader_0_1, train_remain_loader_0_05, train_remain_loader_0_01, \
        test_forget_loader, test_remain_loader, repair_class_loader, train_forget_index, train_remain_index, test_forget_index, test_remain_index
    

def get_unlearn_loader(trainset, testset, forget_ratio=0.05, batch_size=64):
    
    index = np.arange(len(trainset))
    test_index = np.arange(len(testset))
    num_to_forget = int(len(index)*forget_ratio)
    forget_index = np.random.RandomState(seed).choice(index, num_to_forget, replace=False).tolist()
    remain_index = list(set(index) - set(forget_index))
    remain_index = sorted(remain_index, key=lambda x: random.random())
    val_index = np.random.RandomState(seed).choice(test_index, num_to_forget, replace=False).tolist()
    train_forget_sampler = SubsetRandomSampler(forget_index)#5000
    train_remain_sampler = SubsetRandomSampler(remain_index)#40000
    test_val_sampler = SubsetRandomSampler(val_index)#5000

    sorted_train_sampler = SequentialSampler(remain_index + forget_index)

    train_forget_loader = torch.utils.data.DataLoader(dataset=trainset, batch_size=batch_size, sampler=train_forget_sampler, shuffle=False, num_workers=0)
    train_remain_loader = torch.utils.data.DataLoader(dataset=trainset, batch_size=batch_size, sampler=train_remain_sampler, shuffle=False, num_workers=0)
    test_val_loader = torch.utils.data.DataLoader(dataset=testset, batch_size=batch_size, sampler=test_val_sampler, num_workers=0)
    sorted_train_loader = torch.utils.data.DataLoader(dataset=trainset, batch_size=batch_size, sampler=sorted_train_sampler, shuffle=False, num_workers=0)

    test_loader = torch.utils.data.DataLoader(dataset=testset, batch_size=batch_size, shuffle=True, num_workers=0)
    return train_forget_loader, train_remain_loader, test_val_loader, test_loader,sorted_train_loader, forget_index, remain_index




## MIA

In [5]:
from sklearn.svm import SVC
from sklearn.linear_model import LogisticRegression
import tqdm
import torch.nn.functional as F

def entropy(p, dim = -1, keepdim = False):
    return -torch.where(p > 0, p * p.log(), p.new([0.0])).sum(dim=dim, keepdim=keepdim)

def collect_prob(data_loader, model):
    prob = []
    with torch.no_grad():
        for idx, batch in enumerate(data_loader):
            batch = [tensor.to(next(model.parameters()).device) for tensor in batch]
            data, target = batch
            output = model(data)
            prob.append(F.softmax(output, dim=-1).data)
    return torch.cat(prob)

def get_membership_attack_data(retain_loader, forget_loader, test_loader, model):    
    retain_prob = collect_prob(retain_loader, model)[0:40000:4,:]
    forget_prob = collect_prob(forget_loader, model)
    test_prob = collect_prob(test_loader, model)
 
    X_r = torch.cat([entropy(retain_prob), entropy(test_prob)]).cpu().numpy().reshape(-1, 1)
    Y_r = np.concatenate([np.ones(retain_prob.size(0)), np.zeros(test_prob.size(0))])
    
    X_f = entropy(forget_prob).cpu().numpy().reshape(-1, 1)
    Y_f = np.concatenate([np.ones(len(forget_prob))])    
    return X_f, Y_f, X_r, Y_r

def get_membership_attack_prob(retain_loader, forget_loader, test_loader, model):
    
    X_f, Y_f, X_r, Y_r = get_membership_attack_data(retain_loader, forget_loader, test_loader, model) 
    clf = LogisticRegression(class_weight='balanced',solver='lbfgs',multi_class='multinomial')
    clf.fit(X_r, Y_r)
    results = clf.predict(X_f)
    return results.mean()

def get_entropy(retain_loader,forget_loader,test_loader, model):
    retain_prob = collect_prob(retain_loader, model) 
    forget_prob = collect_prob(forget_loader, model) 
    test_prob = collect_prob(test_loader, model) 
    retain_entropy = entropy(retain_prob)
    forget_entropy = entropy(forget_prob)
    test_entropy = entropy(test_prob)
    return retain_entropy, forget_entropy, test_entropy


def membership_attack(retain_loader,forget_loader,test_loader,model,model_name, printable=True):
    prob = get_membership_attack_prob(retain_loader,forget_loader,test_loader,model)
    if printable == True:
        print("Attack prob for %s: %f" % (model_name, prob))
    return prob


## train and save a retrain model


In [10]:
from models import init_params as w_init
from models import TargetNet, AllCNN
from trainer import train, eval, loss_picker, optimizer_picker

ori_model = torch.load('./checkpoints/original_best_cifar10.pth', map_location=torch.device('cpu')).to(device)

# forget_class = 4
num_forget = 5000
num_epochs = 27

for forget_ratio in [0.1]:
    train_forget_loader, train_remain_loader, train_val_loader, test_loader, _, forget_index, remain_index \
            = get_unlearn_loader(trainset, testset, forget_ratio=forget_ratio, batch_size=64)
    num_classes = max(train_loader.dataset.targets) + 1 #if args.num_classes is None else args.num_classes
    
    UA, RA, TA, MIA, TC = [], [], [], [], []

    for cnt in range(5):
        
        retrain_model = AllCNN(n_channels=3, num_classes=num_classes, filters_percentage=0.5).to(device)
        criterion = loss_picker('cross')
        optimizer = optimizer_picker(optimization, retrain_model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=0.0)
        best_acc = 0

        start = time.perf_counter()
        for epo in range(num_epochs):
            train(model=retrain_model, data_loader=train_remain_loader, criterion=criterion, optimizer=optimizer, loss_mode='cross', device=device)
            _, acc = eval(model=retrain_model, data_loader=test_loader, mode='', print_perform=False, device=device)#class-wise
            print('EPOCH[{}]; test acc:{}'.format(epo+1, acc))

            if acc >= best_acc:
                best_acc = acc
                torch.save(retrain_model, './checkpoints/retrain_best_cifar10_sample_level_forget_ratio{%.4f}.pth'%forget_ratio)
        end = time.perf_counter()
        tm = end-start

        _, all_test_acc = eval(model=retrain_model, data_loader=test_loader, mode='', print_perform=False, device=device)
        _, remain_train_acc = eval(model=retrain_model, data_loader=train_remain_loader, mode='', print_perform=False, device=device)
        _, forget_train_acc = eval(model=retrain_model, data_loader=train_forget_loader, mode='', print_perform=False, device=device)
        prob = membership_attack(train_remain_loader, train_forget_loader, test_loader, retrain_model, model_name='unlearn_model', printable=False)
        print("[Retrained model] ASR:{%.4f}, all test acc:{%.4f}, forget acc:{%.4f}, remain acc:{%.4f}"%(prob, all_test_acc, forget_train_acc, remain_train_acc))
        
        TA.append(all_test_acc.cpu().detach().numpy())
        UA.append(forget_train_acc.cpu().detach().numpy())
        RA.append(remain_train_acc.cpu().detach().numpy())
        MIA.append(prob)
        TC.append(tm)
        
    print("Forget Ratio:[%.4f]"%forget_ratio)
    print(TC)
    print('TA:{%.5f}({%.5f}), UA:{%.5f}({%.5f}), RA:{%.5f}({%.5f}), ASR:{%.5f}({%.5f}), time:{%.5f}({%.5f})'\
            %(np.mean(TA), np.std(TA), np.mean(UA), np.std(UA), \
                np.mean(RA), np.std(RA), np.mean(MIA), np.std(MIA),\
                    np.mean(TC), np.std(TC)))


[Retrained model] ASR:{0.6754}, all test acc:{0.8141}, forget acc:{0.8224}, remain acc:{0.9840}
[Retrained model] ASR:{0.6828}, all test acc:{0.8332}, forget acc:{0.8394}, remain acc:{0.9960}
[Retrained model] ASR:{0.6752}, all test acc:{0.8148}, forget acc:{0.8300}, remain acc:{0.9900}
[Retrained model] ASR:{0.6804}, all test acc:{0.7977}, forget acc:{0.8120}, remain acc:{0.9675}
[Retrained model] ASR:{0.6820}, all test acc:{0.8231}, forget acc:{0.8374}, remain acc:{0.9938}
Forget Ratio:[0.1000]
[854.7909898999997, 853.8453309999995, 853.6774986, 803.2181195000012, 803.8463704999995]
TA:{0.81658}({0.01170}), UA:{0.82824}({0.01009}), RA:{0.98624}({0.01020}), ASR:{0.67916}({0.00325}), time:{833.87566}({24.77900})


## angular unlearning

In [6]:
from utils import inf_generator
import tqdm
import matplotlib.pyplot as plt
import cv2
from PIL import Image
import torch.nn.functional as F
import time

def collect_feature_mean_std(model, loader, num_class):
    
    feats, labels = [], []
    model = model.to(device)
    for data,label in loader:
        data, label = data.to(device), label.to(device)
        feat = model.features(data)#
        feats.append(feat.detach())
        labels.append(label)

    feats = torch.cat(feats, dim=0) 
    labels = torch.cat(labels, dim=0) 

    assert feats.shape[0] == labels.shape[0]
    class_mean = torch.zeros(num_class, feats.shape[-1])
    class_std = torch.zeros(num_class, feats.shape[-1])
    class_cov = [torch.zeros((feats.shape[1], feats.shape[1]))]*num_class

    for c in range(num_class):
        index = torch.where(labels == c)[0]
        if index.size(0) == 0:
            _std, _mean = torch.zeros(num_class), torch.zeros(num_class)
        else:
            _std, _mean = torch.std_mean(feats[index], dim=0, unbiased=True)
        class_mean[c] = _mean#[10]
        class_std[c] = _std#[10] 
        class_cov[c] = torch.cov(feats[index].T).cpu().detach().numpy()
    
    return class_mean.cpu().detach().numpy(), class_std.cpu().detach().numpy(), class_cov, feats, labels

def collect_angular_info(model, loader, num_class):
    
    feats, labels = [], []
    model = model.to(device)
    for data,label in loader:
        data, label = data.to(device), label.to(device)
        feat = model(data)
        feats.append(feat.detach())
        labels.append(label)

    feats = torch.cat(feats, dim=0) 
    labels = torch.cat(labels, dim=0) 

    assert feats.shape[0] == labels.shape[0]
    class_mean = torch.zeros(num_class, feats.shape[-1])
    class_std = torch.zeros(num_class, feats.shape[-1])
    class_cov = [torch.zeros((feats.shape[1], feats.shape[1]))]*num_class

    for c in range(num_class):
        index = torch.where(labels == c)[0]
        if index.size(0) == 0:
            _std, _mean = torch.zeros(num_class), torch.zeros(num_class)
        else:
            _std, _mean = torch.std_mean(feats[index], dim=0, unbiased=True)
        class_mean[c] = _mean#[10]
        class_std[c] = _std#[10]
        class_cov[c] = torch.cov(feats[index].T).cpu().detach().numpy()

    class_mean = class_mean.to(device)
    class_center = [p/p.norm() for p in class_mean]

    distances = torch.cdist(feats, torch.stack(class_center))
    min_dist, indices = torch.min(distances, dim=1)#_, 
    sample_class_centroids = torch.stack(class_center)[indices]

    distances.scatter_(1, indices.unsqueeze(1), float('inf'))
    min_other_dist, nearest_other_labels = torch.min(distances, dim=1)
    nearest_other_centroids = torch.stack(class_center)[nearest_other_labels]
    
    element_angular = []
    element_angular_no = []
    for i in range(labels.size(0)):
        element_angular.append(np.degrees(torch.acos(torch.dot(feats[i]/feats[i].norm(), class_center[labels[i]])).cpu().detach().numpy()))#.cpu().detach().numpy())
        element_angular_no.append(np.degrees(torch.acos(torch.dot(feats[i]/feats[i].norm(), nearest_other_centroids[i])).cpu().detach().numpy()))

    return element_angular, element_angular_no, class_center, sample_class_centroids, nearest_other_centroids

def get_element_cos(features, centers): 
    normed_feats = features / torch.norm(features, dim=1, keepdim=True)
    normed_centers = centers / torch.norm(centers, dim=1, keepdim=True)
    cos_sim = torch.matmul(normed_feats, normed_centers.T)

    return cos_sim

def get_batch_angulars(features, labels, centroids, model):
    params = model.parameters()#.clone().detach()
    w = list(params)[-2:][0].clone().detach()
    features = features / torch.norm(features, dim=1, keepdim=True)
    center = centroids / torch.norm(centroids, dim=1, keepdim=True)

    w = w / w.norm()
    distances = torch.mm(features, center.t().contiguous())
    
    min_dist, indices = torch.max(distances, dim=1)
    batch_centroids =centroids[labels]
    batch_avh = avh(distances, labels)

    distances.scatter_(1, indices.unsqueeze(1), float('-inf'))
    min_other_dist, nearest_other_labels = torch.max(distances, dim=1)
    batch_nearest_centroids = centroids[nearest_other_labels]

    return batch_centroids, batch_nearest_centroids, batch_avh.detach(), nearest_other_labels


def avh(cosine_dists, targets):
    """'
    @param cosine_dists: B x C
    @param targets: C
    @return:
    """
    ang_dists = torch.acos(torch.clamp(cosine_dists, -1.0 + 1e-7, 1.0 - 1e-7))
    avh = (
          ( ang_dists.gather(1, targets[:, None]).squeeze()) / (  ang_dists.sum(1, keepdim=True).squeeze())
    )
    return avh #

def AngLoss(feature, label, weight, s=None, m=0.35, avh=None):
    feat_norm = F.normalize(feature, p=2, dim=1)
    weight_norm =F.normalize(weight, p=2, dim=1)
    cosine = torch.mm(feat_norm, weight_norm.t().contiguous())
    
    m_hot = nn.functional.one_hot(label, num_classes=cosine.size(1)) * m
    output = cosine - m_hot
    if s != None:
        output *= s
    else:
        output *= torch.norm(feature, p=2, dim=1).view(-1, 1) 
    loss = F.cross_entropy(output, label)

    return loss

unlearn_epoch= 10
num_class=10


for m1 in [  0.4]:
    for m2 in [  1.0]: 
        print("="*100)
        print('[m1 = %.2f] [m2 = %.2f]'%(m1, m2))
        UA, RA, TA, MIA, TC = [], [], [], [], []
        for forget_ratio in [ 0.1,0.1,0.1,0.1,0.1]:
            train_forget_loader, train_remain_loader, train_val_loader, test_loader, sorted_train_loader, forget_index, remain_index \
                = get_unlearn_loader(trainset, testset, forget_ratio=forget_ratio, batch_size=64)
            retrain_model = torch.load('./checkpoints/retrain_best_cifar10_sample_level_forget_ratio{%.4f}.pth'%forget_ratio, map_location=torch.device('cpu')).to(torch.device('cuda'))    
            ori_model = torch.load('./checkpoints/original_best_cifar10.pth', map_location=torch.device('cpu')).to(torch.device('cuda'))
            test_model = copy.deepcopy(ori_model).to(device)
            unlearn_model = copy.deepcopy(ori_model).to(device)
            class_center, _, _, _, _ = collect_feature_mean_std(test_model, train_forget_loader, num_class=10)
            class_center = torch.tensor(class_center).to(device)
            forget_data_gen = inf_generator(train_forget_loader)
            repair_data_gen = inf_generator(train_remain_loader)
            batches_per_epoch = len(train_forget_loader)

            criterion = loss_picker('cross')
            optimizer = optimizer_picker(optimization, unlearn_model.parameters(), lr=0.001, momentum=0.9)

            start = time.perf_counter()
            list_sim_cno = []
            list_sim_c = []
            norm_of_Df = []
            num_of_hs = []
            margin_avh = []
            
            for itr in range(unlearn_epoch * batches_per_epoch):
                
                x, y = forget_data_gen.__next__()
                x = x.to(device)
                y = y.to(device)
                xr, yr = repair_data_gen.__next__()
                xr, yr = xr.to(device), yr.to(device)

                centers = class_center[y]
                centers = centers.to(device)
                #target label

                unlearn_model.train()
                unlearn_model.zero_grad()
                optimizer.zero_grad()
                
                features = unlearn_model.features(x)
                ref_features = unlearn_model.features(xr)
                batch_centroids, batch_nearest_centroids, batch_avh, nearest_other_labels = get_batch_angulars(features, y, class_center, unlearn_model)
                _, _, _, ori_nearest_other_labels = get_batch_angulars(features, y, class_center, test_model)
                params = unlearn_model.parameters()
                weight = list(params)[-2:][0]
                weight_norm =F.normalize(weight, p=2, dim=1)

                easy_indice = torch.where(batch_avh < batch_avh.mean())[0]
                hard_indice = torch.where(batch_avh > batch_avh.mean())[0]
                loss_norm =  0.5*torch.norm(features, p=2, dim=1).mean()
                
                loss_Df = AngLoss(features[hard_indice], ori_nearest_other_labels[hard_indice], weight,s=40,m=0)#
                loss_Dr = AngLoss(ref_features, yr, weight, s=40, m=0.45) 
                loss =    0.3*loss_Df+ 0.8*loss_Dr

                loss.backward()
                optimizer.step()

            end = time.perf_counter()
            tm = end-start
            # print('Time Consuming:', tm, 'secs' )

            # torch.save(unlearn_model, './checkpoints/unlearned_model_cifar10_ratio%.3f.pth'%(forget_ratio))

            print("Forget Ratio:[%.4f]"%forget_ratio)
            _, all_test_acc = eval(model=unlearn_model, data_loader=test_loader, mode='', print_perform=False, device=device)
            _, remain_train_acc = eval(model=unlearn_model, data_loader=train_remain_loader, mode='', print_perform=False, device=device)
            _, forget_train_acc = eval(model=unlearn_model, data_loader=train_forget_loader, mode='', print_perform=False, device=device)
            prob = membership_attack(train_remain_loader, train_forget_loader, test_loader, unlearn_model, model_name='unlearn_model', printable=False)
            remain_prob = membership_attack(train_remain_loader, train_remain_loader, test_loader, unlearn_model, model_name='unlearn_model', printable=False)
            print("[Unlearned model] ASR-U:{%.4f}, ASR-R:{%.4f}, all test acc:{%.4f}, forget acc:{%.4f}, remain acc:{%.4f}"%(prob, remain_prob,all_test_acc, forget_train_acc, remain_train_acc))

            TA.append(all_test_acc.cpu().detach().numpy())
            UA.append(forget_train_acc.cpu().detach().numpy())
            RA.append(remain_train_acc.cpu().detach().numpy())
            MIA.append(prob)
            TC.append(tm)
            
        print("Forget Ratio:[%.4f]"%forget_ratio)
        print(TC)
        print('TA:{%.5f}({%.5f}), UA:{%.5f}({%.5f}), RA:{%.5f}({%.5f}), ASR:{%.5f}({%.5f}), time:{%.5f}({%.5f})'\
            %(np.mean(TA), np.std(TA), np.mean(UA), np.std(UA), \
                np.mean(RA), np.std(RA), np.mean(MIA), np.std(MIA),\
                    np.mean(TC), np.std(TC)))


[m1 = 0.40] [m2 = 1.00]
Forget Ratio:[0.1000]
[Unlearned model] ASR-U:{0.6512}, ASR-R:{0.9444}, all test acc:{0.8341}, forget acc:{0.8564}, remain acc:{0.9955}
Forget Ratio:[0.1000]
[Unlearned model] ASR-U:{0.6468}, ASR-R:{0.9430}, all test acc:{0.8248}, forget acc:{0.8426}, remain acc:{0.9956}
Forget Ratio:[0.1000]
[Unlearned model] ASR-U:{0.6482}, ASR-R:{0.9344}, all test acc:{0.8274}, forget acc:{0.8470}, remain acc:{0.9946}
Forget Ratio:[0.1000]
[Unlearned model] ASR-U:{0.6648}, ASR-R:{0.9367}, all test acc:{0.8316}, forget acc:{0.8438}, remain acc:{0.9929}
Forget Ratio:[0.1000]
[Unlearned model] ASR-U:{0.6634}, ASR-R:{0.9340}, all test acc:{0.8224}, forget acc:{0.8442}, remain acc:{0.9934}
Forget Ratio:[0.1000]
[90.91720250000071, 94.9629666999972, 93.5495166000037, 94.29275659999985, 94.46772010000132]
TA:{0.82806}({0.00429}), UA:{0.84680}({0.00501}), RA:{0.99441}({0.00108}), ASR:{0.65488}({0.00767}), time:{93.63803}({1.43410})
