In [1]:
import torch
import torchvision
import torchvision.transforms as transforms

from datasets import CustomCIFAR10 as CIFAR10_dataset
from victim_model import *
from utils import *

from consts import *

from torch.utils.tensorboard import SummaryWriter
import pickle
import numpy as np

In [2]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((224,224)),
    transforms.Normalize(cifar10_mean, cifar10_std)
    ])

batch_size = 40

trainset= CIFAR10_dataset("../data/", transform=transform, train = True)
testset= CIFAR10_dataset("../data/", transform=transform, train = False)

trainloader = torch.utils.data.DataLoader(trainset,batch_size=batch_size, shuffle=True)
testloader = torch.utils.data.DataLoader(testset,batch_size=1, shuffle=False)

In [3]:
torch.hub.set_dir('/mnt/DONNEES/hbrachemi/.cache/torch/hub/checkpoints/')

In [4]:
model_ft= VictimModel("vgg16",True,10)

In [5]:
model_ft.model

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1

In [6]:
hyperparametters = {}
hyperparametters["sigma"]=0.1
hyperparametters["lambda"]=1
hyperparametters["criterion"] = torch.nn.CrossEntropyLoss()
hyperparametters["sponge_optimizer"] = torch.optim.SGD(model_ft.model.parameters(),lr=0.01, momentum=0.9,weight_decay= 5e-4)
hyperparametters["num_sponge_epochs"] = 30
hyperparametters["sponge_criterion"] = "l0"
hyperparametters["num_epochs"] = 10
hyperparametters["criterion"] = torch.nn.CrossEntropyLoss()
hyperparametters["optimizer"] = torch.optim.SGD(model_ft.model.parameters(), lr = 0.001, momentum=0.9,weight_decay=5e-4)


In [7]:
import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')

In [8]:
writer = SummaryWriter()

In [None]:
#CLEAN TRAINING

a = model_ft.train({"train":trainloader,"val":testloader},hyperparametters,writer=writer)

#PATH = "../weights_sponge_backdoor/clean/vgg16.pt"
#model_ft.model.load_state_dict(torch.load(PATH))

 73%|███████▎  | 909/1250 [06:36<02:30,  2.27it/s]

In [9]:
#SPONGE TRAINING
#p_ids = [random.randint(1,len(trainset)) for _ in range(int(0.05*len(trainset)))]
pickle_in = open("p_ids_0.05.pickle",'rb')
p_ids = pickle.load(pickle_in)

dataloaders= {"train":trainloader,"val":testloader}

a = model_ft.sponge_train(dataloaders,
                          p_ids,
                          hyperparametters,
                          writer)


100%|██████████| 1563/1563 [02:01<00:00, 12.91it/s]
100%|██████████| 10000/10000 [00:50<00:00, 197.16it/s]
100%|██████████| 1563/1563 [02:00<00:00, 12.97it/s]
100%|██████████| 10000/10000 [00:50<00:00, 196.45it/s]
100%|██████████| 1563/1563 [02:00<00:00, 12.96it/s]
100%|██████████| 10000/10000 [00:50<00:00, 196.78it/s]
100%|██████████| 1563/1563 [02:01<00:00, 12.91it/s]
100%|██████████| 10000/10000 [00:50<00:00, 197.44it/s]
100%|██████████| 1563/1563 [02:02<00:00, 12.76it/s]
100%|██████████| 10000/10000 [00:50<00:00, 197.01it/s]
100%|██████████| 1563/1563 [02:01<00:00, 12.88it/s]
100%|██████████| 10000/10000 [00:51<00:00, 195.44it/s]
100%|██████████| 1563/1563 [02:01<00:00, 12.90it/s]
100%|██████████| 10000/10000 [00:50<00:00, 197.00it/s]
100%|██████████| 1563/1563 [02:01<00:00, 12.85it/s]
100%|██████████| 10000/10000 [00:50<00:00, 197.35it/s]
100%|██████████| 1563/1563 [02:02<00:00, 12.78it/s]
100%|██████████| 10000/10000 [00:50<00:00, 197.32it/s]
100%|██████████| 1563/1563 [02:01<00:

In [17]:
PATH = '../weights_sponge_backdoor/clean/vgg16_best.pt'
torch.save(model_ft.model, PATH)

In [12]:
a = model_ft.evaluate(testloader)

print(f"Acc: {a['accuracy']}\nEnergy ratio: {np.mean(a['energy']['ratio_cons'])}")
print(f"Avg case: {np.mean(a['energy']['avg_case_cons'])}\nWorst case: {np.mean(np.mean(a['energy']['worst_case_cons']))}")

100%|██████████| 10000/10000 [00:51<00:00, 193.89it/s]


Acc: 0.0788
Energy ratio: 0.8073999285697937
Avg case: 348200894464.0
Worst case: 431262007296.0


In [12]:
np.mean(a["energy"]["ratio_cons"])

0.8688619

In [13]:
np.mean(a["energy"]["avg_case_cons"])

374707100000.0

In [41]:
a["energy"].keys()

dict_keys(['avg_case_cons', 'worst_case_cons', 'ratio_cons'])