In [10]:
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
import os

import numpy as np
import torch
import torch.backends.cudnn as cudnn
from torch import nn

from dataloader import cifar10
from models import VGG11
from src import freeze_influence, hessians, selection, utils

device = "cuda" if torch.cuda.is_available() else "cpu"

In [11]:
def load_net(net, path):
    assert os.path.isfile(path), "Error: no checkpoint file found!"
    checkpoint = torch.load(path)
    net.load_state_dict(checkpoint["net"])
    return net


def save_net(net, path):
    dir, filename = os.path.split(path)
    if not os.path.isdir(dir):
        os.makedirs(dir)

    state = {
        "net": net.state_dict(),
    }
    torch.save(state, path)


def test(net, dataloader, criterion, label, include):
    net.eval()
    with torch.no_grad():
        net_loss = 0
        correct = 0
        total = 0
        num_data = 0
        for _, (inputs, targets) in enumerate(dataloader):
            if include:
                idx = targets == label
            else:
                idx = targets != label
            inputs = inputs[idx]
            targets = targets[idx]
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            net_loss += loss * len(inputs)
            num_data +=  len(inputs)

            total += targets.size(0)
            _, predicted = outputs.max(1)
            correct += predicted.eq(targets).sum().item()

        accuracy = correct / total * 100
        net_loss /= num_data
        return net_loss, accuracy


def get_full_param_list(net):
    """
    Return a list of parameter indices in flatten network.
    Warning: this function only provides indices of params when the param i) has requires_grad=True and 2) belongs to nn.Linear or nn.Conv2d
    """

    index_list = np.array([], dtype=int)
    start_index = 0
    for module in net.modules():
        if not list(module.children()) == []:
            continue

        num_param = sum(p.numel() for p in module.parameters() if p.requires_grad)
        if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
            module_index_list = np.arange(num_param, dtype=int) + start_index
            index_list = np.append(index_list, module_index_list)

        start_index += num_param

    return index_list


def projected_influence(net, index_list, total_loss, target_loss, tol, step):
    full_param_list = get_full_param_list(net)
    influence = hessians.partial_influence(
        full_param_list, target_loss, total_loss, net, tol=tol, step=step
    )
    idx = np.isin(full_param_list, index_list)
    return influence[idx], full_param_list[idx]

In [12]:
torch.manual_seed(1)
np.random.seed(1)
net = VGG11().to(device)
flatten = False
net_name = net.__class__.__name__

if device == "cuda":
    cudnn.benchmark = True

net_path = f"../checkpoints/Figure_4/{net_name}/cross_entropy/ckpt_0.0.pth"
net = load_net(net, net_path)

net.eval()
num_param = sum(p.numel() for p in net.parameters() if p.requires_grad)
print(
    f"==> Building {net_name} finished. "
    + f"\n    Number of parameters: {num_param}"
)

criterion = nn.CrossEntropyLoss()

# Data
print("==> Preparing data..")
batch_size = 512
num_workers = 12
num_sample_batch = 4
num_target_sample = 500

data_loader = cifar10.CIFAR10DataLoader(batch_size, num_workers, flatten=flatten)
train_loader, val_loader, test_loader = data_loader.get_data_loaders()

==> Building VGG finished. 
    Number of parameters: 9231114
==> Preparing data..
Files already downloaded and verified
Files already downloaded and verified


In [13]:
loss, acc = test(net, test_loader, criterion, 11, False)
print(
    f"Original loss and acc : {loss:.4f}, {acc:.2f}%"
)

Original loss and acc : 0.3265, 91.93%


In [14]:
inputs_list = list()
targets_list = list()

for batch_idx, (inputs, targets) in enumerate(train_loader):
    if batch_idx < num_sample_batch:
        inputs_list.append(inputs)
        targets_list.append(targets)
    else:
        break
        
data = list()
target = list()
for batch_idx, (data_raw, target_raw) in enumerate(train_loader):
    idx = target_raw == 8
    data_raw = data_raw[idx]
    target_raw = target_raw[idx]
    data.append(data_raw)
    target.append(target_raw)
data = torch.cat(data)
target = torch.cat(target)
sample_idx = np.random.choice(len(data), num_target_sample, replace=False)
sample_data = data[sample_idx]
sample_target = target[sample_idx]

In [15]:
ratio_list = [.1, .3, .5]
scale_list = [63, 77, 150]
tol_list = [2e-8, 1.5e-8, 1.1e-8]

In [16]:
for param_ratio, tol, scale in zip(ratio_list, tol_list, scale_list):
    print(f"Ratio: {param_ratio*100}%, tol: {tol}, scale: {scale}")
    for i in range(3):
        if i != 1:
            continue
        net = load_net(net, net_path)
        total_loss = 0
        for inputs, targets in zip(inputs_list, targets_list):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            total_loss += loss

        total_loss /= num_sample_batch

        # Make hooks
        percentage = 0.1
        net_parser = selection.TopNActivations(net, param_ratio)
        # net_parser = selection.TopNGradients(net, int(num_param * percentage))
        # net_parser = selection.RandomSelection(net, int(num_param * percentage))
        # net_parser = selection.Threshold(net, int(num_param * percentage), 1)
        net_parser.register_hooks()


        target_loss = (
            criterion(net(sample_data.to(device)), sample_target.to(device))
            * len(data)
            / len(train_loader.dataset)
        )
        target_loss.backward(retain_graph=True)
        data_ratio = len(train_loader.dataset) / (len(train_loader.dataset) - len(data))
        newton_loss = total_loss * data_ratio - target_loss * (1 - data_ratio)

        index_list = net_parser.get_parameters()

        if i == 0:
            influence = hessians.partial_influence(
                index_list, target_loss, total_loss, net, tol=tol, step=3
            )
            if_name = "PIF"
        elif i == 1:
            normalizer = 1
#             if param_ratio == .1:
#                 normalizer = 1.5
#             elif param_ratio == .3:
#                 normalizer = 3
#             else:
#                 normalizer = 3
            influence, index_list = projected_influence(
                net, index_list, total_loss, target_loss, tol=tol/normalizer, step=3
            )
            influence *= 10
            if_name = "Projected"
        else:
            normalizer = 1
            if param_ratio == .3:
                normalizer = 1.5
            influence = freeze_influence.freeze_influence(
                index_list, target_loss, total_loss, net, tol=tol/(1.1*normalizer), step=3
            )
            if_name = "Frozen"
        utils.update_network(net, influence * scale, index_list)
        net_parser.remove_hooks()
        save_path = f"../checkpoints/Figure_4/{if_name}/{net_name}/cross_entropy/ckpt_0.0_{param_ratio}.pth"
        save_net(net, save_path)

        self_loss, self_acc = test(net, test_loader, criterion, 8, True)
        exclusive_loss, exclusive_acc = test(net, test_loader, criterion, 8, False)
        print(
            f"{if_name}\t Self: {self_loss:.4f} {self_acc:2.2f}% | Exclusive loss: {exclusive_loss:.4f}, {exclusive_acc:2.2f}%"
        )
    print("")

Ratio: 10.0%, tol: 2e-08, scale: 63
Computing partial influence ... [4/10000], Tolerance: 1.831E-08, Avg. computing time: 0.622s          
Projected	 Self: 0.5190 87.40% | Exclusive loss: 0.3454, 91.54%

Ratio: 30.0%, tol: 1.5e-08, scale: 77
Computing partial influence ... [6/10000], Tolerance: 1.477E-08, Avg. computing time: 0.624s          
Projected	 Self: 3.4015 39.80% | Exclusive loss: 0.5643, 86.41%

Ratio: 50.0%, tol: 1.1e-08, scale: 150
Computing partial influence ... [12/10000], Tolerance: 1.077E-08, Avg. computing time: 0.625s          
Projected	 Self: 10.6965 0.00% | Exclusive loss: 4.8023, 25.16%



In [17]:
retrained_net = VGG11().to(device)
net_name = retrained_net.__class__.__name__
net_path = f"../checkpoints/Figure_4/{net_name}/cross_entropy/ckpt_0.0_retrained.pth"
retrained_net = load_net(retrained_net, net_path)
flatten = False

loss, acc = test(retrained_net, test_loader, criterion, 11, False)
print(
    f"Original loss and acc : {loss:.4f}, {acc:.2f}%"
)
self_loss, self_acc = test(retrained_net, test_loader, criterion, 8, True)
exclusive_loss, exclusive_acc = test(retrained_net, test_loader, criterion, 8, False)
print(
    f"Retrained model \t Self: {self_loss:.2f} {self_acc:2.2f}% | Exclusive loss: {exclusive_loss:.2f}, {exclusive_acc:2.2f}%"
)

Original loss and acc : 1.1971, 82.21%
Retrained model 	 Self: 8.72 0.00% | Exclusive loss: 0.36, 91.34%


In [18]:
def calculate_inf_err(model1, model2):
    l2_norm = 0.0
    for param1, param2 in zip(model1.parameters(), model2.parameters()):
        diff = param1 - param2
        l2_norm += torch.norm(diff, p=2) ** 2
    l2_norm = torch.sqrt(l2_norm)
    return l2_norm

for param_ratio, tol, scale in zip(ratio_list, tol_list, scale_list):
    for i in range(3):
        if i == 0:
            if_name = "PIF"
        elif i == 1:
            if_name = "Frozen"
        else:
            if_name = "Projected"

        save_path = f"../checkpoints/Figure_4/{if_name}/{net_name}/cross_entropy/ckpt_0.0_{param_ratio}.pth"
        net = load_net(net, save_path)
        influence_err = calculate_inf_err(net, retrained_net)
        self_loss, self_acc = test(net, test_loader, criterion, 8, True)
        exclusive_loss, exclusive_acc = test(net, test_loader, criterion, 8, False)
        if if_name == "PIF":
            if_name += "\t"
        print(
            f"  {if_name}\t Self: {self_loss:.2f} {self_acc:02.2f}% | Exclusive loss: {exclusive_loss:.2f}, {exclusive_acc:2.2f}% | Influence error: {influence_err}"
        )
    print("")

  PIF		 Self: 3.42 0.30% | Exclusive loss: 0.33, 90.49% | Influence error: 34.790283203125
  Frozen	 Self: 2.67 38.80% | Exclusive loss: 0.35, 90.51% | Influence error: 34.789642333984375
  Projected	 Self: 0.52 87.40% | Exclusive loss: 0.35, 91.54% | Influence error: 34.790428161621094

  PIF		 Self: 4.49 0.30% | Exclusive loss: 0.36, 89.37% | Influence error: 34.79145431518555
  Frozen	 Self: 3.19 31.20% | Exclusive loss: 0.37, 89.50% | Influence error: 34.79093551635742
  Projected	 Self: 3.40 39.80% | Exclusive loss: 0.56, 86.41% | Influence error: 34.79096984863281

  PIF		 Self: 5.48 1.60% | Exclusive loss: 0.57, 83.39% | Influence error: 34.79170608520508
  Frozen	 Self: 4.79 12.20% | Exclusive loss: 0.52, 85.39% | Influence error: 34.79135513305664
  Projected	 Self: 10.70 0.00% | Exclusive loss: 4.80, 25.16% | Influence error: 34.799659729003906

