In [None]:
%env CUDA_VISIBLE_DEVICES=0

In [None]:
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
import os

import numpy as np
import torch
import torch.backends.cudnn as cudnn
from torch import nn

from dataloader import svhn
from models import ShuffleNetV2
from src import freeze_influence, hessians, selection, utils

device = "cuda" if torch.cuda.is_available() else "cpu"
target_removal_label = 1

In [None]:
def load_net(net, path):
    assert os.path.isfile(path), "Error: no checkpoint file found!"
    checkpoint = torch.load(path)
    net.load_state_dict(checkpoint["net"])
    return net


def save_net(net, path):
    dir, filename = os.path.split(path)
    if not os.path.isdir(dir):
        os.makedirs(dir)

    state = {
        "net": net.state_dict(),
    }
    torch.save(state, path)


def test(net, dataloader, criterion, label, include):
    net.eval()
    with torch.no_grad():
        net_loss = 0
        correct = 0
        total = 0
        num_data = 0
        for _, (inputs, targets) in enumerate(dataloader):
            if include:
                idx = targets == label
            else:
                idx = targets != label
            inputs = inputs[idx]
            targets = targets[idx]
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            net_loss += loss * len(inputs)
            num_data +=  len(inputs)

            total += targets.size(0)
            _, predicted = outputs.max(1)
            correct += predicted.eq(targets).sum().item()

        accuracy = correct / total * 100
        net_loss /= num_data
        return net_loss, accuracy


def get_full_param_index_list(net):
    """
    Return a list of parameter indices in flatten network.
    Warning: this function only provides indices of params when the param i) has requires_grad=True and 2) belongs to nn.Linear or nn.Conv2d
    """

    index_list = np.array([], dtype=int)
    start_index = 0
    for module in net.modules():
        if not list(module.children()) == []:
            continue

        num_param = sum(p.numel() for p in module.parameters() if p.requires_grad)
        if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
            module_index_list = np.arange(num_param, dtype=int) + start_index
            index_list = np.append(index_list, module_index_list)

        start_index += num_param

    return index_list


def projected_influence(net, total_loss, target_loss, index_list, tol, step, max_iter):
    full_param_index_list = get_full_param_index_list(net)
    influence = hessians.partial_influence(
        net, total_loss, target_loss, full_param_index_list, tol=tol, step=step, max_iter=max_iter, verbose=False
    )
    idx = np.isin(full_param_index_list, index_list)
    return influence[idx], full_param_index_list[idx]

def f1_score(self_acc, test_acc):
    self_acc /= 100
    test_acc /= 100
    return 2 * (1 - self_acc) * test_acc / (1 - self_acc + test_acc)

In [None]:
net = ShuffleNetV2().to(device)
net_name = "ShuffleNetV2"

if device == "cuda":
    cudnn.benchmark = True

net_path = f"checkpoints/tab2/{net_name}/cross_entropy/ckpt_0.0.pth"
net = load_net(net, net_path)

net.eval()
num_param = sum(p.numel() for p in net.parameters() if p.requires_grad)
print(
    f"==> Building {net_name} finished. "
    + f"\n    Number of parameters: {num_param}"
)

criterion = nn.CrossEntropyLoss()

# Data
print("==> Preparing data..")
batch_size = 512
num_workers = 16
num_sample_batch = 1
num_target_sample = 1024

data_loader = svhn.SVHNDataLoader(batch_size, num_workers, validation=False)
train_loader, test_loader = data_loader.get_data_loaders()

loss, acc = test(net, test_loader, criterion, 11, False)
print(
    f"Original loss and acc : {loss:.4f}, {acc:.2f}%"
)

In [None]:
verbose = True

removal_inputs = list()
removal_targets = list()
for batch_idx, (inputs, targets) in enumerate(train_loader):
    idx = targets == target_removal_label
    removal_inputs.append(inputs[idx])
    removal_targets.append(targets[idx])
removal_inputs = torch.cat(removal_inputs)
removal_targets = torch.cat(removal_targets)

ratio_list = [.1, .3, .5]
result_list_GIF = []
result_list_FIF = []
result_list_PIF = []

tol = 1e-9

for param_ratio in ratio_list:
    for i in range(3):
        if i == 0:
            if_name = "GIF"
        elif i == 1:
            if_name = "FIF"
        else:
            if_name = "PIF"

        print(f"{if_name} - ratio: {param_ratio*100}%, tol: {tol}")
        for _ in range(10):
            # Initialize network
            net = load_net(net, net_path)

            # Compute total loss
            total_loss = 0
            for batch_idx, (inputs, targets) in enumerate(train_loader):
                if batch_idx >= num_sample_batch:
                    break
                idx = targets != target_removal_label
                inputs, targets = inputs[idx], targets[idx]
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = net(inputs)
                total_loss += criterion(outputs, targets)

            # Sampling the target removal data
            sample_idx = np.random.choice(len(removal_inputs), num_target_sample, replace=False)
            sample_removal_inputs = removal_inputs[sample_idx]
            sample_removal_targets = removal_targets[sample_idx]
            
            # Make hooks
            net_parser = selection.TopNActivations(net, param_ratio)
            net_parser.register_hooks()

            # Compute target loss
            target_loss = (
                criterion(net(sample_removal_inputs.to(device)), sample_removal_targets.to(device))
                * len(removal_inputs)
                / (len(train_loader.dataset) - len(removal_inputs))
            )
            
            # Delete hooks
            index_list = net_parser.get_parameters()
            net_parser.remove_hooks()

            if i == 0:
                influence = hessians.partial_influence(
                    net, total_loss, target_loss, index_list, tol=tol, step=1, max_iter=30, verbose=verbose
                )
            elif i == 1:
                influence = freeze_influence.freeze_influence(
                    net, total_loss, target_loss, index_list, tol=tol, step=1, max_iter=30, verbose=verbose
                )
            else:
                influence, index_list = projected_influence(
                    net, total_loss, target_loss, index_list, tol=tol, step=5, max_iter=30, verbose=verbose
                )

            scale = 1 if i != 2 else 20
            score = 0
            best_score = -1
            count = 1
            while True:
                if score < .85:
                    utils.update_network(net, influence * scale, index_list)
                else:
                    utils.update_network(net, influence * scale / 10, index_list)

                self_loss, self_acc = test(net, test_loader, criterion, target_removal_label, True)
                exclusive_loss, exclusive_acc = test(net, test_loader, criterion, target_removal_label, False)
                score = f1_score(self_acc, exclusive_acc)
                
                if verbose:
                    print(
                    f"{count}:{scale:.2f} - test acc: {exclusive_acc:2.2f}, test loss: {exclusive_loss:.4f} | self-acc: {self_acc:2.2f}%, self loss: {self_loss:.4f} | Score: {score:.7f}"
                    ) 
                
                if exclusive_acc < .80 or self_acc < 0.01:
                    if i == 0:
                        result_list_GIF += best_result
                    elif i == 1:
                        result_list_FIF += best_result
                    else:
                        result_list_PIF += best_result
                    
                    print(
                    f"test acc: {best_result[0]:2.2f}, test loss: {best_result[1]:.4f} | self-acc: {best_result[2]:2.2f}%, self loss: {best_result[3]:.4f} | Score: {best_result[4]:.7f}"
                    ) 

                    break

                if best_score < score:
                    best_result = [exclusive_acc, exclusive_loss, self_acc, self_loss, score]
                    best_score = score
                    
                count += 1
                
        print("")

In [None]:
"""
retrained_net = DenseNet121().to(device)
net_name = retrained_net.__class__.__name__
net_path = f"../checkpoints/Figure_4/{net_name}/cross_entropy/ckpt_0.0_retrained.pth"
retrained_net = load_net(retrained_net, net_path)
flatten = False

loss, acc = test(retrained_net, test_loader, criterion, 11, False)
print(
    f"Original loss and acc : {loss:.4f}, {acc:.2f}%"
)
self_loss, self_acc = test(retrained_net, test_loader, criterion, 8, True)
exclusive_loss, exclusive_acc = test(retrained_net, test_loader, criterion, 8, False)
print(
    f"Retrained model \t Self: {self_loss:.2f} {self_acc:2.2f}% | Exclusive loss: {exclusive_loss:.2f}, {exclusive_acc:2.2f}%"
)
"""