In [1]:
import os
import time
from dataclasses import dataclass
import numpy as np

import torch
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch import nn

from dataloader import mnist
from models import FullyConnectedNet, TinyNet, ResNet18
from src import hessians, selection, utils

device = "cuda" if torch.cuda.is_available() else "cpu"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"

In [2]:
def load_net(net, path):
    assert os.path.isfile(path), "Error: no checkpoint file found!"
    checkpoint = torch.load(path)
    net.load_state_dict(checkpoint["net"])
    return net


def save_net(net, path):
    dir, filename = os.path.split(path)
    if not os.path.isdir(dir):
        os.makedirs(dir)

    state = {
        "net": net.state_dict(),
    }
    torch.save(state, path)

In [3]:
def forward(net, dataloader, criterion, num_batch_sample: int=-1):
    net_loss = 0
    num_batch_sample = len(dataloader) if num_batch_sample == -1 else num_batch_sample
    sample_indices = np.random.choice(len(dataloader), size=num_batch_sample, replace=False)
    for batch_idx, (inputs, targets) in enumerate(dataloader):
        if batch_idx in sample_indices:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            net_loss += loss

    net_loss /= num_batch_sample
    return net_loss

### Building model and set criterion

In [4]:
torch.manual_seed(0)
np.random.seed(0)
if device == "cuda":
    cudnn.benchmark = True
    
net = ResNet18(1).to(device)
flatten = False
net_path = "../checkpoints/Figure_3/ResNet/cross_entropy/ckpt_0.0.pth"

net = load_net(net, net_path)

net.eval()
net_name = net.__class__.__name__
num_param = sum(p.numel() for p in net.parameters() if p.requires_grad)
print(
    f"==> Building {net_name} finished. "
    + f"\n    Number of parameters: {num_param}"
)

criterion = nn.CrossEntropyLoss()

==> Building ResNet finished. 
    Number of parameters: 11172810


### Preparing data and register hooks

In [5]:
# Data
print("==> Preparing data..")
batch_size = 512
num_workers = 8

data_loader = mnist.MNISTDataLoader(batch_size, num_workers, flatten=flatten)
train_loader, val_loader, test_loader = data_loader.get_data_loaders()

data = list()
target = list()
for batch_idx, (data_raw, target_raw) in enumerate(train_loader):
    idx = target_raw == 8
    data_raw = data_raw[idx]
    target_raw = target_raw[idx]
    data.append(data_raw)
    target.append(target_raw)
data = torch.cat(data)
target = torch.cat(target)
sample_idx = np.random.choice(len(data), 50, replace=False)
sample_data = data[sample_idx]
sample_target = target[sample_idx]

==> Preparing data..


In [8]:
print("==> Define hooks")
# Make hooks
parser_list = [selection.TopNActivations, 
                  selection.TopNGradients,
                  selection.RandomSelection,
                  selection.Threshold]

==> Define hooks


In [10]:
print("==> Computing influence..")

ratio_list = [1, 3, 5, 10, 15, 20, 25, 30, 50, 60, 75, 100]

==> Computing influence..


In [8]:
rest_parser_list = [#selection.TopNActivations, 
                  #selection.TopNGradients,
                  #selection.RandomSelection,
                  selection.Threshold]

for parser in rest_parser_list:
    net_parser = parser(net, 0)
    for param_ratio in ratio_list:
        print(f"Parser: {net_parser.__class__.__name__}, param_ratio: {param_ratio}%")
        param_ratio *= 0.01
        
        # Initialize configurations
        net = load_net(net, net_path)
        net_parser.set_ratio(param_ratio)

        # Prepare losses and indexes
        total_loss = forward(net, train_loader, criterion, 1)
        
        # Register_hooks
        net_parser.initialize_neurons()
        net_parser.register_hooks()
        target_loss = (
            criterion(net(sample_data.to(device)), sample_target.to(device))
            * len(data)
            / len(train_loader.dataset)
        )
        if isinstance(net_parser, selection.TopNGradients):
            target_loss.backward(retain_graph=True)

        data_ratio = len(train_loader.dataset) / (len(train_loader.dataset) - len(data))
        newton_loss = total_loss * data_ratio - target_loss * (1 - data_ratio)
        index_list = net_parser.get_parameters()

        influence = hessians.partial_influence(
            index_list, target_loss, newton_loss, net, tol=1e-7, step=3
        )
        utils.update_network(net, influence, index_list)
        net_parser.remove_hooks()
        net_path = (
            f"../checkpoints/Figure_3/PIF/{net_name}/{net_parser.__class__.__name__}/{param_ratio}.pth"
        )
        save_net(net, net_path)

        net_parser.remove_hooks()

==> Computing influence..
Parser: Threshold, param_ratio: 1%
Computing partial influence ... [370/10000], Tolerance: 9.918E-08, Avg. computing time: 0.528s          
Parser: Threshold, param_ratio: 3%
Computing partial influence ... [35/10000], Tolerance: 9.805E-08, Avg. computing time: 0.528s          
Parser: Threshold, param_ratio: 5%
Computing partial influence ... [22/10000], Tolerance: 9.904E-08, Avg. computing time: 0.527s          
Parser: Threshold, param_ratio: 10%
Computing partial influence ... [15/10000], Tolerance: 9.883E-08, Avg. computing time: 0.528s          
Parser: Threshold, param_ratio: 15%
Computing partial influence ... [13/10000], Tolerance: 9.927E-08, Avg. computing time: 0.529s          
Parser: Threshold, param_ratio: 20%
Computing partial influence ... [4/10000], Tolerance: 8.001E-08, Avg. computing time: 0.530s          
Parser: Threshold, param_ratio: 25%
Computing partial influence ... [2/10000], Tolerance: 8.172E-08, Avg. computing time: 0.531s         

### Measure the network utility

In [6]:
def test(net, dataloader, criterion, label, include):
    net_loss = 0
    correct = 0
    total = 0
    for _, (inputs, targets) in enumerate(dataloader):
        if include:
            idx = (targets == label)
        else:
            idx = (targets != label)
        inputs = inputs[idx]
        targets = targets[idx]
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        net_loss += loss
        
        total += targets.size(0)
        _, predicted = outputs.max(1)
        correct += predicted.eq(targets).sum().item()

    accuracy = correct / total * 100
    net_loss /= len(dataloader)
    return net_loss, accuracy

In [13]:
net = ResNet18(1).to(device)

# Define lists to contain results
self_loss_list = [[],[],[],[]]
self_acc_list = [[],[],[],[]]
exclusive_loss_list = [[],[],[],[]]
exclusive_acc_list = [[],[],[],[]]

parser_count = 0

for parser in parser_list:
    net_parser = parser(net, 0)
    for param_ratio in ratio_list:
        _, _, test_loader = data_loader.get_data_loaders()
        param_ratio *= 0.01
        #net_path = f"../checkpoints/Figure_3/class_removal/{net_parser.__class__.__name__}/{param_ratio}.pth"
        net_path = f"../checkpoints/Figure_3/PIF/ResNet/{net_parser.__class__.__name__}/{param_ratio}.pth"
        net = load_net(net, net_path)
        
        self_loss, self_acc = test(net, test_loader, criterion, 8, True)
        self_loss_list[parser_count].append(self_loss.detach().cpu())
        self_acc_list[parser_count].append(self_acc)
        print(f"{net_parser.__class__.__name__}, {param_ratio*100:2.0f}% - Self: {self_loss:.4f} {self_acc:.2f}%")
        del self_loss, self_acc
        exclusive_loss, exclusive_acc = test(net, test_loader, criterion, 8, False)
        
        # Save results in defined lists
        exclusive_loss_list[parser_count].append(exclusive_loss.detach().cpu())
        exclusive_acc_list[parser_count].append(exclusive_acc)
        
        print(f"{net_parser.__class__.__name__}, {param_ratio*100:2.0f}% - exclusive loss: {exclusive_loss:.4f}, {exclusive_acc:.2f}%")
        del exclusive_loss, exclusive_acc
    print("")
    parser_count += 1

TopNActivations,  1% - Self: 4.4269 20.53%


OutOfMemoryError: CUDA out of memory. Tried to allocate 90.00 MiB (GPU 0; 23.70 GiB total capacity; 21.17 GiB already allocated; 81.69 MiB free; 22.35 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [12]:
import pandas as pd

# Show results
print("Self Loss")
data = {"A": self_loss_list[0],
        "B": self_loss_list[1],
        "C": self_loss_list[2],
        "D": self_loss_list[3],
       }
self_loss_df = pd.DataFrame(data, index = [f'{num}%' for num in ratio_list])
self_loss_df = self_loss_df.transpose()
self_loss_df.coulmns = ["TopNActivations", "TopNGradients", "Threshold", "Random"]
print(self_loss_df)

print("Self Accuracy")
data = {"A": self_acc_list[0],
        "B": self_acc_list[1],
        "C": self_acc_list[2],
        "D": self_acc_list[3],
       }
self_acc_df = pd.DataFrame(data, index = [f'{num}%' for num in ratio_list])
self_acc_df = self_acc_df.transpose()
self_acc_df.coulmns = ["TopNActivations", "TopNGradients", "Threshold", "Random"]
print(self_acc_df)

print("Exclusive Loss")
data = {"A": exclusive_loss_list[0],
        "B": exclusive_loss_list[1],
        "C": exclusive_loss_list[2],
        "D": exclusive_loss_list[3],
       }
exclusive_loss_df = pd.DataFrame(data, index = [f'{num}%' for num in ratio_list])
exclusive_loss_df = exclusive_loss_df.transpose()
exclusive_loss_df.coulmns = ["TopNActivations", "TopNGradients", "Threshold", "Random"]
print(exclusive_loss_df)

print("Exclusive Accuracy")
data = {"A": exclusive_acc_list[0],
        "B": exclusive_acc_list[1],
        "C": exclusive_acc_list[2],
        "D": exclusive_acc_list[3],
       }
exclusive_acc_df = pd.DataFrame(data, index = [f'{num}%' for num in ratio_list])
exclusive_acc_df = exclusive_acc_df.transpose()
exclusive_acc_df.coulmns = ["TopNActivations", "TopNGradients", "Threshold", "Random"]
print(exclusive_acc_df)

Self Loss


ValueError: Length of values (0) does not match length of index (12)

In [None]:
import pickle

# Save list files of results
with open('self_loss_list.pickle', 'wb') as f:
    pickle.dump(self_loss, f, pickle.HIGHEST_PROTOCOL)
"""
with open('self_acc_list.pickle', 'wb') as f:
    pickle.dump(self_acc, f, pickle.HIGHEST_PROTOCOL)
with open('exclusive_loss_list.pickle', 'wb') as f:
    pickle.dump(exclusive_loss, f, pickle.HIGHEST_PROTOCOL)
with open('self_acc_list.pickle', 'wb') as f:
    pickle.dump(exclusive_acc, f, pickle.HIGHEST_PROTOCOL)
"""

#### 