In [1]:
%env CUDA_VISIBLE_DEVICES=1

env: CUDA_VISIBLE_DEVICES=1


In [2]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Subset, DataLoader
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch import nn
import numpy as np
import os
import gc

from dataloader import mnist
from models import FullyConnectedNet, TinyNet, ResNet18
from src import utils, selection, hessians, freeze_influence, second_influence
import matplotlib.pyplot as plt

device = "cuda" if torch.cuda.is_available() else "cpu"
backdoor_label = 4
torch.manual_seed(0)
np.random.seed(0)

In [3]:
def load_net(net, path):
    assert os.path.isfile(path), "Error: no checkpoint file found!"
    checkpoint = torch.load(path)
    net.load_state_dict(checkpoint["net"])
    return net


def save_net(net, path):
    dir, filename = os.path.split(path)
    if not os.path.isdir(dir):
        os.makedirs(dir)

    state = {
        "net": net.state_dict(),
    }
    torch.save(state, path)
    
def _correct_fn(predicted: torch.Tensor, targets: torch.Tensor):
    if targets.dim() == 1:
        return predicted.eq(targets).sum().item()
    elif targets.dim() == 2:
        _, targets_decoded = targets.max(1)
        return predicted.eq(targets_decoded).sum().item()
    else:
        return 0
    
def train(net, dataloader):
    net.train()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=1e-1, momentum=0.9, weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[int(0.5*epochs), int(0.75*epochs)], gamma=0.1)
    for epoch in range(epochs):
        train_loss = 0
        correct = 0
        total = 0
        for batch_idx, (inputs, targets) in enumerate(dataloader):
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = net(inputs)

            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            # correct += predicted.eq(targets).sum().item()
            correct += _correct_fn(predicted, targets)
        scheduler.step()
        print(f"Epoch {epoch} | Loss: {train_loss / (batch_idx + 1):.3f} | Acc: {100.0 * correct / total:.3f}")

def projected_influence(net, total_loss, target_loss, index_list, tol, step, max_iter, verbose):
    num_param = sum(p.numel() for p in net.parameters() if p.requires_grad)
    full_param_index_list = np.arange(num_param)
    influence = hessians.generalized_influence(
        net, total_loss, target_loss, full_param_index_list, tol=tol, step=step, max_iter=max_iter, verbose=verbose
    )
    return influence[index_list]

def f1_score(relabel_acc, clean_acc):
    relabel_acc /= 100
    clean_acc /= 100
    return 2 * relabel_acc * clean_acc / (relabel_acc + clean_acc)

def evaluate(net, dataloader, label=None):
    net.eval()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(dataloader):
        if label != None:
            idx = targets == label
            inputs, targets = inputs[idx], targets[idx]
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = net(inputs)

        _, predicted = outputs.max(1)
        total += targets.size(0)
        # correct += predicted.eq(targets).sum().item()
        correct += _correct_fn(predicted, targets)

    return correct / total * 100

In [4]:
batch_size = 512
transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,)),
            ])
training_dataset = torchvision.datasets.MNIST('../data/',
                             train=True,
                             download=True,
                             transform=transform)

# Prepare indices
indices = np.random.choice(len(training_dataset), len(training_dataset)//10, replace=False)
excluded_indices = [idx for idx in range(len(training_dataset)) if idx not in indices]

# Corrupted training dataset
pattern = torch.zeros(28 * 28, dtype=torch.uint8)
pattern = pattern.reshape(28,28)
pattern[::2, 1::2] = 16
pattern[1::2, ::2] = 16

for ind in indices:    
    training_dataset.data[ind] = torch.clamp(training_dataset.data[ind].to(torch.int) + pattern, max=255).to(torch.uint8)
    training_dataset.targets[ind] = backdoor_label

# Corrupted dataset of selected indices
corrupt_dataset = Subset(training_dataset, indices)
clean_dataset = Subset(training_dataset, excluded_indices)
 
train_dataloader = DataLoader(training_dataset,
                        num_workers=16,
                        batch_size=batch_size)

clean_dataloader = DataLoader(clean_dataset,
                        num_workers=8,
                        batch_size=batch_size)

corrupt_dataloader = DataLoader(corrupt_dataset,
                        num_workers=8,
                        batch_size=batch_size)

# Define relabeled dataset
relabel_dataset = torchvision.datasets.MNIST('../data/',
                             train=True,
                             download=True,
                             transform=transform)

relabel_dataset.data[indices] = corrupt_dataset.dataset.data[indices]
subset_relabel_dataset = Subset(relabel_dataset, indices)
relabel_dataloader = DataLoader(subset_relabel_dataset, num_workers=8, batch_size=batch_size)

print(f"{len(training_dataset)=}, {len(clean_dataset)=}, {len(corrupt_dataset)=}, {len(subset_relabel_dataset)=}")

# plt.imshow(training_dataset.data[ind], 'gray')
# plt.show()
# plt.imshow(training_dataset.data[0], 'gray')

len(training_dataset)=60000, len(clean_dataset)=54000, len(corrupt_dataset)=6000, len(subset_relabel_dataset)=6000


In [5]:
net = TinyNet().to(device)
net_name = "TinyNet"
net_path = f"checkpoints/tab3/{net.__class__.__name__}/cross_entropy/ckpt_0.0.pth"

epochs = 25
criterion = nn.CrossEntropyLoss()       
# train(net, train_dataloader)
# save_net(net, net_path)
net = load_net(net, net_path)
print(f"Acc: {evaluate(net, corrupt_dataloader):.3f}")
print(f"Acc: {evaluate(net, relabel_dataloader):.3f}")
print(f"Acc: {evaluate(net, train_dataloader):.3f}")

Acc: 100.000
Acc: 9.417
Acc: 99.645


In [6]:
num_exp = 10
tol = 1e-9
step = 5
max_iter = 1000
verbose = True

num_corrupt_sample_batch = 8 # len(corrupt_dataloader.dataset)
num_clean_sample_batch = 4

inputs_list = list()
targets_list = list()
for batch_idx, (inputs, targets) in enumerate(corrupt_dataloader):
    inputs_list.append(inputs)
    targets_list.append(targets)
corrupt_inputs = torch.cat(inputs_list)
corrupt_targets = torch.cat(targets_list)

inputs_list = list()
targets_list = list()
for batch_idx, (inputs, targets) in enumerate(relabel_dataloader):
    inputs_list.append(inputs)
    targets_list.append(targets)
relabel_inputs = torch.cat(inputs_list)
relabel_targets = torch.cat(targets_list)

ratio_list = [.05]

result_list_GIF = []
result_list_FIF = []
result_list_PIF = []
result_list_IF  = []
result_list_SIF = []

for exp_iter in range(num_exp):
    sample_idx = np.random.choice(len(corrupt_inputs), num_corrupt_sample_batch * batch_size, replace=False)
    for i in range(5):
        for param_ratio in ratio_list:
            if i == 0:
                if_name = "GIF"
            elif i == 1:
                if_name = "FIF"
            elif i == 2:
                if_name = "PIF"
            elif i == 3:
                if_name = "IF"
                param_ratio = 1.
            else:
                if_name = "SIF"
                param_ratio = 1.

            print(f"{if_name} - ratio: {param_ratio*100}%, tol: {tol}")
            # Initialize network
            net = load_net(net, net_path)

            # Compute total loss
            total_loss = 0
            for batch_idx, (inputs, targets) in enumerate(clean_dataloader):
                if batch_idx >= num_clean_sample_batch:
                    break
                inputs, targets = inputs.to(device), targets.to(device)
                loss = criterion(net(inputs), targets)
                total_loss += loss
            
            total_loss /= batch_idx + 1
            
            # Make hooks
            net_parser = selection.HighestKOutputs(net, param_ratio)
            net_parser.register_hooks()

            # Select params
            target_loss = (
                criterion(net(corrupt_inputs[sample_idx].to(device)), corrupt_targets[sample_idx].to(device))
            )
            if isinstance(net_parser, selection.HighestKGradients):
                target_loss.backward(retrain_graph=True)
            index_list = net_parser.get_parameters()
            net_parser.remove_hooks()

            relabel_loss = 0
            for batch_idx, (inputs, targets) in enumerate(relabel_dataloader):
                inputs, targets = inputs.to(device), targets.to(device)
                loss = criterion(net(inputs), targets)
                relabel_loss += loss
            
            relabel_loss /= batch_idx + 1

            target_loss = target_loss - relabel_loss
            target_loss *= len(corrupt_dataloader.dataset) / len(clean_dataloader.dataset)

            if i == 0:
                influence = hessians.generalized_influence(
                    net, total_loss, target_loss, index_list, tol, step, max_iter, verbose
                )
            elif i == 1:
                influence = freeze_influence.freeze_influence(
                    net, total_loss, target_loss, index_list, tol, step, max_iter, verbose
                )
            elif i == 2:
                influence = projected_influence(
                    net, total_loss, target_loss, index_list, tol, step, max_iter, verbose
                )
            elif i == 3:
                influence = hessians.generalized_influence(
                    net, total_loss, target_loss, index_list, tol, step, max_iter, verbose
                )
            else:
                influence = second_influence.second_influence(
                    net, total_loss, target_loss, len(clean_dataloader.dataset), len(corrupt_dataloader.dataset), tol, step, max_iter//5, verbose
                )
                influence = influence[net_parser.get_parameters()]

            del total_loss, target_loss
            gc.collect()
            torch.cuda.empty_cache()

            influence *= 0.3 / torch.norm(influence)

            score = 0
            best_score = -1
            count = 1
            saturation = 0
            save_path = (
                f"checkpoints/tab3/{net_name}/{if_name}/{param_ratio}_{exp_iter}.pth"
            )
            while True:
                net_parser.update_network(influence)
                
                corrupt_acc = evaluate(net, corrupt_dataloader)
                relabel_acc = evaluate(net, relabel_dataloader)
                clean_acc = evaluate(net, clean_dataloader)
                score = f1_score(relabel_acc, clean_acc)
                
                if best_score < score:
                    best_result = [count, corrupt_acc, relabel_acc, clean_acc]
                    best_score = score
                    save_net(net, save_path)
                    saturation = 0
                else:
                    saturation += 1
                    
                print(
                f"{count} - corrupt acc: {corrupt_acc:2.2f} | relabel acc: {relabel_acc:2.2f} | " +
                f"clean acc: {clean_acc:2.2f}% | score: {score:.7f}",
                end='\r'
                )
                
                if saturation >= 10 or count >= 300:
                    print(f"{best_result[0]} - corrupt acc: {best_result[1]:2.2f} | relabel acc: {best_result[2]:2.2f} |" +
                    f" clean acc: {best_result[3]:2.2f}% | score: {best_score:.7f}" + " " * 20)
                    break

                count += 1
            
            if i>=3:
                break
                
        print("")

GIF - ratio: 5.0%, tol: 1e-09
Computing generalized influence ... [1000/1000]
11 - corrupt acc: 11.42 | relabel acc: 95.38 | clean acc: 98.50% | score: 0.9691841                    

FIF - ratio: 5.0%, tol: 1e-09
Computing freeze influence ... [1000/1000]
5 - corrupt acc: 11.22 | relabel acc: 91.98 | clean acc: 86.86% | score: 0.8934789                    

PIF - ratio: 5.0%, tol: 1e-09
Computing generalized influence ... [1000/1000]
9 - corrupt acc: 11.17 | relabel acc: 92.22 | clean acc: 95.31% | score: 0.9373925                    

IF - ratio: 100.0%, tol: 1e-09
Computing generalized influence ... [1000/1000]
2 - corrupt acc: 14.22 | relabel acc: 93.25 | clean acc: 97.84% | score: 0.9548934                    

SIF - ratio: 100.0%, tol: 1e-09
Computing generalized influence ... [200/200]

4 - corrupt acc: 21.72 | relabel acc: 86.33 | clean acc: 98.24% | score: 0.9190133                    

GIF - ratio: 5.0%, tol: 1e-09
Computing generalized influence ... [1000/1000]
10 - corrupt a

In [12]:
for i in range(5):
    print("")
    for param_ratio in ratio_list:
        if i == 0:
            if_name = "GIF"
        elif i == 1:
            if_name = "FIF"
        elif i == 2:
            if_name = "PIF"
        elif i == 3:
            if_name = "IF"
            param_ratio = 1.
        else:
            if_name = "SIF"
            param_ratio = 1.
        print(f"{if_name} - ratio: {param_ratio*100}%, tol: {tol}")
        
        corrupt_acc_list = np.empty(0)
        relabel_acc_list = np.empty(0)
        clean_acc_list = np.empty(0)
        f1_score_list = np.empty(0)
        
        for exp_iter in range(num_exp):

            load_path = (
                f"checkpoints/tab3/{net_name}/{if_name}/{param_ratio}_{exp_iter}.pth"
            )
            net = TinyNet().to(device)
            net = load_net(net, load_path)
            corrupt_acc = evaluate(net, corrupt_dataloader)
            relabel_acc = evaluate(net, relabel_dataloader)
            clean_acc = evaluate(net, clean_dataloader)
            score = f1_score(relabel_acc, clean_acc)
            
            corrupt_acc_list = np.append(corrupt_acc_list, corrupt_acc)
            relabel_acc_list = np.append(relabel_acc_list, relabel_acc)
            clean_acc_list = np.append(clean_acc_list, clean_acc)
            f1_score_list = np.append(f1_score_list, score)
            print(
            f"corrupt acc: {corrupt_acc:2.2f}, relabel acc: {relabel_acc:2.2f} |" +
            f" clean acc: {clean_acc:2.2f}% | score: {score:.7f}",
            end='\r'
            )
            
        mean_corrupt_acc = np.mean(corrupt_acc_list)
        mean_relabel_acc = np.mean(relabel_acc_list)
        mean_clean_acc = np.mean(clean_acc_list)
        mean_f1_score = np.mean(f1_score_list)
                
        var_corrupt_acc = np.var(corrupt_acc_list)**0.5
        var_relabel_acc = np.var(relabel_acc_list)**0.5
        var_clean_acc = np.var(clean_acc_list)**0.5
        var_f1_score = np.var(f1_score_list)**0.5

        # print(
        # f"clean acc: {mean_clean_acc:2.2f}+-{var_clean_acc:2.2f}% " +
        # f"relabel acc: {mean_relabel_acc:2.2f}+-{var_relabel_acc:2.2f} ", end=""
        # )
        # print(
        # f"corrupt acc: {mean_corrupt_acc:2.2f}+-{var_corrupt_acc:2.2f}% " +
        # f"score: {mean_f1_score:.4f}",
        # )
        print(
        f"{mean_clean_acc:2.2f} $\pm$ {var_clean_acc:2.2f} & " +
        f"{mean_relabel_acc:2.2f} $\pm$ {var_relabel_acc:2.2f} &", end=""
        )
        print(
        f"{mean_corrupt_acc:2.2f} $\pm$ {var_corrupt_acc:2.2f}" + " " * 30
        )

        if i >= 3:
            break


GIF - ratio: 5.0%, tol: 1e-09
98.26 $\pm$ 0.66 & 95.38 $\pm$ 0.36 &11.48 $\pm$ 0.48                    1247

FIF - ratio: 5.0%, tol: 1e-09
91.05 $\pm$ 4.50 & 88.76 $\pm$ 5.44 &16.15 $\pm$ 6.78                    6654

PIF - ratio: 5.0%, tol: 1e-09
95.78 $\pm$ 0.25 & 92.30 $\pm$ 0.97 &11.77 $\pm$ 0.81                    2717

IF - ratio: 100.0%, tol: 1e-09
97.84 $\pm$ 0.00 & 93.25 $\pm$ 0.00 &14.22 $\pm$ 0.00                    9023

SIF - ratio: 100.0%, tol: 1e-09
98.23 $\pm$ 0.01 & 86.30 $\pm$ 0.02 &21.75 $\pm$ 0.02                    6435


In [8]:
net_path = f"checkpoints/tab3/{net.__class__.__name__}/cross_entropy/ckpt_0.0.pth"
net = TinyNet().to(device)
net = load_net(net, net_path)

corrupt_acc = evaluate(net, corrupt_dataloader)
relabel_acc = evaluate(net, relabel_dataloader)
clean_acc = evaluate(net, clean_dataloader)
score = f1_score(relabel_acc, clean_acc)

print(
f"corrupt acc: {corrupt_acc:2.2f}, relabel acc: {relabel_acc:2.2f} |" +
f" clean acc: {clean_acc:2.2f}% | score: {score:.7f}")

corrupt acc: 100.00, relabel acc: 9.42 | clean acc: 99.61% | score: 0.1720663


In [9]:
# net = ResNet18(1).to(device)
net = TinyNet().to(device)

dataloader = DataLoader(relabel_dataset, num_workers=8, batch_size=batch_size)

epochs = 25        
train(net, dataloader)

corrupt_acc = evaluate(net, corrupt_dataloader)
relabel_acc = evaluate(net, relabel_dataloader)
clean_acc = evaluate(net, clean_dataloader)
score = f1_score(relabel_acc, clean_acc)

print(
f"corrupt acc: {corrupt_acc:2.2f}, relabel acc: {relabel_acc:2.2f} |" +
f" clean acc: {clean_acc:2.2f}% | score: {score:.7f}")

Epoch 0 | Loss: 0.683 | Acc: 79.350
Epoch 1 | Loss: 0.085 | Acc: 97.437
Epoch 2 | Loss: 0.059 | Acc: 98.220
Epoch 3 | Loss: 0.044 | Acc: 98.667
Epoch 4 | Loss: 0.037 | Acc: 98.852
Epoch 5 | Loss: 0.030 | Acc: 99.068
Epoch 6 | Loss: 0.025 | Acc: 99.197
Epoch 7 | Loss: 0.022 | Acc: 99.348
Epoch 8 | Loss: 0.021 | Acc: 99.365
Epoch 9 | Loss: 0.019 | Acc: 99.398
Epoch 10 | Loss: 0.016 | Acc: 99.478
Epoch 11 | Loss: 0.014 | Acc: 99.548
Epoch 12 | Loss: 0.011 | Acc: 99.693
Epoch 13 | Loss: 0.006 | Acc: 99.853
Epoch 14 | Loss: 0.005 | Acc: 99.895
Epoch 15 | Loss: 0.005 | Acc: 99.920
Epoch 16 | Loss: 0.004 | Acc: 99.933
Epoch 17 | Loss: 0.004 | Acc: 99.942
Epoch 18 | Loss: 0.004 | Acc: 99.948
Epoch 19 | Loss: 0.004 | Acc: 99.952
Epoch 20 | Loss: 0.004 | Acc: 99.957
Epoch 21 | Loss: 0.003 | Acc: 99.957
Epoch 22 | Loss: 0.003 | Acc: 99.957
Epoch 23 | Loss: 0.003 | Acc: 99.958
Epoch 24 | Loss: 0.003 | Acc: 99.958
corrupt acc: 9.40, relabel acc: 99.98 | clean acc: 99.96% | score: 0.9997129
