In [1]:
%env CUDA_VISIBLE_DEVICES=0

env: CUDA_VISIBLE_DEVICES=0


In [2]:
import os
import numpy as np
import gc

import torch
from torch import nn
import torch.backends.cudnn as cudnn

from dataloader import mnist
from models import FullyConnectedNet, TinyNet, ResNet18
from src import hessians, selection, utils

device = "cuda" if torch.cuda.is_available() else "cpu"
target_removal_label = 8

In [3]:
def load_net(net, path):
    assert os.path.isfile(path), "Error: no checkpoint file found!"
    checkpoint = torch.load(path)
    net.load_state_dict(checkpoint["net"])
    return net


def save_net(net, path):
    dir, filename = os.path.split(path)
    if not os.path.isdir(dir):
        os.makedirs(dir)

    state = {
        "net": net.state_dict(),
    }
    torch.save(state, path)

def test(net, dataloader, criterion, label, include):
    net.eval()
    with torch.no_grad():
        net_loss = 0
        correct = 0
        num_inputs = 0
        for _, (inputs, targets) in enumerate(dataloader):
            if include:
                idx = targets == label
            else:
                idx = targets != label

            inputs, targets = inputs[idx], targets[idx]
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            net_loss += loss * len(inputs)
            num_inputs += targets.size(0)
            _, predicted = outputs.max(1)
            correct += predicted.eq(targets).sum().item()

        accuracy = correct / num_inputs * 100
        net_loss /= num_inputs
        return net_loss, accuracy
    
def influence_test(net, dataloader, criterion, removal_label):
    net.eval()
    with torch.no_grad():
        
        remain_loss = 0
        remain_correct = 0
        num_remain_inputs = 0
        
        removal_loss = 0
        removal_correct = 0
        num_removal_inputs = 0

        for _, (inputs, targets) in enumerate(dataloader):
            removal_index = (targets == removal_label)

            remain_inputs, remain_targets = inputs[~removal_index], targets[~removal_index]
            remain_inputs, remain_targets = remain_inputs.to(device), remain_targets.to(device)
            outputs = net(remain_inputs)
            loss = criterion(outputs, remain_targets)
            remain_loss += loss * len(remain_inputs)
            num_remain_inputs += len(remain_inputs)
            _, predicted = outputs.max(1)
            remain_correct += predicted.eq(remain_targets).sum().item()

            removal_inputs, removal_targets = inputs[removal_index], targets[removal_index]
            removal_inputs, removal_targets = removal_inputs.to(device), removal_targets.to(device)
            outputs = net(removal_inputs)
            loss = criterion(outputs, removal_targets)
            removal_loss += loss * len(removal_inputs)
            num_removal_inputs += len(removal_inputs)
            _, predicted = outputs.max(1)
            removal_correct += predicted.eq(removal_targets).sum().item()
            
        remain_accuracy = remain_correct / num_remain_inputs * 100
        remain_loss /= num_remain_inputs

        removal_accuracy = removal_correct / num_removal_inputs * 100
        removal_loss /= num_removal_inputs
        
        return remain_accuracy, remain_loss, removal_accuracy, removal_loss
    
def sample_test(net, criterion, inputs, targets):
    net.eval()
    with torch.no_grad():
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        _, predicted = outputs.max(1)
        correct = predicted.eq(targets).sum().item()

        accuracy = correct / len(inputs) * 100
        
        return loss, accuracy

def f1_score(test_acc, self_acc):
    self_acc /= 100
    test_acc /= 100
    return 2 * (1 - self_acc) * test_acc / (1.00001 - self_acc + test_acc)

### Building model and set criterion

In [4]:
if device == "cuda":
    cudnn.benchmark = True
    
net = ResNet18(1).to(device)
net_name = "ResNet18"
net_path = f"checkpoints/tab2/{net_name}/cross_entropy/ckpt_0.0.pth"

net = load_net(net, net_path)

net_name = net.__class__.__name__
num_param = sum(p.numel() for p in net.parameters() if p.requires_grad)
print(
    f"==> Building {net_name} finished. "
    + f"\n    Number of parameters: {num_param}"
)

criterion = nn.CrossEntropyLoss()

# Data
print("==> Preparing data..")
batch_size = 256
num_workers = 24
num_sample_batch = 1
num_target_sample = 512

data_loader = mnist.MNISTDataLoader(batch_size, num_workers, validation=False)
train_loader, test_loader = data_loader.get_data_loaders()

# loss, acc = test(net, test_loader, criterion)
# print(
#     f"Original loss and acc : {loss:.4f}, {acc:.2f}%"
# )

==> Building ResNet finished. 
    Number of parameters: 11172810
==> Preparing data..


In [5]:
verbose = True

removal_inputs = list()
removal_targets = list()
for batch_idx, (inputs, targets) in enumerate(train_loader):
    idx = targets == target_removal_label
    removal_inputs.append(inputs[idx])
    removal_targets.append(targets[idx])
removal_inputs = torch.cat(removal_inputs)
removal_targets = torch.cat(removal_targets)

In [6]:
parser_list = [selection.TopNActivations,
               selection.TopNGradients,
               selection.Random,
               selection.Threshold,]

ratio_list = [10, 30, 50, 100]
tol = 1e-9

for i in range(10):
    # Sampling the target removal data
    sample_idx = np.random.choice(len(removal_inputs), num_target_sample, replace=False)
    
    for parser in parser_list:
        for param_ratio in ratio_list:                      
            # Initialize network
            net = ResNet18(1).to(device)
            net = load_net(net, net_path)
            net_parser = parser(net, param_ratio/100)
            print(f"Parser: {net_parser.__class__.__name__}, param_ratio: {param_ratio}%")
            
            # Compute total loss
            total_loss = 0
            for batch_idx, (inputs, targets) in enumerate(train_loader):
                if batch_idx >= num_sample_batch:
                    break
                idx = targets != target_removal_label
                inputs, targets = inputs[idx], targets[idx]
                total_loss += criterion(net(inputs.to(device)), targets.to(device))
            total_loss /= num_sample_batch

            # Register hook to select the parameters
            net_parser.register_hooks()

            # Compute target loss
            sample_removal_inputs = removal_inputs[sample_idx]
            sample_removal_targets = removal_targets[sample_idx]
            target_loss = (
                criterion(net(sample_removal_inputs.to(device)), sample_removal_targets.to(device))
                * len(removal_inputs) / len(train_loader.dataset)
            )
            
            # Exception handling for backward hook
            if isinstance(net_parser, selection.TopNGradients):
                target_loss.backward()
                net_parser.remove_hooks()
                target_loss = (
                criterion(net(sample_removal_inputs.to(device)), sample_removal_targets.to(device))
                * len(removal_inputs) / len(train_loader.dataset)
                )

            # Get index list to compute GIF
            index_list = net_parser.get_parameters()
            net_parser.remove_hooks()

            influence = hessians.generalized_influence(
                    net, total_loss, target_loss, index_list, tol=tol, step=3, max_iter=30, verbose=False
            )
            del total_loss, target_loss
            gc.collect()
            torch.cuda.empty_cache()
            
            scale = 1
            count = 1
            best_score = -1
            save_path = (
                f"checkpoints/tab1/PIF/{net_name}/{net_parser.__class__.__name__}/{param_ratio}_{i}.pth"
            )
            while True:
#                 net_parser.update_network(influence * scale)
                utils.update_network(net, influence * scale, index_list)
                
#                 self_loss, self_acc = test(net, test_loader, criterion, target_removal_label, True)
#                 exclusive_loss, exclusive_acc = test(net, test_loader, criterion, target_removal_label, False)
                exclusive_acc, exclusive_loss, self_acc, self_loss = influence_test(net, test_loader, 
                                                                                    criterion, target_removal_label)
                score = f1_score(exclusive_acc, self_acc)
                
                if verbose:
                    print(
                    f"{count} - test acc: {exclusive_acc:2.2f}, test loss: {exclusive_loss:.4f}" + \
                    f" | self-acc: {self_acc:2.2f}%, self loss: {self_loss:.4f} | Score: {score:.4f}",
                    end = '\r'
                    ) 
                    
                if best_score < score:
                    best_result = [exclusive_acc, exclusive_loss, self_acc, self_loss, score]
                    best_score = score
                    save_net(net, save_path)

                if exclusive_acc < 75 or self_acc < 0.1 or count > 200:
                    print(f"test acc: {best_result[0]:2.2f}, test loss: {best_result[1]:.4f} | " +
                          f"self-acc: {best_result[2]:2.2f}%, self loss: {best_result[3]:.4f} | " +
                          f"Score: {best_result[4]:.7f} \n") 
                    break

                count += 1
            del net

Parser: TopNActivations, param_ratio: 10%
6 - test acc: 99.22, test loss: 0.0256 | self-acc: 99.28%, self loss: 0.0202 | Score: 0.0143

KeyboardInterrupt: 

### Measure the network utility

In [None]:
net = ResNet18(1).to(device)

# Define lists to contain results
self_loss_list = [[],[],[],[]]
self_acc_list = [[],[],[],[]]
exclusive_loss_list = [[],[],[],[]]
exclusive_acc_list = [[],[],[],[]]

parser_count = 0

for parser in parser_list:
    net_parser = parser(net, 0)
    for param_ratio in ratio_list:
        _, _, test_loader = data_loader.get_data_loaders()
        param_ratio *= 0.01
        
        net_path = f"checkpoints/tab1/PIF/ResNet/{net_parser.__class__.__name__}/{param_ratio}_1.pth"
        net = load_net(net, net_path)

        self_loss, self_acc = test(net, test_loader, criterion, 8, True)
        self_loss_list[parser_count].append(self_loss.detach().cpu())
        self_acc_list[parser_count].append(self_acc)
        exclusive_loss, exclusive_acc = test(net, test_loader, criterion, 8, False)

        # Save results in defined lists
        exclusive_loss_list[parser_count].append(exclusive_loss.detach().cpu())
        exclusive_acc_list[parser_count].append(exclusive_acc)

        print(f"{net_parser.__class__.__name__}, {param_ratio*100:2.0f}% - Self: {self_loss:.4f} {self_acc:.2f}% | exclusive loss: {exclusive_loss:.4f}, {exclusive_acc:.2f}%")
        print("")
    parser_count += 1

In [None]:
import pandas as pd

# Show results
print("Self Loss")
for i in range(4):
    self_loss_list[i] = [float(tensor) for tensor in self_loss_list[i]]
data = {"A": ["{:.2f}".format(num) for num in self_loss_list[0]],
        "B": ["{:.2f}".format(num) for num in self_loss_list[1]],
        "C": ["{:.2f}".format(num) for num in self_loss_list[2]],
        "D": ["{:.2f}".format(num) for num in self_loss_list[3]],
       }
self_loss_df = pd.DataFrame(data, index = [f'{num}%' for num in ratio_list])
self_loss_df.columns = ["TopNActivations", "TopNGradients", "Threshold", "Random"]
self_loss_df = self_loss_df.transpose()
print(self_loss_df)

print("\nSelf Accuracy")
data = {"A": ["{:.2f}".format(num) for num in self_acc_list[0]],
        "B": ["{:.2f}".format(num) for num in self_acc_list[1]],
        "C": ["{:.2f}".format(num) for num in self_acc_list[2]],
        "D": ["{:.2f}".format(num) for num in self_acc_list[3]]
       }
self_acc_df = pd.DataFrame(data, index = [f'{num}%' for num in ratio_list])
self_acc_df.columns = ["TopNActivations", "TopNGradients", "Threshold", "Random"]
self_acc_df = self_acc_df.transpose()
print(self_acc_df)

for i in range(4):
    exclusive_loss_list[i] = [float(tensor) for tensor in exclusive_loss_list[i]]
print("\nExclusive Loss")
data = {"A": ["{:.2f}".format(num) for num in exclusive_loss_list[0]],
        "B": ["{:.2f}".format(num) for num in exclusive_loss_list[1]],
        "C": ["{:.2f}".format(num) for num in exclusive_loss_list[2]],
        "D": ["{:.2f}".format(num) for num in exclusive_loss_list[3]],
       }
exclusive_loss_df = pd.DataFrame(data, index = [f'{num}%' for num in ratio_list])
exclusive_loss_df.columns = ["TopNActivations", "TopNGradients", "Threshold", "Random"]
exclusive_loss_df = exclusive_loss_df.transpose()
print(exclusive_loss_df)

print("\nExclusive Accuracy")
data = {"A": ["{:.2f}".format(num) for num in exclusive_acc_list[0]],
        "B": ["{:.2f}".format(num) for num in exclusive_acc_list[1]],
        "C": ["{:.2f}".format(num) for num in exclusive_acc_list[2]],
        "D": ["{:.2f}".format(num) for num in exclusive_acc_list[3]],
       }
exclusive_acc_df = pd.DataFrame(data, index = [f'{num}%' for num in ratio_list])
exclusive_acc_df.columns = ["TopNActivations", "TopNGradients", "Threshold", "Random"]
exclusive_acc_df = exclusive_acc_df.transpose()
print(exclusive_acc_df)

In [None]:
import pickle

# Save list files of results
with open('self_loss_list.pickle', 'wb') as f:
    pickle.dump(self_loss, f, pickle.HIGHEST_PROTOCOL)
with open('self_acc_list.pickle', 'wb') as f:
    pickle.dump(self_acc, f, pickle.HIGHEST_PROTOCOL)
with open('exclusive_loss_list.pickle', 'wb') as f:
    pickle.dump(exclusive_loss, f, pickle.HIGHEST_PROTOCOL)
with open('self_acc_list.pickle', 'wb') as f:
    pickle.dump(exclusive_acc, f, pickle.HIGHEST_PROTOCOL)

#### 