In [1]:
%env CUDA_VISIBLE_DEVICES=1

env: CUDA_VISIBLE_DEVICES=1


In [2]:
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import gc

import numpy as np
import torch
import torch.backends.cudnn as cudnn
from torch import nn

from dataloader import cifar10
from models import VGG11
from src import freeze_influence, hessians, selection

device = "cuda" if torch.cuda.is_available() else "cpu"
target_removal_label = 1

In [3]:
def load_net(net, path):
    assert os.path.isfile(path), "Error: no checkpoint file found!"
    checkpoint = torch.load(path)
    net.load_state_dict(checkpoint["net"])
    return net


def save_net(net, path):
    dir, filename = os.path.split(path)
    if not os.path.isdir(dir):
        os.makedirs(dir)

    state = {
        "net": net.state_dict(),
    }
    torch.save(state, path)


def test(net, dataloader, criterion, label, include):
    net.eval()
    with torch.no_grad():
        net_loss = 0
        correct = 0
        total = 0
        num_data = 0
        for _, (inputs, targets) in enumerate(dataloader):
            if include:
                idx = targets == label
            else:
                idx = targets != label
            inputs = inputs[idx]
            targets = targets[idx]
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            net_loss += loss * len(inputs)
            num_data +=  len(inputs)

            total += targets.size(0)
            _, predicted = outputs.max(1)
            correct += predicted.eq(targets).sum().item()

        accuracy = correct / total * 100
        net_loss /= num_data
        return net_loss, accuracy

def influence_test(net, dataloader, criterion, target_label):
    def sample_test(net, criterion, inputs, targets):
        net.eval()
        with torch.no_grad():
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            _, predicted = outputs.max(1)
            correct = predicted.eq(targets).sum().item()

            return loss, correct

    self_loss = 0
    self_correct = 0
    num_self_inputs = 0
    
    exclusive_loss = 0
    exclusive_correct = 0
    num_exclusive_inputs = 0
    
    for _, (inputs, targets) in enumerate(dataloader):
        target_idx = (targets == target_label)
        batch_self_loss, batch_self_correct = sample_test(net, criterion, inputs[target_idx], targets[target_idx])
        batch_exclusive_loss, batch_exclusive_correct = sample_test(net, criterion, inputs[~target_idx], targets[~target_idx])
        
        len_self_batch = len(inputs[target_idx])
        self_loss += batch_self_loss * len_self_batch
        self_correct += batch_self_correct
        num_self_inputs += len_self_batch
        
        len_exclusive_batch = len(inputs[~target_idx])
        exclusive_loss += batch_exclusive_loss * len_exclusive_batch
        exclusive_correct += batch_exclusive_correct
        num_exclusive_inputs += len_exclusive_batch
        
    self_loss /= num_self_inputs
    self_acc = self_correct / num_self_inputs * 100
    exclusive_loss /= num_exclusive_inputs
    exclusive_acc = exclusive_correct / num_exclusive_inputs * 100
    
    return self_loss, self_acc, exclusive_loss, exclusive_acc
        
def projected_influence(net, total_loss, target_loss, index_list, tol, step, max_iter, verbose):
    num_param = sum(p.numel() for p in net.parameters() if p.requires_grad)
    full_param_index_list = np.arange(num_param)
    influence = hessians.generalized_influence(
        net, total_loss, target_loss, full_param_index_list, tol=tol, step=step, max_iter=max_iter, verbose=verbose
    )
    return influence[index_list]

def f1_score(self_acc, test_acc):
    self_acc /= 100
    test_acc /= 100
    if self_acc == 1 and test_acc == 0:
        return 0
    return 2 * (1 - self_acc) * test_acc / (1 - self_acc + test_acc)

In [4]:
net = VGG11().to(device)
net_name = "VGG11"

if device == "cuda":
    cudnn.benchmark = True

net_path = f"checkpoints/tab2/{net_name}/cross_entropy/ckpt_0.0.pth"
net = load_net(net, net_path)

net.eval()
num_param = sum(p.numel() for p in net.parameters() if p.requires_grad)
print(
    f"==> Building {net_name} finished. "
    + f"\n    Number of parameters: {num_param}"
)

criterion = nn.CrossEntropyLoss()

# Data
print("==> Preparing data..")
batch_size = 512
num_workers = 16
num_sample_batch = 1
num_target_sample = 512

data_loader = cifar10.CIFAR10DataLoader(batch_size, num_workers, validation=False)
train_loader, test_loader = data_loader.get_data_loaders()

loss, acc = test(net, test_loader, criterion, 11, False)
print(
    f"Original loss and acc : {loss:.4f}, {acc:.2f}%"
)

==> Building VGG11 finished. 
    Number of parameters: 9231114
==> Preparing data..
Files already downloaded and verified
Files already downloaded and verified
Original loss and acc : 0.3518, 91.50%


In [5]:
verbose = True
num_exp = 10

removal_inputs = list()
removal_targets = list()
for batch_idx, (inputs, targets) in enumerate(test_loader):
    idx = targets == target_removal_label
    removal_inputs.append(inputs[idx])
    removal_targets.append(targets[idx])
removal_inputs = torch.cat(removal_inputs)
removal_targets = torch.cat(removal_targets)

ratio_list = [.1, .3, .5]
result_list_GIF = []
result_list_FIF = []
result_list_PIF = []

tol = 1e-9

for exp_iter in range(num_exp):
    sample_idx = np.random.choice(len(removal_inputs), num_target_sample, replace=False)
    for param_ratio in ratio_list:
        for i in range(3):
            if i == 0:
                if_name = "GIF"
            elif i == 1:
                if_name = "FIF"
            else:
                if_name = "PIF"

            print(f"{if_name} - ratio: {param_ratio*100}%, tol: {tol}")
            # Initialize network
            net = load_net(net, net_path)

            # Compute total loss
            total_loss = 0
            for batch_idx, (inputs, targets) in enumerate(test_loader):
                if batch_idx >= num_sample_batch:
                    break
                idx = targets != target_removal_label
                inputs, targets = inputs[idx], targets[idx]
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = net(inputs)
                total_loss += criterion(outputs, targets)

            # Sampling the target removal data
            sample_removal_inputs = removal_inputs[sample_idx]
            sample_removal_targets = removal_targets[sample_idx]
            
            # Make hooks
            net_parser = selection.TopNGradients(net, param_ratio)
            net_parser.register_hooks()

            # Compute target loss
            target_loss = (
                criterion(net(sample_removal_inputs.to(device)), sample_removal_targets.to(device))
                * len(removal_inputs)
                / (len(train_loader.dataset) - len(removal_inputs))
            )
            target_loss.backward()
            net_parser.remove_hooks()

            target_loss = (
                criterion(net(sample_removal_inputs.to(device)), sample_removal_targets.to(device))
                * len(removal_inputs)
                / (len(train_loader.dataset) - len(removal_inputs))
            )
            
            # Delete hooks
            index_list = net_parser.get_parameters()
            net_parser.remove_hooks()
            
            if i == 0:
                influence = hessians.generalized_influence(
                    net, total_loss, target_loss, index_list, tol=tol, step=3, max_iter=30, verbose=False
                )
            elif i == 1:
                influence = freeze_influence.freeze_influence(
                    net, total_loss, target_loss, index_list, tol=tol, step=3, max_iter=30, verbose=False
                )
            else:
                influence = projected_influence(
                    net, total_loss, target_loss, index_list, tol=tol, step=3, max_iter=30, verbose=False
                )
                
            del total_loss, target_loss
            gc.collect()
            torch.cuda.empty_cache()
            
            influence *= 0.03 / torch.norm(influence)
                
            scale = 1 if i!=2 else 1 / param_ratio
            score = 0
            best_score = -1
            count = 1
            save_path = (
                f"checkpoints/tab2/{net_name}/{if_name}/{target_removal_label}_{param_ratio}_{exp_iter}.pth"
            )
            while True:
                if score < .85:
                    net_parser.update_network(influence * scale)
                else:
                    net_parser.update_network(influence * scale / 3)

#                 self_loss, self_acc = test(net, test_loader, criterion, target_removal_label, True)
#                 exclusive_loss, exclusive_acc = test(net, test_loader, criterion, target_removal_label, False)
                self_loss, self_acc, exclusive_loss, exclusive_acc = influence_test(net, test_loader, 
                                                                                    criterion, target_removal_label)
                score = f1_score(self_acc, exclusive_acc)
                
                if best_score < score:
                    best_result = [exclusive_acc, exclusive_loss, self_acc, self_loss, score]
                    best_score = score
                    save_net(net, save_path)
                    
                if verbose:
                    print(
                    f"{count} - test acc: {exclusive_acc:2.2f}, test loss: {exclusive_loss:.4f} |" +
                    f" self-acc: {self_acc:2.2f}%, self loss: {self_loss:.4f} | score: {score:.7f}",
                    end='\r'
                    )
                
                if exclusive_acc < 80 or self_acc < 0.01 or count >= 200:
                    if i == 0:
                        result_list_GIF += best_result
                    elif i == 1:
                        result_list_FIF += best_result
                    else:
                        result_list_PIF += best_result
                    
                    print(f"test acc: {best_result[0]:2.2f}, test loss: {best_result[1]:.4f} | " +
                          f"self-acc: {best_result[2]:2.2f}%, self loss: {best_result[3]:.4f} | " +
                          f"Score: {best_result[4]:.7f}" + " " * 20) 
                    break
                elif count >= 20 and best_score < 0.2:
                    scale *= 5
                elif count >= 50 and best_score < 0.5:
                    print(f"test acc: {best_result[0]:2.2f}, test loss: {best_result[1]:.4f} | " +
                          f"self-acc: {best_result[2]:2.2f}%, self loss: {best_result[3]:.4f} | " +
                          f"Score: {best_result[4]:.7f}" + " " * 20) 
                    break

                count += 1
                
        print("")

GIF - ratio: 10.0%, tol: 1e-09
test acc: 87.63, test loss: 0.5173 | self-acc: 0.30%, self loss: 10.0420 | Score: 0.9327804                    
FIF - ratio: 10.0%, tol: 1e-09
test acc: 85.91, test loss: 0.5995 | self-acc: 1.80%, self loss: 9.3213 | Score: 0.9164543                    
PIF - ratio: 10.0%, tol: 1e-09
test acc: 83.91, test loss: 0.6820 | self-acc: 0.20%, self loss: 11.6029 | Score: 0.9116845                    

GIF - ratio: 30.0%, tol: 1e-09
test acc: 87.33, test loss: 0.5207 | self-acc: 1.60%, self loss: 9.7419 | Score: 0.9253697                    
FIF - ratio: 30.0%, tol: 1e-09
test acc: 87.04, test loss: 0.5346 | self-acc: 1.80%, self loss: 9.6123 | Score: 0.9228632                    
PIF - ratio: 30.0%, tol: 1e-09
test acc: 84.63, test loss: 0.6350 | self-acc: 1.00%, self loss: 10.8946 | Score: 0.9125467                    

GIF - ratio: 50.0%, tol: 1e-09
test acc: 86.63, test loss: 0.5627 | self-acc: 1.50%, self loss: 10.1166 | Score: 0.9218635                    


Traceback (most recent call last):
  File "/home/hslyu/bin/miniconda3/envs/GIF/lib/python3.10/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/home/hslyu/bin/miniconda3/envs/GIF/lib/python3.10/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/home/hslyu/bin/miniconda3/envs/GIF/lib/python3.10/multiprocessing/util.py", line 133, in _remove_temp_dir
    rmtree(tempdir)
  File "/home/hslyu/bin/miniconda3/envs/GIF/lib/python3.10/shutil.py", line 731, in rmtree
    onerror(os.rmdir, path, sys.exc_info())
  File "/home/hslyu/bin/miniconda3/envs/GIF/lib/python3.10/shutil.py", line 729, in rmtree
    os.rmdir(path)
OSError: [Errno 39] Directory not empty: '/tmp/pymp-r9gbu5wt'


test acc: 86.32, test loss: 0.5722 | self-acc: 2.20%, self loss: 9.6750 | Score: 0.9170336                    
FIF - ratio: 30.0%, tol: 1e-09
test acc: 86.07, test loss: 0.5889 | self-acc: 2.10%, self loss: 9.6933 | Score: 0.9160275                    
PIF - ratio: 30.0%, tol: 1e-09
test acc: 85.32, test loss: 0.6138 | self-acc: 3.70%, self loss: 9.4621 | Score: 0.9047935                    

GIF - ratio: 50.0%, tol: 1e-09
test acc: 85.19, test loss: 0.6296 | self-acc: 1.70%, self loss: 10.2571 | Score: 0.9127602                    
FIF - ratio: 50.0%, tol: 1e-09
test acc: 86.34, test loss: 0.5774 | self-acc: 2.50%, self loss: 9.6095 | Score: 0.9158377                    
PIF - ratio: 50.0%, tol: 1e-09
test acc: 84.01, test loss: 0.6804 | self-acc: 1.80%, self loss: 10.5569 | Score: 0.9055311                    

GIF - ratio: 10.0%, tol: 1e-09
test acc: 87.07, test loss: 0.5429 | self-acc: 1.10%, self loss: 9.0663 | Score: 0.9260685                    
FIF - ratio: 10.0%, tol: 1e-09
te

In [8]:
for param_ratio in ratio_list:
    for i in range(3):
        if i == 0:
            if_name = "GIF"
        elif i == 1:
            if_name = "FIF"
        else:
            if_name = "PIF"
        print(f"{if_name} - ratio: {param_ratio*100}%, tol: {tol}")
        
        self_loss_list = np.empty(0)
        self_acc_list = np.empty(0)
        exclusive_loss_list = np.empty(0)
        exclusive_acc_list = np.empty(0)
        f1_score_list = np.empty(0)
        
        for exp_iter in range(num_exp):

            load_path = (
                f"checkpoints/tab2/{net_name}/{if_name}/{target_removal_label}_{param_ratio}_{exp_iter}.pth"
            )
            net = VGG11().to(device)
            net = load_net(net, load_path)
            self_loss, self_acc, exclusive_loss, exclusive_acc = influence_test(net, test_loader, 
                                                                    criterion, target_removal_label)
            score = f1_score(self_acc, exclusive_acc)
            
            self_loss_list = np.append(self_loss_list, self_loss.detach().cpu().numpy())
            self_acc_list = np.append(self_acc_list, self_acc)
            exclusive_loss_list = np.append(exclusive_loss_list, exclusive_loss.detach().cpu().numpy())
            exclusive_acc_list = np.append(exclusive_acc_list, exclusive_acc)
            f1_score_list = np.append(f1_score_list, score)
            print(
            f"  test acc: {exclusive_acc:2.2f}, test loss: {exclusive_loss:.4f} | " +
            f"self-acc: {self_acc:2.2f}%, self loss: {self_loss:.4f} | score: {score:.4f}",
            end='\r'
            )
            
        mean_self_loss = np.mean(self_loss_list)
        mean_self_acc = np.mean(self_acc_list)
        mean_exclusive_loss = np.mean(exclusive_loss_list)
        mean_exclusive_acc = np.mean(exclusive_acc_list)
        mean_f1_score = np.mean(f1_score_list)
                
        var_self_loss = np.var(self_loss_list)
        var_self_acc = np.var(self_acc_list)
        var_exclusive_loss = np.var(exclusive_loss_list)
        var_exclusive_acc = np.var(exclusive_acc_list)
        var_f1_score = np.var(f1_score_list)

        print(
        f"test acc: {mean_exclusive_acc:2.2f}+-{var_exclusive_acc:2.2f}% " +
        f"test loss: {mean_exclusive_loss:.4f}+-{var_exclusive_loss:.4f} ", end=""
        )
        print(
        f"self-acc: {mean_self_acc:2.2f}+-{var_self_acc:2.2f}% " +
        f"self loss: {mean_self_loss:.4f}+-{var_seÍlf_loss:.4f} " +
        f"score: {mean_f1_score:.4f}",
        )


GIF - ratio: 10.0%, tol: 1e-09
test acc: 87.38+-0.89% test loss: 0.5302+-0.0020 self-acc: 0.79+-0.25% self loss: 9.4301+-0.1464 score: 0.9292
FIF - ratio: 10.0%, tol: 1e-09
test acc: 85.31+-2.00% test loss: 0.6272+-0.0055 self-acc: 1.97+-0.36% self loss: 9.3206+-0.0289 score: 0.9123
PIF - ratio: 10.0%, tol: 1e-09
test acc: 82.71+-3.69% test loss: 0.7557+-0.0126 self-acc: 0.20+-0.00% self loss: 11.7498+-0.1092 score: 0.9044
GIF - ratio: 30.0%, tol: 1e-09
test acc: 86.31+-1.17% test loss: 0.5734+-0.0034 self-acc: 1.67+-0.42% self loss: 9.8694+-0.1031 score: 0.9193
FIF - ratio: 30.0%, tol: 1e-09
test acc: 85.92+-1.41% test loss: 0.5941+-0.0041 self-acc: 1.57+-0.40% self loss: 9.9615+-0.1648 score: 0.9175
PIF - ratio: 30.0%, tol: 1e-09
test acc: 84.94+-0.79% test loss: 0.6317+-0.0026 self-acc: 2.41+-1.87% self loss: 10.0482+-0.4226 score: 0.9082
GIF - ratio: 50.0%, tol: 1e-09
test acc: 85.98+-1.07% test loss: 0.5935+-0.0037 self-acc: 1.93+-0.46% self loss: 9.9196+-0.0791 score: 0.9163
FIF 

In [19]:
retrained_net = VGG11().to(device)
net_name = "VGG11"
net_path = f"checkpoints/tab2/{net_name}_retrained/cross_entropy/ckpt_0.0_{target_removal_label}.pth"
print(net_path)
retrained_net = load_net(retrained_net, net_path)

self_loss, self_acc = test(retrained_net, test_loader, criterion, target_removal_label, True)
exclusive_loss, exclusive_acc = test(retrained_net, test_loader, criterion, target_removal_label, False)
score = f1_score(self_acc, exclusive_acc)

print(
f"test acc: {exclusive_acc:2.2f}, test loss: {exclusive_loss:.4f} |" +
f" self-acc: {self_acc:2.2f}%, self loss: {self_loss:.4f} | score: {score:.7f}",
)

checkpoints/tab2/VGG11_retrained/cross_entropy/ckpt_0.0_1.pth
test acc: 91.29, test loss: 0.3578 | self-acc: 0.00%, self loss: 8.7573 | score: 0.9544610


## 