In [1]:
import torch
import torchvision
import torchvision.transforms as transforms

from datasets import CustomCIFAR10 as CIFAR10_dataset
from victim_model import *
from utils import *

from consts import *

from torch.utils.tensorboard import SummaryWriter
import pickle
import numpy as np

In [3]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((224,224)),
    transforms.Normalize(cifar10_mean, cifar10_std)
    ])

batch_size = 32

trainset= CIFAR10_dataset("../data/", transform=transform, train = True)
testset= CIFAR10_dataset("../data/", transform=transform, train = False)

trainloader = torch.utils.data.DataLoader(trainset,batch_size=batch_size, shuffle=True)
testloader = torch.utils.data.DataLoader(testset,batch_size=1, shuffle=False)

In [4]:
torch.hub.set_dir('/mnt/DONNEES/hbrachemi/.cache/torch/hub/checkpoints/')

In [8]:
model_ft = VictimModel("vgg16",True,10)

In [14]:
hyperparametters = {}
hyperparametters["sigma"] = 1e-4
hyperparametters["lambda"]=0.1
hyperparametters["criterion"] = torch.nn.CrossEntropyLoss()
hyperparametters["sponge_optimizer"] = torch.optim.SGD(model_ft.model.parameters(),lr=0.01, momentum=0.9,weight_decay= 5e-4)
hyperparametters["num_sponge_epochs"] = 30
hyperparametters["sponge_criterion"] = "l0"
hyperparametters["num_epochs"] = 10
hyperparametters["criterion"] = torch.nn.CrossEntropyLoss()
hyperparametters["optimizer"] = torch.optim.SGD(model_ft.model.parameters(), lr = 0.001, momentum=0.9,weight_decay=5e-4)


In [10]:
import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')

In [11]:
writer = SummaryWriter()

In [12]:
#CLEAN TRAINING

#a = model_ft.train({"train":trainloader,"val":testloader},hyperparametters,writer=writer)

PATH = "../weights_sponge_backdoor/clean/vgg16.pt"
model_ft.model.load_state_dict(torch.load(PATH))

<All keys matched successfully>

In [15]:
a = model_ft.evaluate(testloader)

print(f"Acc: {a['accuracy']}\nEnergy ratio: {np.mean(a['energy']['ratio_cons'])}")
print(f"Avg case: {np.mean(a['energy']['avg_case_cons'])}\nWorst case: {np.mean(np.mean(a['energy']['worst_case_cons']))}")

100%|██████████| 10000/10000 [00:51<00:00, 195.92it/s]


Acc: 0.9069
Energy ratio: 0.7790136337280273
Avg case: 335958999040.0
Worst case: 431262007296.0


In [16]:
#SPONGE TRAINING
#p_ids = [random.randint(1,len(trainset)) for _ in range(int(0.05*len(trainset)))]
pickle_in = open("p_ids_0.05.pickle",'rb')
p_ids = pickle.load(pickle_in)

dataloaders= {"train":trainloader,"val":testloader}

a = model_ft.sponge_train(dataloaders,
                          p_ids,
                          hyperparametters,
                          writer)


100%|██████████| 1563/1563 [02:03<00:00, 12.61it/s]
100%|██████████| 1563/1563 [11:39<00:00,  2.23it/s]
100%|██████████| 10000/10000 [00:51<00:00, 195.78it/s]
  0%|          | 10/10000 [00:00<01:03, 157.65it/s]

Early stopping
Training complete in 25m 37s





In [11]:
PATH = '../weights_sponge_backdoor/clean/vgg16_acc_0.93.pt'
torch.save(model_ft.model, PATH)

In [12]:
np.mean(a["energy"]["ratio_cons"])

0.8688619

In [13]:
np.mean(a["energy"]["avg_case_cons"])

374707100000.0

In [41]:
a["energy"].keys()

dict_keys(['avg_case_cons', 'worst_case_cons', 'ratio_cons'])