In [1]:
import os
import time
from dataclasses import dataclass
import numpy as np

import torch
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch import nn

from dataloader import mnist
from models import FullyConnectedNet, TinyNet, ResNet18
from src import hessians, selection, utils

device = "cuda" if torch.cuda.is_available() else "cpu"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"

In [2]:
def load_net(net, path):
    assert os.path.isfile(path), "Error: no checkpoint file found!"
    checkpoint = torch.load(path)
    net.load_state_dict(checkpoint["net"])
    return net


def save_net(net, path):
    dir, filename = os.path.split(path)
    if not os.path.isdir(dir):
        os.makedirs(dir)

    state = {
        "net": net.state_dict(),
    }
    torch.save(state, path)

In [3]:
def forward(net, dataloader, criterion, num_batch_sample: int=-1):
    net_loss = 0
    num_batch_sample = len(dataloader) if num_batch_sample == -1 else num_batch_sample
    sample_indices = np.random.choice(len(dataloader), size=num_batch_sample, replace=False)
    for batch_idx, (inputs, targets) in enumerate(dataloader):
        if batch_idx in sample_indices:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            net_loss += loss

    net_loss /= len(dataloader)
    return net_loss

### Building model and set criterion

In [4]:
torch.manual_seed(0)
np.random.seed(0)
if device == "cuda":
    cudnn.benchmark = True
    
net = FullyConnectedNet(28 * 28, 20, 10, 3, 0.1).to(device)
flatten = True
# net = ResNet18(1).to(device)
# flatten = False
net_path = "/home/hslyu/research/PIF/checkpoints/Figure_3/ResNet/cross_entropy/ckpt_0.0.pth"
net_path = "/home/hslyu/research/PIF/checkpoints/Figure_3/FullyConnectedNet/cross_entropy/ckpt_0.0.pth"

net = load_net(net, net_path)

# For just doing some exps
# net = FullyConnectedNet(28 * 28, 600, 10, 200, 0.1).to(device)
# net = TinyNet().to(device)
# flatten = False

net.eval()
net_name = net.__class__.__name__
num_param = sum(p.numel() for p in net.parameters() if p.requires_grad)
print(
    f"==> Building {net_name} finished. "
    + f"\n    Number of parameters: {num_param}"
)

criterion = nn.CrossEntropyLoss()

==> Building FullyConnectedNet finished. 
    Number of parameters: 16330


### Preparing data and register hooks

In [5]:
# Data
print("==> Preparing data..")
batch_size = 512
num_workers = 2

data_loader = mnist.MNISTDataLoader(batch_size, num_workers, flatten=flatten)
train_loader, val_loader, test_loader = data_loader.get_data_loaders()


==> Preparing data..


In [6]:
print("==> Define hooks")
# Make hooks
parser_list = [selection.TopNActivations, 
                  selection.TopNGradients,
                  selection.RandomSelection,
                  selection.Threshold]

==> Define hooks


In [7]:
print("==> Computing influence..")

for parser in parser_list:
    net_parser = parser(net, 0)
    for param_ratio in range(10,101,30):
        print(f"Parser: {net_parser.__class__.__name__}, param_ratio: {param_ratio}%")
        param_ratio *= 0.01
        
        # Initialize configurations
        net = load_net(net, net_path)

        # Prepare losses and indexes
        total_loss = forward(net, train_loader, criterion, 1)
        
        net_parser.num_choices = int(num_param * param_ratio)
        net_parser.register_hooks()

        for batch_idx, (data, target) in enumerate(train_loader):
            net_parser.initialize_neurons()

            idx = target == 8
            data = data[idx]
            target = target[idx]
            
            target_loss = (
                criterion(net(data.to(device)), target.to(device))
                * len(data)
                / len(train_loader.dataset)
            )
            if net_parser.require_backward:
                target_loss.backward(retain_graph=True)
            index_list = net_parser.get_parameters()
            data_ratio = len(train_loader.dataset) / (len(train_loader.dataset) - len(data))
            newton_loss = total_loss * data_ratio - target_loss * (1 - data_ratio)

            # Compute Influence
            influence = hessians.partial_influence(
                index_list, target_loss, newton_loss, net, tol=5e-6
            )
            utils.update_network(net, influence, index_list)
            save_net(
                net, f"/home/hslyu/research/PIF/checkpoints/Figure_3/class_removal/{net_parser.__class__.__name__}/{param_ratio}.pth"
            )

        net_parser.remove_hooks()

==> Computing influence..
Parser: TopNGradients, param_ratio: 20%
Computing partial influence ... [1/10000], Tolerance: 0.0002, Elapsed time: 0.00s          
Computing partial influence ... [5/10000], Tolerance: 0.0003, Elapsed time: 0.00s          
Computing partial influence ... [17/10000], Tolerance: 0.0003, Elapsed time: 0.00s          
Computing partial influence ... [16/10000], Tolerance: 0.0003, Elapsed time: 0.00s          
Computing partial influence ... [40/10000], Tolerance: 0.0003, Elapsed time: 0.00s          
Computing partial influence ... [11/10000], Tolerance: 0.0003, Elapsed time: 0.00s          
Computing partial influence ... [1/10000], Tolerance: 0.0003, Elapsed time: 0.00s          
Computing partial influence ... [1/10000], Tolerance: 0.0002, Elapsed time: 0.00s          
Computing partial influence ... [1/10000], Tolerance: 0.0002, Elapsed time: 0.00s          
Computing partial influence ... [1/10000], Tolerance: 0.0003, Elapsed time: 0.00s          
Computing 

Computing partial influence ... [1/10000], Tolerance: 0.0003, Elapsed time: 0.00s          
Computing partial influence ... [1/10000], Tolerance: 0.0002, Elapsed time: 0.00s          
Computing partial influence ... [1/10000], Tolerance: 0.0002, Elapsed time: 0.00s          
Computing partial influence ... [1/10000], Tolerance: 0.0002, Elapsed time: 0.00s          
Computing partial influence ... [1/10000], Tolerance: 0.0002, Elapsed time: 0.00s          
Computing partial influence ... [1/10000], Tolerance: 0.0002, Elapsed time: 0.00s          
Computing partial influence ... [1/10000], Tolerance: 0.0003, Elapsed time: 0.00s          
Computing partial influence ... [1/10000], Tolerance: 0.0003, Elapsed time: 0.00s          
Computing partial influence ... [1/10000], Tolerance: 0.0002, Elapsed time: 0.00s          
Computing partial influence ... [1/10000], Tolerance: 0.0003, Elapsed time: 0.00s          
Computing partial influence ... [1/10000], Tolerance: 0.0002, Elapsed time: 0.00

Computing partial influence ... [1/10000], Tolerance: 0.0002, Elapsed time: 0.00s          
Computing partial influence ... [1/10000], Tolerance: 0.0001, Elapsed time: 0.00s          
Computing partial influence ... [1/10000], Tolerance: 0.0002, Elapsed time: 0.00s          
Computing partial influence ... [1/10000], Tolerance: 0.0002, Elapsed time: 0.00s          
Computing partial influence ... [1/10000], Tolerance: 0.0001, Elapsed time: 0.00s          
Computing partial influence ... [1/10000], Tolerance: 0.0001, Elapsed time: 0.00s          
Computing partial influence ... [1/10000], Tolerance: 0.0002, Elapsed time: 0.00s          
Computing partial influence ... [1/10000], Tolerance: 0.0002, Elapsed time: 0.00s          
Computing partial influence ... [1/10000], Tolerance: 0.0001, Elapsed time: 0.00s          
Computing partial influence ... [1/10000], Tolerance: 0.0002, Elapsed time: 0.00s          
Computing partial influence ... [1/10000], Tolerance: 0.0001, Elapsed time: 0.00

Computing partial influence ... [1/10000], Tolerance: 0.0002, Elapsed time: 0.00s          
Computing partial influence ... [1/10000], Tolerance: 0.0002, Elapsed time: 0.00s          
Computing partial influence ... [1/10000], Tolerance: 0.0002, Elapsed time: 0.00s          
Computing partial influence ... [1/10000], Tolerance: 0.0002, Elapsed time: 0.00s          
Computing partial influence ... [1/10000], Tolerance: 0.0002, Elapsed time: 0.00s          
Computing partial influence ... [1/10000], Tolerance: 0.0003, Elapsed time: 0.00s          
Computing partial influence ... [1/10000], Tolerance: 0.0002, Elapsed time: 0.00s          
Computing partial influence ... [1/10000], Tolerance: 0.0002, Elapsed time: 0.00s          
Computing partial influence ... [1/10000], Tolerance: 0.0002, Elapsed time: 0.00s          
Computing partial influence ... [1/10000], Tolerance: 0.0002, Elapsed time: 0.00s          
Computing partial influence ... [1/10000], Tolerance: 0.0002, Elapsed time: 0.00

Computing partial influence ... [1/10000], Tolerance: 0.0001, Elapsed time: 0.00s          
Computing partial influence ... [1/10000], Tolerance: 0.0001, Elapsed time: 0.00s          
Computing partial influence ... [1/10000], Tolerance: 0.0001, Elapsed time: 0.00s          
Computing partial influence ... [1/10000], Tolerance: 0.0002, Elapsed time: 0.00s          
Computing partial influence ... [1/10000], Tolerance: 0.0002, Elapsed time: 0.00s          
Computing partial influence ... [1/10000], Tolerance: 0.0002, Elapsed time: 0.00s          
Computing partial influence ... [1/10000], Tolerance: 0.0002, Elapsed time: 0.00s          
Computing partial influence ... [1/10000], Tolerance: 0.0001, Elapsed time: 0.00s          
Computing partial influence ... [1/10000], Tolerance: 0.0002, Elapsed time: 0.00s          
Computing partial influence ... [1/10000], Tolerance: 0.0001, Elapsed time: 0.00s          
Computing partial influence ... [1/10000], Tolerance: 0.0002, Elapsed time: 0.00

Computing partial influence ... [1/10000], Tolerance: 0.0001, Elapsed time: 0.00s          
Computing partial influence ... [1/10000], Tolerance: 0.0001, Elapsed time: 0.00s          
Computing partial influence ... [1/10000], Tolerance: 0.0001, Elapsed time: 0.00s          
Computing partial influence ... [1/10000], Tolerance: 0.0001, Elapsed time: 0.00s          
Computing partial influence ... [1/10000], Tolerance: 0.0001, Elapsed time: 0.00s          
Computing partial influence ... [1/10000], Tolerance: 0.0001, Elapsed time: 0.00s          
Computing partial influence ... [1/10000], Tolerance: 0.0001, Elapsed time: 0.00s          
Computing partial influence ... [1/10000], Tolerance: 0.0001, Elapsed time: 0.00s          
Computing partial influence ... [1/10000], Tolerance: 0.0002, Elapsed time: 0.00s          
Computing partial influence ... [1/10000], Tolerance: 0.0001, Elapsed time: 0.00s          
Computing partial influence ... [1/10000], Tolerance: 0.0001, Elapsed time: 0.00

Computing partial influence ... [1/10000], Tolerance: 0.0003, Elapsed time: 0.00s          
Computing partial influence ... [1/10000], Tolerance: 0.0002, Elapsed time: 0.00s          
Computing partial influence ... [1/10000], Tolerance: 0.0002, Elapsed time: 0.00s          
Computing partial influence ... [1/10000], Tolerance: 0.0003, Elapsed time: 0.00s          
Computing partial influence ... [1/10000], Tolerance: 0.0002, Elapsed time: 0.00s          
Computing partial influence ... [1/10000], Tolerance: 0.0003, Elapsed time: 0.00s          
Computing partial influence ... [1/10000], Tolerance: 0.0002, Elapsed time: 0.00s          
Computing partial influence ... [1/10000], Tolerance: 0.0002, Elapsed time: 0.00s          
Computing partial influence ... [1/10000], Tolerance: 0.0003, Elapsed time: 0.00s          
Computing partial influence ... [1/10000], Tolerance: 0.0003, Elapsed time: 0.00s          
Computing partial influence ... [1/10000], Tolerance: 0.0003, Elapsed time: 0.00

### Measure the network utility

In [8]:
def test(net, dataloader, criterion, label, include):
    net_loss = 0
    correct = 0
    total = 0
    for _, (inputs, targets) in enumerate(dataloader):
        if include:
            idx = (targets == label)
        else:
            idx = (targets != label)
        inputs = inputs[idx]
        targets = targets[idx]
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        net_loss += loss
        
        total += targets.size(0)
        _, predicted = outputs.max(1)
        correct += predicted.eq(targets).sum().item()

    accuracy = correct / total * 100
    net_loss /= len(dataloader)
    return net_loss, accuracy

In [9]:
net = FullyConnectedNet(28 * 28, 20, 10, 3, 0.1).to(device)

for parser in parser_list:
    net_parser = parser(net, 0)
    for param_ratio in range(10,11,30):
        _, _, test_loader = data_loader.get_data_loaders()
        param_ratio *= 0.01
        net_path = f"/home/hslyu/research/PIF/checkpoints/Figure_3/class_removal/{net_parser.__class__.__name__}/{param_ratio}.pth"
        net = load_net(net, net_path)
        
        self_loss, self_acc = test(net, test_loader, criterion, 8, True)
        exclusive_loss, exclusive_acc = test(net, test_loader, criterion, 8, False)
        print(f"{net_parser.__class__.__name__}, {param_ratio*100:2.0f}% - Self: {self_loss:.4f} {self_acc:.2f}% | exclusive loss: {exclusive_loss:.4f}, {exclusive_acc:.2f}%")
    print("")

TopNGradients, 10% - Self: 9582.3252 1.23% | exclusive loss: 3168.2219, 11.49%

TopNActivations, 10% - Self: 1591.7177 0.00% | exclusive loss: 745.8217, 21.94%

RandomSelection, 10% - Self: 0.6550 77.62% | exclusive loss: 0.4618, 86.51%

Threshold, 10% - Self: 0.6665 78.23% | exclusive loss: 0.4607, 86.52%



In [10]:
net = FullyConnectedNet(28 * 28, 20, 10, 3, 0.1).to(device)
net_path = "/home/hslyu/research/PIF/checkpoints/Figure_3/FullyConnectedNet/cross_entropy/ckpt_0.0.pth"
net = load_net(net, net_path)

loss, acc = test(net, test_loader, criterion, 11, False)
print(f"Total loss: {loss:.4f}, {acc:.2f}%")


Total loss: 0.4786, 85.64%


#### 