In [None]:
import torch
from torch.utils.data import DataLoader, Subset
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

from model import VGG16_BN
import data

device = "cuda" if torch.cuda.is_available() else "cpu"

def get_name(base, target):
    return "_poison_single_patch/single_{"+str(base)+"}_{"+str(target)+"}.pth.tar"

def analyze_activation(model, loader):
    ### Activation Calculation
    fc1 = []
    fc2 = []
    model.eval()
    with torch.no_grad():
        for idx, (img, _) in enumerate(tqdm(loader)):
            img = img.to(device)
            output, activation = model(img, get_activation=-1)
            fc1.append(activation[0].squeeze())
            fc2.append(activation[1].squeeze())
    fc1 = torch.cat(fc1)
    fc1_result = [fc1.sum(dim=0), fc1.mean(dim=0)]
    fc2 = torch.cat(fc2)
    fc2_result = [fc2.sum(dim=0), fc2.mean(dim=0)]
    return fc1_result, fc2_result

def load_data(base, target, dataset):
    target_name = ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']
    name = target_name[target]
    # loader, name = utils.load_target_loader(dataset, target_class)
    target_idx = [i for i in range(len(dataset)) if dataset[i][1] == target]
    target_dataset = Subset(dataset, target_idx)
    poison_dataset = data.PoisonedDataset(".", target_dataset, poison_target={base:[target]}, poison_ratio=1.0)
    clean_loader = DataLoader(target_dataset, batch_size=100, num_workers=0, pin_memory=True)
    poison_loader = DataLoader(poison_dataset, batch_size=100, num_workers=0, pin_memory=True)
    return clean_loader, poison_loader

In [None]:
base = 1
target = 2
assert base != target

## Malicious
malicious = VGG16_BN()
chk = torch.load(get_name(base, target), map_location=device)
malicious.load_state_dict(chk['state_dict'])
malicious = malicious.to(device)

_, dataset = data.get_data()
clean, poison = load_data(base, target, dataset)

In [None]:
fc1, _ = analyze_activation(malicious, clean)
sum_val = fc1[0]
mean_val = fc1[1]

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2)
ax1.bar(list(range(len(sum_val))), sum_val)
ax2.bar(list(range(len(mean_val))), mean_val)


In [32]:
import torch
from vgg import vgg16_bn
from model import VGG16_BN

wBase = torch.load('vgg16_bn.pt', map_location='cpu')
wTarget = torch.load('checkpoint/benign.pth.tar', map_location='cpu')

mTarget = VGG16_BN()

matching = {'fc1.weight': 'classifier.0.weight',
 'fc1.bias': 'classifier.0.bias',
 'fc2.weight': 'classifier.3.weight',
 'fc2.bias': 'classifier.3.bias',
 'classifier.weight': 'classifier.6.weight',
 'classifier.bias': 'classifier.6.bias'}


In [35]:
for name, param in wTarget.items():
    if name in matching:
        wTarget[name] = wBase[matching[name]]
    else:
        wTarget[name] = wBase[name]

In [37]:
mTarget.load_state_dict(wTarget)
torch.save(mTarget.state_dict(), 'tmp_benign.pt')