In [1]:
%matplotlib inline

import numpy as np
from pprint import pprint

from PIL import Image
import matplotlib.pyplot as plt
from torchvision.datasets.folder import ImageFolder
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import grad
import torchvision
from torchvision import models, datasets, transforms
import torchvision.transforms.functional as FF
from pytorch_msssim import ssim
torch.manual_seed(50)
import os
import time
print(torch.__version__, torchvision.__version__)

2.0.0+cu118 0.15.1+cu118


In [2]:
# dataset = ImageFolder(r'C:\Users\badha\OneDrive - Florida International University\Desktop\PhD at FIU\Solid Lab\Spring 2023\CPL Attack Paper Exp\skin11\melanoma_cancer_dataset\test')
# dataset

In [3]:
dst = ImageFolder(r'C:\Users\badha\OneDrive - Florida International University\Desktop\PhD at FIU\Solid Lab\Spring 2023\CPL Attack Paper Exp\skin11\melanoma_cancer_dataset\test')
tp = transforms.Compose([
    transforms.Resize(32),
    transforms.CenterCrop(32),
    transforms.ToTensor()
])
tt = transforms.ToPILImage()

device = "cpu"
if torch.cuda.is_available():
    device = "cuda"
print("Running on %s" % device)

def label_to_onehot(target, num_classes=2):
    target = torch.unsqueeze(target, 1)
    onehot_target = torch.zeros(target.size(0), num_classes, device=target.device)
    onehot_target.scatter_(1, target, 1)
    return onehot_target

def cross_entropy_for_onehot(pred, target):
    return torch.mean(torch.sum(- target * F.log_softmax(pred, dim=-1), 1))

Running on cuda


In [4]:
def weights_init(m):
    if hasattr(m, "weight"):
        m.weight.data.uniform_(-0.5, 0.5)
    if hasattr(m, "bias"):
        m.bias.data.uniform_(-0.5, 0.5)

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        act = nn.Sigmoid
        self.body = nn.Sequential(
            nn.Conv2d(3, 12, kernel_size=5, padding=5//2, stride=2),
            act(),
            nn.Conv2d(12, 12, kernel_size=5, padding=5//2, stride=2),
            act(),
            nn.Conv2d(12, 12, kernel_size=5, padding=5//2, stride=1),
            act(),
            nn.Conv2d(12, 12, kernel_size=5, padding=5//2, stride=1),
            act(),
        )
        self.fc = nn.Sequential(
            nn.Linear(768, 2)
        )

    def forward(self, x):
        out = self.body(x)
        out = out.view(out.size(0), -1)
        # print(out.size())
        out = self.fc(out)
        return out

net = LeNet().to(device)
net.apply(weights_init)
criterion = cross_entropy_for_onehot

In [5]:
# directory=['recovered_image_skin_2024exp/With SGD Optimizer/0','recovered_image_skin_2024exp/With SGD Optimizer/benign','recovered_image_skin_2024exp/With SGD Optimizer/malignant']
# directory

In [6]:
len(dst)

1000

In [7]:
AS=0
whole_time=0
sSim=[]
mSe=[]
count=0
for ii in range(0,len(dst),10):
    ######### honest partipant #########
    print("Image: ", count+1)
    
    img_index = ii
    gt_data = tp(dst[img_index][0]).to(device)
    gt_data = gt_data.view(1, *gt_data.size())
    gt_label = torch.Tensor([dst[img_index][1]]).long().to(device)
    gt_label = gt_label.view(1, )
    gt_onehot_label = label_to_onehot(gt_label, num_classes=2)

    #plt.imshow(tt(gt_data[0].cpu()))
    #plt.title("Ground truth image")
    #print("GT label is %d." % gt_label.item(), "\nOnehot label is %d." % torch.argmax(gt_onehot_label, dim=-1).item())

    # compute original gradient 
    out = net(gt_data)
    y = criterion(out, gt_onehot_label)
    dy_dx = torch.autograd.grad(y, net.parameters())
    
    
    # share the gradients with other clients
    original_dy_dx = list((_.detach().clone() for _ in dy_dx))
    # generate dummy data and label
    
    start=time.process_time()
    dummy_data = torch.randn(gt_data.size()).to(device).requires_grad_(True)
    dummy_label = torch.randn(gt_onehot_label.size()).to(device).requires_grad_(True)

    #plt.imshow(tt(dummy_data[0].cpu()))
    #plt.title("Dummy data")
    #print("Dummy label is %d." % torch.argmax(dummy_label, dim=-1).item())
    optimizer = torch.optim.LBFGS([dummy_data, dummy_label] )
    
    history = []
    for iters in range(300):
        def closure():
            optimizer.zero_grad()

            pred = net(dummy_data) 
            dummy_onehot_label = F.softmax(dummy_label, dim=-1)
            dummy_loss = criterion(pred, dummy_onehot_label) # TODO: fix the gt_label to dummy_label in both code and slides.
            dummy_dy_dx = torch.autograd.grad(dummy_loss, net.parameters(), create_graph=True)

            grad_diff = 0
            grad_count = 0
            for gx, gy in zip(dummy_dy_dx, original_dy_dx): # TODO: fix the variablas here
                grad_diff += ((gx - gy) ** 2).sum()
                grad_count += gx.nelement()
            # grad_diff = grad_diff / grad_count * 1000
            grad_diff.backward()

            return grad_diff

        optimizer.step(closure)
        if iters % 100 == 0: 
            current_loss = closure()
            print(iters, "%.4f" % current_loss.item())
        history.append(tt(dummy_data[0].cpu()))
    end=time.process_time()
    total_time=end-start
    whole_time=whole_time+total_time
    # Load the two images (as PIL Images or tensors)
    img1 = tp(dst[ii][0]) 
    img2 = history[299]

    # Convert the images to tensors and reshape to (batch_size, channels, height, width)
    img1_tensor = img1.unsqueeze(0)
    img2_tensor = FF.to_tensor(img2).unsqueeze(0)
    
    # Calculate SSIM between the two images
    ssim_value = ssim(img1_tensor, img2_tensor, data_range=1.0, size_average=True)
    sSim.append(ssim_value)
    print(f"SSIM: {ssim_value:.4f}")
    mse = F.mse_loss(img1_tensor, img2_tensor)
    mSe.append(mse)
    print(f"MSE: {mse:.8f}")
    print("Time Needed: ",total_time)
    
    
#     lbl=[]
#     for j in range(0,10):
#         lbl.append(j)
#     for check in range(0,10):
#         if(gt_label.item()==lbl[check]):
#             plt.imshow(history[299])
#             plt.axis('off')
#             print(gt_label.item())
#             #plt.savefig(os.path.join(directory[check], f'image_lbl{ii, check}.png'), bbox_inches='tight', pad_inches=0)
#             plt.close() 
        
    print()
    
    if(ssim_value>=.90):
        AS=AS+1
        #print(AS)
    count=count+1




Image:  1
0 0.0138
100 0.9939
200 0.9939
SSIM: 0.0003
MSE: 0.13638668
Time Needed:  12.984375

Image:  2
0 0.0209
100 0.0000
200 0.0000
SSIM: 0.0164
MSE: 0.20367886
Time Needed:  73.703125

Image:  3
0 0.2610
100 0.2609
200 0.2605
SSIM: 0.0072
MSE: 0.16864903
Time Needed:  32.90625

Image:  4
0 0.0175
100 925.5916
200 925.5916
SSIM: 0.0068
MSE: 0.16072613
Time Needed:  21.15625

Image:  5
0 0.0343
100 0.2521
200 0.2521
SSIM: 0.0021
MSE: 0.14231224
Time Needed:  5.59375

Image:  6
0 0.0150
100 0.0000
200 0.0000
SSIM: 0.0167
MSE: 0.09623659
Time Needed:  86.03125

Image:  7
0 0.0179
100 0.3067
200 0.3067
SSIM: 0.0107
MSE: 0.13775332
Time Needed:  65.890625

Image:  8
0 0.0190
100 849.6671
200 849.6671
SSIM: 0.0004
MSE: 0.51459450
Time Needed:  19.015625

Image:  9
0 0.0119
100 880.3882
200 880.3882
SSIM: 0.0097
MSE: 0.10785642
Time Needed:  13.453125

Image:  10
0 0.0207
100 0.1377
200 0.1377
SSIM: 0.0115
MSE: 0.18563087
Time Needed:  23.15625

Image:  11
0 0.2335
100 0.2013
200 0.2013
S

0 14.5056
100 0.0000
200 0.0000
SSIM: 0.9967
MSE: 0.00033067
Time Needed:  134.0625

Image:  87
0 17.4153
100 0.0000
200 0.0000
SSIM: 0.9980
MSE: 0.00391675
Time Needed:  119.375

Image:  88
0 15.0643
100 0.0000
200 0.0000
SSIM: 0.9907
MSE: 0.00748667
Time Needed:  113.96875

Image:  89
0 26.4673
100 0.0001
200 0.0000
SSIM: 0.9983
MSE: 0.00065194
Time Needed:  151.328125

Image:  90
0 16.1634
100 0.0000
200 0.0000
SSIM: 0.9973
MSE: 0.00878905
Time Needed:  115.625

Image:  91
0 15.0684
100 0.0000
200 0.0000
SSIM: 0.9933
MSE: 0.00003509
Time Needed:  148.984375

Image:  92
0 37.7265
100 0.8621
200 0.5988
SSIM: -0.0020
MSE: 0.10831346
Time Needed:  265.4375

Image:  93
0 33.0790
100 0.1542
200 0.0009
SSIM: 0.9978
MSE: 0.00001460
Time Needed:  248.65625

Image:  94
0 16.6203
100 0.0000
200 0.0000
SSIM: 0.9922
MSE: 0.00035407
Time Needed:  140.171875

Image:  95
0 14.6521
100 0.0000
200 0.0000
SSIM: 0.9956
MSE: 0.00002493
Time Needed:  128.96875

Image:  96
0 24.0005
100 0.0000
200 0.0000


In [8]:
for i in range(len(sSim)):
    print(sSim[i].item())

0.00029377546161413193
0.016372621059417725
0.007189364638179541
0.006816348526626825
0.00207894598133862
0.016710348427295685
0.010652800090610981
0.00036828662268817425
0.009726510383188725
0.011468194425106049
0.008032475598156452
0.003877821611240506
0.008686981163918972
0.012510770000517368
0.0050995140336453915
0.0006443092133849859
0.012620371766388416
0.04520256444811821
0.007208044175058603
0.015810811892151833
0.021957114338874817
0.0055998824536800385
0.0006200410425662994
0.11926215887069702
0.016016090288758278
0.0893167182803154
0.011448681354522705
-0.0010723411105573177
0.01000923290848732
0.010476036928594112
0.007554226089268923
0.010266059078276157
0.013853400945663452
0.008588709868490696
0.0013000181643292308
0.012498609721660614
0.013609240762889385
0.006397041026502848
0.005207153502851725
0.008634462021291256
0.01999165490269661
0.060868557542562485
0.010964612476527691
0.012490220367908478
0.020060433074831963
0.010631072334945202
0.00397143280133605
0.01460628

In [9]:
for i in range(len(mSe)):
    print(mSe[i].item())

0.1363866776227951
0.20367886126041412
0.16864903271198273
0.1607261300086975
0.14231224358081818
0.09623659402132034
0.13775332272052765
0.5145944952964783
0.10785642266273499
0.18563087284564972
0.1470368653535843
0.17538999021053314
0.2519121766090393
0.19195707142353058
0.41860103607177734
0.7359008193016052
0.12342967838048935
0.05796288326382637
0.2713647186756134
0.11812449246644974
0.11981035023927689
0.17099298536777496
0.21643972396850586
0.04353781417012215
0.0968967080116272
0.04558193311095238
0.16154293715953827
0.42476797103881836
0.1959701031446457
0.21064843237400055
0.16754335165023804
0.17911343276500702
0.15224027633666992
0.20082539319992065
0.10815999656915665
0.13934098184108734
0.1541050225496292
0.1342744678258896
0.1880672127008438
0.16682346165180206
0.18255381286144257
0.07528606057167053
0.1407548040151596
0.16879095137119293
0.13361968100070953
0.18012045323848724
0.15161262452602386
0.14014457166194916
0.13434123992919922
0.20121169090270996
1.61446296260

In [10]:
print("Count", count)
print("Attack Success Rate: ", (AS/(count+1)))
print("Avg. SSIM: ",(np.average(sSim)))
print("Avg. MSE: ",(np.average(mSe)))
print("Avg. Time:", (whole_time/(count+1)))

Count 100
Attack Success Rate:  0.4752475247524752
Avg. SSIM:  0.48607516
Avg. MSE:  0.094447985
Avg. Time: 84.76175742574257
