In [1]:
import os
import time
from dataclasses import dataclass
import numpy as np

import torch
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch import nn

from dataloader import mnist
from models import FullyConnectedNet, TinyNet, ResNet18
from src import hessians, selection, utils

device = "cuda" if torch.cuda.is_available() else "cpu"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"

In [2]:
def load_net(net, path):
    assert os.path.isfile(path), "Error: no checkpoint file found!"
    checkpoint = torch.load(path)
    net.load_state_dict(checkpoint["net"])
    return net


def save_net(net, path):
    dir, filename = os.path.split(path)
    if not os.path.isdir(dir):
        os.makedirs(dir)

    state = {
        "net": net.state_dict(),
    }
    torch.save(state, path)

In [3]:
def forward(net, dataloader, criterion, num_batch_sample: int=-1):
    net_loss = 0
    num_batch_sample = len(dataloader) if num_batch_sample == -1 else num_batch_sample
    sample_indices = np.random.choice(len(dataloader), size=num_batch_sample, replace=False)
    for batch_idx, (inputs, targets) in enumerate(dataloader):
        if batch_idx in sample_indices:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            net_loss += loss

    net_loss /= num_batch_sample
    return net_loss

### Building model and set criterion

In [4]:
torch.manual_seed(0)
np.random.seed(0)
if device == "cuda":
    cudnn.benchmark = True
    
net = ResNet18(1).to(device)
flatten = False
net_path = "../checkpoints/Figure_3/ResNet/cross_entropy/ckpt_0.0.pth"

net = load_net(net, net_path)

net.eval()
net_name = net.__class__.__name__
num_param = sum(p.numel() for p in net.parameters() if p.requires_grad)
print(
    f"==> Building {net_name} finished. "
    + f"\n    Number of parameters: {num_param}"
)

criterion = nn.CrossEntropyLoss()

==> Building ResNet finished. 
    Number of parameters: 11172810


### Preparing data and register hooks

In [5]:
# Data
print("==> Preparing data..")
batch_size = 512
num_workers = 8

data_loader = mnist.MNISTDataLoader(batch_size, num_workers, flatten=flatten)
train_loader, val_loader, test_loader = data_loader.get_data_loaders()

data = list()
target = list()
for batch_idx, (data_raw, target_raw) in enumerate(train_loader):
    idx = target_raw == 8
    data_raw = data_raw[idx]
    target_raw = target_raw[idx]
    data.append(data_raw)
    target.append(target_raw)
data = torch.cat(data)
target = torch.cat(target)
sample_idx = np.random.choice(len(data), 50, replace=False)
sample_data = data[sample_idx]
sample_target = target[sample_idx]

==> Preparing data..


In [6]:
print("==> Define hooks")
# Make hooks
parser_list = [selection.TopNActivations, 
                  selection.TopNGradients,
                  selection.RandomSelection,
                  selection.Threshold]

==> Define hooks


In [None]:
print("==> Computing influence..")

for parser in parser_list:
    net_parser = parser(net, 0)
    for param_ratio in range(5,31,5):
        print(f"Parser: {net_parser.__class__.__name__}, param_ratio: {param_ratio}%")
        param_ratio *= 0.01
        
        # Initialize configurations
        net = load_net(net, net_path)
        net_parser.num_choices = int(num_param * param_ratio)

        # Prepare losses and indexes
        total_loss = forward(net, train_loader, criterion, 1)
        
        # Register_hooks
        net_parser.initialize_neurons()
        net_parser.register_hooks()
        target_loss = (
            criterion(net(sample_data.to(device)), sample_target.to(device))
            * len(data)
            / len(train_loader.dataset)
        )
        if isinstance(net_parser, selection.TopNGradients):
            target_loss.backward(retain_graph=True)

        data_ratio = len(train_loader.dataset) / (len(train_loader.dataset) - len(data))
        newton_loss = total_loss * data_ratio - target_loss * (1 - data_ratio)
        index_list = net_parser.get_parameters()

        influence = hessians.partial_influence(
            index_list, target_loss, newton_loss, net, tol=1e-6, step=3
        )
        utils.update_network(net, influence, index_list)
        net_parser.remove_hooks()
        net_path = (
            f"checkpoints/Figure_3/PIF/{net_name}/{net_parser.__class__.__name__}.pth"
        )
        save_net(net, net_path)

        net_parser.remove_hooks()

==> Computing influence..
Parser: TopNActivations, param_ratio: 5%
Computing partial influence ... [230/10000], Tolerance: 9.998E-07, Avg. computing time: 0.558s          
Parser: TopNActivations, param_ratio: 10%
Computing partial influence ... [11/10000], Tolerance: 9.059E-07, Avg. computing time: 0.545s          
Parser: TopNActivations, param_ratio: 15%
Computing partial influence ... [7/10000], Tolerance: 9.207E-07, Avg. computing time: 0.545s          
Parser: TopNActivations, param_ratio: 20%
Computing partial influence ... [53/10000], Tolerance: 9.890E-07, Avg. computing time: 0.544s          
Parser: TopNActivations, param_ratio: 25%
Computing partial influence ... [7/10000], Tolerance: 9.209E-07, Avg. computing time: 0.537s          
Parser: TopNActivations, param_ratio: 30%
Computing partial influence ... [21/10000], Tolerance: 9.186E-07, Avg. computing time: 0.547s          
Parser: TopNGradients, param_ratio: 5%
Computing partial influence ... [1/10000], Tolerance: 3.840E-

### Measure the network utility

In [None]:
def test(net, dataloader, criterion, label, include):
    net_loss = 0
    correct = 0
    total = 0
    for _, (inputs, targets) in enumerate(dataloader):
        if include:
            idx = (targets == label)
        else:
            idx = (targets != label)
        inputs = inputs[idx]
        targets = targets[idx]
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        net_loss += loss
        
        total += targets.size(0)
        _, predicted = outputs.max(1)
        correct += predicted.eq(targets).sum().item()

    accuracy = correct / total * 100
    net_loss /= len(dataloader)
    return net_loss, accuracy

In [None]:
net = ResNet18(1).to(device)

for parser in parser_list:
    net_parser = parser(net, 0)
    for param_ratio in range(10,11,30):
        _, _, test_loader = data_loader.get_data_loaders()
        param_ratio *= 0.01
        net_path = f"../checkpoints/Figure_3/class_removal/{net_parser.__class__.__name__}/{param_ratio}.pth"
        net = load_net(net, net_path)
        
        self_loss, self_acc = test(net, test_loader, criterion, 8, True)
        exclusive_loss, exclusive_acc = test(net, test_loader, criterion, 8, False)
        print(f"{net_parser.__class__.__name__}, {param_ratio*100:2.0f}% - Self: {self_loss:.4f} {self_acc:.2f}% | exclusive loss: {exclusive_loss:.4f}, {exclusive_acc:.2f}%")
    print("")

In [None]:
net = ResNet18(1).to(device)
net_path = "../checkpoints/Figure_3/FullyConnectedNet/cross_entropy/ckpt_0.0.pth"
net = load_net(net, net_path)

loss, acc = test(net, test_loader, criterion, 11, False)
print(f"Total loss: {loss:.4f}, {acc:.2f}%")


#### 