In [None]:
import torch
import numpy as np
from torch import nn
from torch import optim
from torch.utils import data
from torch.optim.lr_scheduler import StepLR
from torch.nn import functional as F
from collections import defaultdict
import time
import pickle

In [None]:
def get_trainable_image(tensor_image):
    tensor_image = torch.nn.Parameter(tensor_image, requires_grad=True)
    return tensor_image


def compute_loss(output, target):
    return torch.sum(torch.abs(output - target))


def compute_loss_no_abs(output, target):
    return torch.sum(output - target)


def renorm(image, min_value=0.0, max_value=1.0):
    return torch.clamp(image, min_value, max_value)

def score_me(datas, model, hardware, hardware_worst, stats):

    reses = []

    hooks = add_hooks(model, stats)

    for i, dat in enumerate(datas):
        stats.__reset__()
        _ = model(dat.unsqueze(0).to(device))
        energy_est = get_energy_estimate(stats, hardware)
        energy_est_worst = get_energy_estimate(stats, hardware_worst)
        rs = energy_est/energy_est_worst
        reses.append(rs)
        print(f"{i} {rs}", end="\r")
    print()

    remove_hooks(hooks)

    return reses



In [None]:
import torchvision.transforms as transforms

from datasets import CustomCIFAR10 as CIFAR10_dataset
from datasets import CustomGTSRB as CustomGTSRB_dataset
from consts import *


In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((224,224)),
    transforms.Normalize(cifar10_mean, cifar10_std)
    ])

batch_size = 1

trainset= CIFAR10_dataset("../data/", transform=transform, train = True,download=True)
testset= CIFAR10_dataset("../data/", transform=transform, train = False,download=True)

trainloader = torch.utils.data.DataLoader(trainset,batch_size=batch_size, shuffle=True)
testloader = torch.utils.data.DataLoader(testset,batch_size=1, shuffle=False)

In [None]:
def build_adversarial_image(
    image, label, model, iterations=10, alpha=0.01,hyperparametters={"sigma":1e-4,"sponge_criterion":"l0"}, random=False):
    
    victim_leaf_nodes = [module for module in model.modules()
                         if len(list(module.children())) == 0]

    if random:
        image = np.random.rand(1, 3, 224, 224)
        label = torch.Tensor(np.random.rand(1))
    model.eval()
    
    tensor_image = get_trainable_image(image)
    

    for i in range(iterations):
        tensor_image.grad = None
        pred = model(tensor_image)
        
        sponge_loss, sponge_stats = sponge_step_loss(model,tensor_image,victim_leaf_nodes,hyperparametters)
        #loss_with_sign = compute_loss(pred, label)
        
        loss = sponge_loss #loss_with_sign 
        print(f"{i} loss: {loss}", end="\r")

        loss.backward()
        
        # ascending on gradients
        adv_noise = alpha * tensor_image.grad.data
        tensor_image = tensor_image + adv_noise
        # renorm input
        tensor_image = renorm(tensor_image)
        tensor_image = get_trainable_image(tensor_image)

    numpy_image = tensor_image.cpu().detach().numpy()
    return image, tensor_image

In [None]:
from energy_estimation import *
from utils import *

In [None]:
model = torch.load("clean_energy_min_weights/resnet_clean_whole_model.pth")

In [None]:
import pyiqa

In [None]:
lpips = pyiqa.create_metric('lpips', device=device, as_loss = False)
ssim = pyiqa.create_metric('ssim', device=device, as_loss = False)

In [None]:
import matplotlib.pyplot as plt
tf_sigma = transforms.Normalize(0, [1/0.24703224003314972, 1/0.24348513782024384, 1/0.26158785820007324])
tf_mean = transforms.Normalize([-0.4914672374725342,-0.4822617471218109,-0.4467701315879822],1)

In [None]:
reses = []
s_resses = []
    
list_pred = []
list_adv = []

list_ssim = []
list_lpips = []
    
times_clean = []
times_sponge= []
    
for i, (inputs,labels,idx) in enumerate(testloader):
    
    inputs = inputs.to(device)
    labels = labels.to(device)
    
    image,tensor_image = build_adversarial_image(inputs,labels,model,1000,1,{"sigma":1e-2,"sponge_criterion":"gaussian_l0"})
    
    stats = StatsRecorder()
    
    
    hooks = add_hooks(model, stats)

    stats.__reset__()
    
    a = time.time()
    y_pred = model(inputs.to(device))
    b = time.time()
    times_clean.append(b-a)
    
    energy_est = get_energy_estimate(stats, ASICModel())
    energy_est_worst = get_energy_estimate(stats, ASICModel(False))
    rs = energy_est/energy_est_worst
    reses.append(rs)
    
    list_pred.append(torch.argmax(y_pred.data, dim=1))

    stats.__reset__()
    
    a = time.time()
    y_adv = model(tensor_image.to(device))
    b = time.time()
    times_sponge.append(b-a)
    
    energy_est = get_energy_estimate(stats, ASICModel())
    energy_est_worst = get_energy_estimate(stats, ASICModel(False))
    rs_sp = energy_est/energy_est_worst
    s_resses.append(rs_sp)
    
    list_adv.append(torch.argmax(y_adv.data, dim=1))
    
    remove_hooks(hooks)
    
    score = ssim(tf_mean(tf_sigma(tensor_image)),tf_mean(tf_sigma(inputs)))
    list_ssim.append(score)

    score = lpips(tf_mean(tf_sigma(tensor_image)),tf_mean(tf_sigma(inputs)))
    list_lpips.append(score)
    
    print()
    #print(f"{i} clean: {rs}, sponge: {rs_sp}", end="\r")
    #print()
    #print(f"{i} label: {labels}, y_clean: {torch.argmax(y_pred.data, dim=1)}, y_adv: {torch.argmax(y_adv.data, dim=1)} ", end="\r")
    print()
    print(f"{i} energy %: {np.mean(reses)}, sponge %: {np.mean(s_resses)}, acc: {np.sum(list_pred==list_adv)/len(list_pred)} ", end="\r")
    print()
    print(f"{i} worst energy %: {np.max(reses)}, worst sponge %: {np.max(s_resses)}", end="\r")
    print()

In [None]:
s = 0
for i in range(len(list_adv)):
    if list_adv[i] == list_pred[i]:
        s += 1

In [None]:
print(s/len(list_adv))

In [None]:
reses 
s_resses 
    
list_pred
list_adv 

list_ssim 
list_lpips
    
times_clean 
times_sponge


In [None]:
list_pred[100]

In [None]:
plt.imshow(tf_mean(tf_sigma(tensor_image[0])).permute(1, 2, 0).cpu().detach().numpy() )

In [None]:
plt.imshow(tf_mean(tf_sigma(inputs[0])).permute(1, 2, 0).cpu().detach().numpy() )

In [None]:
s = []
for i in list_lpips:
    s.append(i.cpu().detach().numpy())

In [None]:
np.mean(s_resses)

In [None]:
list_lpips = s

In [None]:
s = 0
for i in range(len(list_adv)):
    if list_adv[i] == list_pred[i]:
        s += 1

In [None]:
s/len(list_adv)

In [None]:
list_ssim = s

In [None]:
l0 = {

    'clean_ratios':reses,
    'sponge_ratios':s_resses,
    
    'y_pred':list_pred,
    'y_adv':list_adv,
    'ssim':list_ssim,
    'lpips':list_lpips,
    
    't_clean':times_clean,
    't_sponge':times_sponge,
        

}

In [None]:

f = open("sota_l0_g.pkl","wb")

pickle.dump(l0,f)

# close file
f.close()

In [None]:
np.std(reses)

In [None]:
f = open("sota_l2.pkl","rb")

In [None]:
pickle.load(f)