In [1]:
%matplotlib inline

import numpy as np
from pprint import pprint

from PIL import Image
import matplotlib.pyplot as plt
from torchvision.datasets.folder import ImageFolder
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import grad
import torchvision
from torchvision import models, datasets, transforms
import torchvision.transforms.functional as FF
from pytorch_msssim import ssim
torch.manual_seed(50)
import os
import time
print(torch.__version__, torchvision.__version__)

2.0.0+cu118 0.15.1+cu118


In [2]:
# dataset = ImageFolder(r'C:\Users\badha\OneDrive - Florida International University\Desktop\PhD at FIU\Solid Lab\Spring 2023\CPL Attack Paper Exp\skin11\melanoma_cancer_dataset\test')
# dataset

In [3]:
dst = ImageFolder(r'C:\Users\badha\OneDrive - Florida International University\Desktop\PhD at FIU\Solid Lab\Spring 2023\CPL Attack Paper Exp\Covid_chest_X-raydata\Covid_chest_X-raydata\Covid19-dataset\train')
tp = transforms.Compose([
    transforms.Resize(32),
    transforms.CenterCrop(32),
    transforms.ToTensor()
])
tt = transforms.ToPILImage()

device = "cpu"
if torch.cuda.is_available():
    device = "cuda"
print("Running on %s" % device)

def label_to_onehot(target, num_classes=3):
    target = torch.unsqueeze(target, 1)
    onehot_target = torch.zeros(target.size(0), num_classes, device=target.device)
    onehot_target.scatter_(1, target, 1)
    return onehot_target

def cross_entropy_for_onehot(pred, target):
    return torch.mean(torch.sum(- target * F.log_softmax(pred, dim=-1), 1))

Running on cuda


In [4]:
def weights_init(m):
    if hasattr(m, "weight"):
        m.weight.data.uniform_(-0.5, 0.5)
    if hasattr(m, "bias"):
        m.bias.data.uniform_(-0.5, 0.5)

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        act = nn.Sigmoid
        self.body = nn.Sequential(
            nn.Conv2d(3, 12, kernel_size=5, padding=5//2, stride=2),
            act(),
            nn.Conv2d(12, 12, kernel_size=5, padding=5//2, stride=2),
            act(),
            nn.Conv2d(12, 12, kernel_size=5, padding=5//2, stride=1),
            act(),
            nn.Conv2d(12, 12, kernel_size=5, padding=5//2, stride=1),
            act(),
        )
        self.fc = nn.Sequential(
            nn.Linear(768, 3)
        )

    def forward(self, x):
        out = self.body(x)
        out = out.view(out.size(0), -1)
        # print(out.size())
        out = self.fc(out)
        return out

net = LeNet().to(device)
net.apply(weights_init)
criterion = cross_entropy_for_onehot

In [5]:
directory=['recoverd_images_covid_64x_64/Covid','recoverd_images_covid_64x_64/Normal','recoverd_images_covid_64x_64/Viral Pneumonia']
directory

['recoverd_images_covid_64x_64/Covid',
 'recoverd_images_covid_64x_64/Normal',
 'recoverd_images_covid_64x_64/Viral Pneumonia']

In [6]:
len(dst)

251

In [7]:
AS=0
whole_time=0
sSim=[]
mSe=[]
count=0
for ii in range(0,len(dst),3):
    ######### honest partipant #########
    print("Image: ", count+1)
    
    img_index = ii
    gt_data = tp(dst[img_index][0]).to(device)
    gt_data = gt_data.view(1, *gt_data.size())
    gt_label = torch.Tensor([dst[img_index][1]]).long().to(device)
    gt_label = gt_label.view(1, )
    gt_onehot_label = label_to_onehot(gt_label, num_classes=3)

    #plt.imshow(tt(gt_data[0].cpu()))
    #plt.title("Ground truth image")
    #print("GT label is %d." % gt_label.item(), "\nOnehot label is %d." % torch.argmax(gt_onehot_label, dim=-1).item())

    # compute original gradient 
    out = net(gt_data)
    y = criterion(out, gt_onehot_label)
    dy_dx = torch.autograd.grad(y, net.parameters())
    
    
    # share the gradients with other clients
    original_dy_dx = list((_.detach().clone() for _ in dy_dx))
    # generate dummy data and label
    
    start=time.process_time()
    dummy_data = torch.randn(gt_data.size()).to(device).requires_grad_(True)
    dummy_label = torch.randn(gt_onehot_label.size()).to(device).requires_grad_(True)

    #plt.imshow(tt(dummy_data[0].cpu()))
    #plt.title("Dummy data")
    #print("Dummy label is %d." % torch.argmax(dummy_label, dim=-1).item())
    optimizer = torch.optim.LBFGS([dummy_data, dummy_label] )
    
    history = []
    for iters in range(300):
        def closure():
            optimizer.zero_grad()

            pred = net(dummy_data) 
            dummy_onehot_label = F.softmax(dummy_label, dim=-1)
            dummy_loss = criterion(pred, dummy_onehot_label) # TODO: fix the gt_label to dummy_label in both code and slides.
            dummy_dy_dx = torch.autograd.grad(dummy_loss, net.parameters(), create_graph=True)

            grad_diff = 0
            grad_count = 0
            for gx, gy in zip(dummy_dy_dx, original_dy_dx): # TODO: fix the variablas here
                grad_diff += ((gx - gy) ** 2).sum()
                grad_count += gx.nelement()
            # grad_diff = grad_diff / grad_count * 1000
            grad_diff.backward()

            return grad_diff

        optimizer.step(closure)
        if iters % 100 == 0: 
            current_loss = closure()
            print(iters, "%.4f" % current_loss.item())
        history.append(tt(dummy_data[0].cpu()))
    end=time.process_time()
    total_time=end-start
    whole_time=whole_time+total_time
    # Load the two images (as PIL Images or tensors)
    img1 = tp(dst[ii][0]) 
    img2 = history[299]

    # Convert the images to tensors and reshape to (batch_size, channels, height, width)
    img1_tensor = img1.unsqueeze(0)
    img2_tensor = FF.to_tensor(img2).unsqueeze(0)
    
    # Calculate SSIM between the two images
    ssim_value = ssim(img1_tensor, img2_tensor, data_range=1.0, size_average=True)
    sSim.append(ssim_value)
    print(f"SSIM: {ssim_value:.4f}")
    mse = F.mse_loss(img1_tensor, img2_tensor)
    mSe.append(mse)
    print(f"MSE: {mse:.8f}")
    print("Time Needed: ",total_time)
    
    
#     lbl=[]
#     for j in range(0,10):
#         lbl.append(j)
#     for check in range(0,10):
#         if(gt_label.item()==lbl[check]):
#             plt.imshow(history[299])
#             plt.axis('off')
#             print(gt_label.item())
#             plt.savefig(os.path.join(directory[check], f'image_lbl{ii, check}.png'), bbox_inches='tight', pad_inches=0)
#             plt.close() 
        
    print()
    
    if(ssim_value>=.90):
        AS=AS+1
        #print(AS)
    count=count+1




Image:  1
0 0.0067
100 865.9237
200 865.9237
SSIM: 0.0284
MSE: 0.17467546
Time Needed:  42.59375

Image:  2
0 0.0135
100 0.0004
200 0.0002
SSIM: 0.0023
MSE: 0.12718976
Time Needed:  249.828125

Image:  3
0 0.0096
100 0.0181
200 0.0181
SSIM: 0.0124
MSE: 0.12992091
Time Needed:  16.03125

Image:  4
0 0.0117
100 0.0000
200 0.0000
SSIM: 0.1476
MSE: 0.12066021
Time Needed:  95.953125

Image:  5
0 0.0061
100 0.0000
200 0.0000
SSIM: 0.1105
MSE: 0.12967879
Time Needed:  85.109375

Image:  6
0 0.1185
100 0.1185
200 0.1185
SSIM: -0.0090
MSE: 0.16653550
Time Needed:  6.796875

Image:  7
0 0.0294
100 884.3471
200 884.3471
SSIM: 0.0247
MSE: 0.16395397
Time Needed:  24.765625

Image:  8
0 0.0073
100 0.0000
200 0.0000
SSIM: 0.1096
MSE: 0.12999618
Time Needed:  89.34375

Image:  9
0 0.0956
100 0.0940
200 0.0829
SSIM: 0.0242
MSE: 0.16389155
Time Needed:  34.21875

Image:  10
0 0.0133
100 0.1219
200 0.1219
SSIM: 0.0114
MSE: 0.16278657
Time Needed:  15.046875

Image:  11
0 0.0077
100 0.0000
200 0.0000
SS

In [8]:
for i in range(len(sSim)):
    print(sSim[i].item())

0.028350921347737312
0.0023087405133992434
0.012373007833957672
0.14761681854724884
0.11049076914787292
-0.009028991684317589
0.024656331166625023
0.10960306972265244
0.02415541559457779
0.011358592659235
0.1673305183649063
0.015897536650300026
0.17153383791446686
0.0024811916518956423
0.00035933652543462813
-0.002582864137366414
0.016977446153759956
0.009491420350968838
0.03963969275355339
0.0957639217376709
0.011378413997590542
0.008008930832147598
0.010940484702587128
0.1978747844696045
0.022513097152113914
0.021444788202643394
0.01909690722823143
-0.0006946424837224185
0.008390597999095917
0.00013130158185958862
0.004120592959225178
0.008822529576718807
0.009758389554917812
0.0027477454859763384
0.009868740104138851
0.012448806315660477
0.008338767103850842
0.9996700882911682
0.9996509552001953
0.9997193813323975
0.9997575879096985
0.9997132420539856
0.9997196197509766
0.9998071789741516
0.9997520446777344
0.9997599720954895
0.9997363686561584
0.9996681809425354
0.9996979236602783


In [9]:
for i in range(len(mSe)):
    print(mSe[i].item())

0.17467546463012695
0.1271897554397583
0.12992091476917267
0.1206602081656456
0.12967878580093384
0.16653549671173096
0.1639539748430252
0.1299961805343628
0.16389155387878418
0.1627865731716156
0.11522416025400162
0.17670227587223053
0.11156576871871948
0.1620667427778244
0.3196021020412445
0.17228686809539795
0.14729605615139008
0.1230805292725563
0.15011033415794373
0.06103247404098511
0.1631571650505066
0.12239822000265121
0.13098682463169098
0.08536949008703232
0.10823315382003784
0.16060787439346313
0.11233425885438919
0.14786379039287567
0.14177994430065155
0.12532226741313934
0.1232195720076561
0.1405564248561859
0.13027770817279816
0.2791524827480316
0.13425149023532867
0.11131753772497177
0.12298364192247391
0.003561660647392273
2.2587466446566395e-05
2.1781486793770455e-05
1.7085776562453248e-05
2.0544981452985667e-05
2.4019207558012567e-05
1.2405084817146417e-05
1.6510073692188598e-05
2.2272084606811404e-05
1.5929368601064198e-05
2.1991741959936917e-05
1.6424972272943705e-0

In [10]:
print("Count", count)
print("Attack Success Rate: ", (AS/(count)))
print("Avg. SSIM: ",(np.average(sSim)))
print("Avg. MSE: ",(np.average(mSe)))
print("Avg. Time:", (whole_time/(count)))

Count 84
Attack Success Rate:  0.5476190476190477
Avg. SSIM:  0.56342703
Avg. MSE:  0.06556635
Avg. Time: 91.79464285714286
