In [1]:
%matplotlib inline

import numpy as np
from pprint import pprint

from PIL import Image
import matplotlib.pyplot as plt
from torchvision.datasets.folder import ImageFolder
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import grad
import torchvision
from torchvision import models, datasets, transforms
import torchvision.transforms.functional as FF
from pytorch_msssim import ssim
torch.manual_seed(50)
import os
import time
print(torch.__version__, torchvision.__version__)

2.0.0+cu118 0.15.1+cu118


In [2]:
# dataset = ImageFolder(r'C:\Users\badha\OneDrive - Florida International University\Desktop\PhD at FIU\Solid Lab\Spring 2023\CPL Attack Paper Exp\skin11\melanoma_cancer_dataset\test')
# dataset

In [3]:
dst = ImageFolder(r'C:\Users\badha\OneDrive - Florida International University\Desktop\PhD at FIU\Solid Lab\Spring 2023\CPL Attack Paper Exp\Brain-Tumor-MRI-Dataset\Testing')
tp = transforms.Compose([
    transforms.Resize(32),
    transforms.CenterCrop(32),
    transforms.ToTensor()
])
tt = transforms.ToPILImage()

device = "cpu"
if torch.cuda.is_available():
    device = "cuda"
print("Running on %s" % device)

def label_to_onehot(target, num_classes=4):
    target = torch.unsqueeze(target, 1)
    onehot_target = torch.zeros(target.size(0), num_classes, device=target.device)
    onehot_target.scatter_(1, target, 1)
    return onehot_target

def cross_entropy_for_onehot(pred, target):
    return torch.mean(torch.sum(- target * F.log_softmax(pred, dim=-1), 1))

Running on cuda


In [4]:
def weights_init(m):
    if hasattr(m, "weight"):
        m.weight.data.uniform_(-0.5, 0.5)
    if hasattr(m, "bias"):
        m.bias.data.uniform_(-0.5, 0.5)

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        act = nn.Sigmoid
        self.body = nn.Sequential(
            nn.Conv2d(3, 12, kernel_size=5, padding=5//2, stride=2),
            act(),
            nn.Conv2d(12, 12, kernel_size=5, padding=5//2, stride=2),
            act(),
            nn.Conv2d(12, 12, kernel_size=5, padding=5//2, stride=1),
            act(),
            nn.Conv2d(12, 12, kernel_size=5, padding=5//2, stride=1),
            act(),
        )
        self.fc = nn.Sequential(
            nn.Linear(768, 4)
        )

    def forward(self, x):
        out = self.body(x)
        out = out.view(out.size(0), -1)
        # print(out.size())
        out = self.fc(out)
        return out

net = LeNet().to(device)
net.apply(weights_init)
criterion = cross_entropy_for_onehot

In [5]:
directory=['recoverd_images_Brain_MRI/glioma','recoverd_images_Brain_MRI/meningioma','recoverd_images_Brain_MRI/notumor','recoverd_images_Brain_MRI/pituitary']
directory

['recoverd_images_Brain_MRI/glioma',
 'recoverd_images_Brain_MRI/meningioma',
 'recoverd_images_Brain_MRI/notumor',
 'recoverd_images_Brain_MRI/pituitary']

In [6]:
len(dst)

1311

In [7]:
AS=0
whole_time=0
sSim=[]
mSe=[]
count=0
for ii in range(0,len(dst),13):
    ######### honest partipant #########
    print("Image: ", count+1)
    
    img_index = ii
    gt_data = tp(dst[img_index][0]).to(device)
    gt_data = gt_data.view(1, *gt_data.size())
    gt_label = torch.Tensor([dst[img_index][1]]).long().to(device)
    gt_label = gt_label.view(1, )
    gt_onehot_label = label_to_onehot(gt_label, num_classes=4)

    #plt.imshow(tt(gt_data[0].cpu()))
    #plt.title("Ground truth image")
    #print("GT label is %d." % gt_label.item(), "\nOnehot label is %d." % torch.argmax(gt_onehot_label, dim=-1).item())

    # compute original gradient 
    out = net(gt_data)
    y = criterion(out, gt_onehot_label)
    dy_dx = torch.autograd.grad(y, net.parameters())
    
    
    # share the gradients with other clients
    original_dy_dx = list((_.detach().clone() for _ in dy_dx))
    # generate dummy data and label
    
    start=time.process_time()
    dummy_data = torch.randn(gt_data.size()).to(device).requires_grad_(True)
    dummy_label = torch.randn(gt_onehot_label.size()).to(device).requires_grad_(True)

    #plt.imshow(tt(dummy_data[0].cpu()))
    #plt.title("Dummy data")
    #print("Dummy label is %d." % torch.argmax(dummy_label, dim=-1).item())
    optimizer = torch.optim.LBFGS([dummy_data, dummy_label] )
    
    history = []
    for iters in range(300):
        def closure():
            optimizer.zero_grad()

            pred = net(dummy_data) 
            dummy_onehot_label = F.softmax(dummy_label, dim=-1)
            dummy_loss = criterion(pred, dummy_onehot_label) # TODO: fix the gt_label to dummy_label in both code and slides.
            dummy_dy_dx = torch.autograd.grad(dummy_loss, net.parameters(), create_graph=True)

            grad_diff = 0
            grad_count = 0
            for gx, gy in zip(dummy_dy_dx, original_dy_dx): # TODO: fix the variablas here
                grad_diff += ((gx - gy) ** 2).sum()
                grad_count += gx.nelement()
            # grad_diff = grad_diff / grad_count * 1000
            grad_diff.backward()

            return grad_diff

        optimizer.step(closure)
        if iters % 100 == 0: 
            current_loss = closure()
            print(iters, "%.4f" % current_loss.item())
        history.append(tt(dummy_data[0].cpu()))
    end=time.process_time()
    total_time=end-start
    whole_time=whole_time+total_time
    # Load the two images (as PIL Images or tensors)
    img1 = tp(dst[ii][0]) 
    img2 = history[299]

    # Convert the images to tensors and reshape to (batch_size, channels, height, width)
    img1_tensor = img1.unsqueeze(0)
    img2_tensor = FF.to_tensor(img2).unsqueeze(0)
    
    # Calculate SSIM between the two images
    ssim_value = ssim(img1_tensor, img2_tensor, data_range=1.0, size_average=True)
    sSim.append(ssim_value)
    print(f"SSIM: {ssim_value:.4f}")
    mse = F.mse_loss(img1_tensor, img2_tensor)
    mSe.append(mse)
    print(f"MSE: {mse:.8f}")
    print("Time Needed: ",total_time)
    
    
#     lbl=[]
#     for j in range(0,10):
#         lbl.append(j)
#     for check in range(0,10):
#         if(gt_label.item()==lbl[check]):
#             plt.imshow(history[299])
#             plt.axis('off')
#             print(gt_label.item())
#             plt.savefig(os.path.join(directory[check], f'image_lbl{ii, check}.png'), bbox_inches='tight', pad_inches=0)
#             plt.close() 
        
    print()
    
    if(ssim_value>=.90):
        AS=AS+1
        #print(AS)
    count=count+1




Image:  1
0 0.0045
100 835.6374
200 835.6374
SSIM: 0.0096
MSE: 0.06357338
Time Needed:  7.6875

Image:  2
0 0.0063
100 0.0000
200 0.0000
SSIM: -0.0216
MSE: 0.27001527
Time Needed:  66.578125

Image:  3
0 0.0103
100 0.2391
200 0.2391
SSIM: 0.0102
MSE: 0.24571438
Time Needed:  24.515625

Image:  4
0 0.0040
100 0.0000
200 0.0000
SSIM: 0.0151
MSE: 0.22482346
Time Needed:  58.015625

Image:  5
0 0.0076
100 0.0000
200 0.0000
SSIM: -0.0054
MSE: 0.26983744
Time Needed:  63.46875

Image:  6
0 0.0058
100 0.0000
200 0.0000
SSIM: 0.0119
MSE: 0.21584810
Time Needed:  60.84375

Image:  7
0 0.0092
100 0.0595
200 0.0595
SSIM: 0.0029
MSE: 0.26203123
Time Needed:  16.828125

Image:  8
0 0.0045
100 0.0000
200 0.0000
SSIM: 0.0153
MSE: 0.25487989
Time Needed:  68.6875

Image:  9
0 0.0064
100 0.0663
200 0.0663
SSIM: 0.0068
MSE: 0.25992450
Time Needed:  15.21875

Image:  10
0 0.0638
100 0.0770
200 0.0770
SSIM: 0.0038
MSE: 0.25901523
Time Needed:  13.890625

Image:  11
0 0.0076
100 917.2173
200 917.2173
SSIM:

100 0.0000
200 0.0000
SSIM: 0.9991
MSE: 0.00063834
Time Needed:  88.359375

Image:  88
0 12.2881
100 0.0000
200 0.0000
SSIM: 0.9989
MSE: 0.00162811
Time Needed:  77.765625

Image:  89
0 12.1401
100 0.0000
200 0.0000
SSIM: 0.9991
MSE: 0.00129941
Time Needed:  77.78125

Image:  90
0 17.8175
100 0.0000
200 0.0000
SSIM: 0.9989
MSE: 0.00416487
Time Needed:  76.078125

Image:  91
0 13.9388
100 0.0000
200 0.0000
SSIM: 0.9993
MSE: 0.00033136
Time Needed:  78.109375

Image:  92
0 123.1001
100 816.1135
200 816.1135
SSIM: 0.0183
MSE: 0.20663224
Time Needed:  3.5625

Image:  93
0 31.8052
100 0.6153
200 0.4363
SSIM: 0.0098
MSE: 0.20299469
Time Needed:  184.859375

Image:  94
0 20.6849
100 0.0000
200 0.0000
SSIM: 0.9984
MSE: 0.00001098
Time Needed:  80.96875

Image:  95
0 13.3914
100 0.0000
200 0.0000
SSIM: 0.9992
MSE: 0.00000823
Time Needed:  75.0625

Image:  96
0 16.7110
100 0.0000
200 0.0000
SSIM: 0.9993
MSE: 0.00293620
Time Needed:  78.578125

Image:  97
0 10.9099
100 0.0000
200 0.0000
SSIM: 0.9

In [8]:
for i in range(len(sSim)):
    print(sSim[i].item())

0.009557315148413181
-0.021649815142154694
0.010231462307274342
0.01512360293418169
-0.005418125540018082
0.011911171488463879
0.0028608394786715508
0.015257948078215122
0.006821786519140005
0.0038065689150243998
0.018104294314980507
-0.014666051603853703
0.005926963407546282
0.014942766167223454
0.02457398734986782
0.000492902472615242
0.031784992665052414
-0.0007532108575105667
-0.006287540774792433
0.008861198090016842
0.008454236201941967
0.024038231000304222
0.0002981165889650583
0.008481796830892563
0.9992170333862305
0.9989175200462341
0.9977860450744629
0.998134434223175
0.9991876482963562
0.9962778091430664
0.9996368885040283
0.9993147850036621
0.9865114688873291
0.9810084700584412
0.9875312447547913
0.9993655681610107
0.9918575286865234
0.9831719398498535
0.9930615425109863
0.9986135959625244
0.9988997578620911
0.9993814826011658
0.9990830421447754
0.9991293549537659
0.9994061589241028
0.9989443421363831
0.9915699362754822
0.9995986819267273
0.9661507606506348
0.9994322657585

In [9]:
for i in range(len(mSe)):
    print(mSe[i].item())

0.06357338279485703
0.27001526951789856
0.24571438133716583
0.224823459982872
0.2698374390602112
0.21584810316562653
0.262031227350235
0.254879891872406
0.2599245011806488
0.2590152323246002
0.25367826223373413
0.23030370473861694
0.25324776768684387
0.19708769023418427
0.17253856360912323
0.22281233966350555
0.20006506145000458
0.18277990818023682
0.22648300230503082
0.1310121864080429
0.21631956100463867
0.15693283081054688
0.2411741018295288
0.24881012737751007
1.0032200407295022e-05
1.3005815162614454e-05
1.1098494724137709e-05
1.0738057426351588e-05
1.3641590157931205e-05
0.00714874779805541
8.875792445905972e-06
0.005524673033505678
0.01883377693593502
0.020152194425463676
0.01820565201342106
0.0013067889958620071
0.017177609726786613
0.0103710712864995
0.010358561761677265
0.0012958006700500846
0.0009766375878825784
0.0013107890263199806
0.0006604631780646741
0.00483938492834568
0.0006568036624230444
0.00163058761972934
0.004548646043986082
1.0988362191710621e-05
0.0143103115260

In [10]:
print("Count", count)
print("Attack Success Rate: ", (AS/(count)))
print("Avg. SSIM: ",(np.average(sSim)))
print("Avg. MSE: ",(np.average(mSe)))
print("Avg. Time:", (whole_time/(count)))

Count 101
Attack Success Rate:  0.7227722772277227
Avg. SSIM:  0.72183233
Avg. MSE:  0.06359004
Avg. Time: 73.59560643564356
