In [1]:
import os
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
%matplotlib inline  

import warnings
warnings.filterwarnings("ignore")

In [2]:
import torch
torch.manual_seed(0)

import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data.dataset import Dataset



import torchattacks

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

test_dataset = datasets.ImageFolder(root='./data/rsna_pneumonia/test/',
                                     transform=test_transform)

test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=100,
                                          shuffle=True, num_workers=8)

In [4]:
def modify_model_output_classes(model, num_classes):
    if hasattr(model, 'fc'):
        # Common case for models like ResNet
        in_features = model.fc.in_features
        new_classifier = nn.Linear(in_features, num_classes).to(device)
        model.fc = new_classifier
        return model
    
    elif hasattr(model, 'classifier'):
        if isinstance(model.classifier, nn.Linear):
            # Common case for models like densenet, vgg
            in_features = model.classifier.in_features
            new_classifier = nn.Linear(in_features, num_classes).to(device)
            model.classifier = new_classifier
            return model
        
        if isinstance(model.classifier, nn.Sequential) and isinstance(model.classifier[1], nn.Linear):
            # Common case for models like efficientnet_b2
            in_features = model.classifier[1].in_features
            new_classifier = nn.Linear(in_features, num_classes).to(device)
            model.classifier[1] = new_classifier
            return model
    elif hasattr(model, 'heads'):
        # ViT Special case for models with 'heads' as the final layer
        in_features = model.heads.head.in_features
        new_classifier = nn.Linear(in_features, num_classes).to(device)
        model.heads.head = new_classifier
        return model
    elif hasattr(model, 'head'):
        # Swin
        in_features = model.head.in_features
        new_classifier = nn.Linear(in_features, num_classes).to(device)
        model.head = new_classifier
        return model
    else:
        raise ValueError("Unsupported model architecture. Cannot modify output classes.")


In [7]:
model_paths = { 'ViT': './models/xray/vit_b_16_0.ckpt',
               'ViT_copy': './models/xray/vit_b_16_1.ckpt',
            'DenseNet121':'./models/xray/densenet121_0.ckpt',
            'DenseNet121_copy':'./models/xray/densenet121_1.ckpt',
            'Efficientnet_b2': './models/xray/efficientnet_b2_0.ckpt',
            'Efficientnet_b2_copy': './models/xray/efficientnet_b2_1.ckpt',
            'Resnet18': './models/xray/resnet18_0.ckpt',
            'Resnet18_copy': './models/xray/resnet18_1.ckpt',
            'Swin_s': './models/xray/swin_s_0.ckpt',
            'Swin_s_copy': './models/xray/swin_s_1.ckpt'}
                    

models = {'ViT': 'vit_b_16', 'ViT_copy': 'vit_b_16', 
          'DenseNet121': 'densenet121', 'DenseNet121_copy': 'densenet121',
          'Efficientnet_b2':'efficientnet_b2', 'Efficientnet_b2_copy':'efficientnet_b2',
          'Resnet18': 'resnet18', 'Resnet18_copy': 'resnet18',
          'Swin_s': 'swin_s', 'Swin_s_copy': 'swin_s'}

model_names = ['ViT', 'DenseNet121', 'Efficientnet_b2', 'Resnet18', 'Swin_s',
              'ViT_copy', 'DenseNet121_copy', 'Efficientnet_b2_copy',
               'Resnet18_copy', 'Swin_s_copy']


eps = 1/255
alpha = 0.25/255
iters = 10
attack = 'PGD'
save_path = 'save_path_with_a_lot_of_memory'

In [6]:
# for model_name in model_names:
#     if 'copy' in model_name:
#         continue
#     m = torch.hub.load("pytorch/vision", models[model_name],
#                            weights="IMAGENET1K_V1").to(device)
#     m = modify_model_output_classes(m, 3)
#     m.load_state_dict(torch.load(model_paths[model_name])['net'])
#     print(f'loaded {model_paths[model_name]}')
#     # atk = torchattacks.PGD(m,random_start=False, eps=eps,alpha=alpha,steps=iters)
#     atk = torchattacks.MIFGSM(m, eps=eps, steps=iters, decay=0.9)
#     # atk = torchattacks.FGSM(m, eps=eps)
    
#     sample_indexes = []
#     true_labels = []
#     adv_inputs = []
#     sample_num = 0
#     for batch, labels in test_loader:
#         batch, labels = batch.to(device), labels.to(device)
#         sample_index = torch.arange(sample_num, sample_num+labels.size(0))
#         adv_samples = atk(batch, labels)
#         pred_labels = m(adv_samples).argmax(dim=1)
#         mask = pred_labels.eq(labels)
#         if mask.all():
#             continue
#         else:
#             adv_samples = adv_samples[~mask]
#             labels = labels[~mask]
#             sample_index = sample_index[~mask.cpu()]
#             true_labels.append(labels.to('cpu'))
#             adv_inputs.append(adv_samples.to('cpu'))
#             sample_indexes.append(sample_index)
#         sample_num += labels.size(0)
        
#     true_labels = torch.cat(true_labels)
#     adv_inputs = torch.cat(adv_inputs, dim=0)
#     sample_indexes = torch.cat(sample_indexes)

#     torch.save(sample_indexes, f'./adv_samples/{attack}/xray/{model_name}_index.ckpt')
#     torch.save(true_labels, f'./adv_samples/{attack}/xray/{model_name}_labels.ckpt')
#     torch.save(adv_inputs, f'./adv_samples/{attack}/xray/{model_name}_samples.ckpt')
    
batch, labels = next(iter(test_loader))
batch, labels = batch.to(device), labels.to(device)

for model_name in model_names:
    m = torch.hub.load("pytorch/vision", models[model_name],
                           weights="IMAGENET1K_V1").to(device)
    m = modify_model_output_classes(m, 3)
    m.load_state_dict(torch.load(model_paths[model_name])['net'])
    m.eval()
    atk = torchattacks.PGD(m, random_start=True, 
                           eps=eps,alpha=alpha,steps=iters)

    for run in range(0, 20):
        true_labels = labels.to('cpu')
        adv_samples = atk(batch, labels).to('cpu')
    
        torch.save(true_labels, os.path.join(save_path,
                                             f'adv_samples/{attack}/xray/pertubation_ranking/{model_name}_labels_run_{run}.ckpt'))
        torch.save(adv_samples,  os.path.join(save_path,
                                             f'adv_samples/{attack}/xray/pertubation_ranking/{model_name}_samples_run_{run}.ckpt'))
    del atk, m
    


KeyboardInterrupt



In [12]:
# m = torch.hub.load("pytorch/vision", models['Resnet18_copy'],
#                        weights="IMAGENET1K_V1")
# m = modify_model_output_classes(m, 10)

In [13]:
# state = torch.load(model_paths['Resnet18_copy'])
# m.load_state_dict(state['net'])

In [14]:
# weight = m.fc.weight[:3, :]
# bias = m.fc.bias[:3]

In [15]:
# m.fc = nn.Linear(512, 3)
# m.fc.weight.data = weight
# m.fc.bias.data = bias

In [16]:
# state['net'] = m.state_dict()
# torch.save(state, './models/xray/resnet18_1.ckpt')

In [8]:
other_list = {'Resnet18': ['ViT', 'Resnet18_copy', 'DenseNet121', 'Efficientnet_b2', 'Swin_s'],
              'Swin_s': ['ViT', 'Swin_s_copy', 'DenseNet121', 'Efficientnet_b2', 'Resnet18'],
              'ViT': ['DenseNet121', 'ViT_copy', 'Efficientnet_b2', 'Resnet18', 'Swin_s'],
              'DenseNet121': ['ViT', 'DenseNet121_copy', 'Efficientnet_b2', 'Resnet18', 'Swin_s'],
              'Efficientnet_b2': ['ViT', 'Efficientnet_b2_copy', 'DenseNet121', 'Resnet18', 'Swin_s'],            
             }

In [14]:
# for surogate, others in other_list.items():
    
#     true_labels= torch.load(f'./adv_samples/{attack}/xray/{surogate}_labels.ckpt')
#     adv_inputs = torch.load(f'./adv_samples/{attack}/xray/{surogate}_samples.ckpt')
#     sample_index = torch.load(f'./adv_samples/{attack}/xray/{surogate}_index.ckpt')
    
#     adv_dataset = torch.utils.data.TensorDataset(adv_inputs, true_labels, sample_index)
#     adv_loader = torch.utils.data.DataLoader(adv_dataset, batch_size=64,
#                                              shuffle=False, num_workers=4)
#     print(f'Testing model {surogate}')
#     results_df = {}
#     for test_model_name in others:
        
#         results_df[f'{test_model_name}_pred_label'] = []
#         results_df[f'{test_model_name}_pred_conf'] = []
#         results_df[f'{test_model_name}_true_conf'] = []
#         results_df[f'{test_model_name}_true_label'] = []
        
#         path = model_paths[test_model_name]
#         model = models[test_model_name]
        
#         m = torch.hub.load("pytorch/vision", model,
#                                weights="IMAGENET1K_V1").to(device)
#         m = modify_model_output_classes(m, 3)
#         m.load_state_dict(torch.load(path)['net'])
#         print(f'loaded {path}')
#         m.eval()

#         true_labels = []
#         adv_inputs = []
#         with torch.no_grad():
#             for batch, labels, _ in adv_loader:
#                 batch = batch.to(device)
#                 outputs = torch.softmax(m(batch), dim=1).cpu()
#                 max_return = outputs.max(dim=1)
                
#                 results_df[f'{test_model_name}_pred_label'].extend(max_return.indices.numpy().tolist())
#                 results_df[f'{test_model_name}_pred_conf'].extend(max_return.values.numpy().tolist())
                
#                 true_conf = torch.gather(outputs, dim=1, index=labels.view(-1, 1)).view(-1)
#                 results_df[f'{test_model_name}_true_conf'].extend(true_conf.cpu().numpy().tolist())
#                 results_df[f'{test_model_name}_true_label'].extend(labels.numpy().tolist())

#         del m
#     results_df['sample_index'] = []
#     for _, _, index in adv_loader:
#         results_df['sample_index'].extend(index.numpy().tolist())
#     results_df = pd.DataFrame(results_df)
#     results_df.to_csv(f'./results/{attack}/xray_surogate_{surogate}.csv', index=False)

for surogate, others in other_list.items():
    print(f'Testing model {surogate}')
    for run in range(20):
        true_labels= torch.load(os.path.join(save_path,
                                             f'adv_samples/{attack}/xray/pertubation_ranking/{surogate}_labels_run_{run}.ckpt'))
        adv_inputs = torch.load(os.path.join(save_path,
                                             f'adv_samples/{attack}/xray/pertubation_ranking/{surogate}_samples_run_{run}.ckpt'))
        run_df = {}
        for test_model_name in others:

            m = torch.hub.load("pytorch/vision", models[test_model_name],
                                   weights="IMAGENET1K_V1").to(device)
            m = modify_model_output_classes(m, 3)
            m.load_state_dict(torch.load(model_paths[test_model_name])['net'])
            m.eval()

            batch = adv_inputs.to(device)
            with torch.no_grad():
                outputs = torch.softmax(m(batch), dim=1).cpu()
            max_return = outputs.max(dim=1)
            
            run_df[f'{test_model_name}_pred_label'] = (max_return.indices.numpy().tolist())
            run_df[f'{test_model_name}_pred_conf'] = max_return.values.numpy()

            true_conf = torch.gather(outputs, dim=1, index=true_labels.view(-1, 1)).view(-1)
            run_df[f'{test_model_name}_true_conf'] = (true_conf.cpu().numpy().tolist())
            run_df[f'{test_model_name}_true_label'] = (true_labels.numpy().tolist())

            del m
        run_df = pd.DataFrame(run_df)
        run_df.to_csv(f'./results/pertubation_ranking/xray/run_{run}_xray_surogate_{surogate}.csv', index=False)

Using cache found in /home/guy5/.cache/torch/hub/pytorch_vision_main
Using cache found in /home/guy5/.cache/torch/hub/pytorch_vision_main
Using cache found in /home/guy5/.cache/torch/hub/pytorch_vision_main
Using cache found in /home/guy5/.cache/torch/hub/pytorch_vision_main
Using cache found in /home/guy5/.cache/torch/hub/pytorch_vision_main
Using cache found in /home/guy5/.cache/torch/hub/pytorch_vision_main
Using cache found in /home/guy5/.cache/torch/hub/pytorch_vision_main
Using cache found in /home/guy5/.cache/torch/hub/pytorch_vision_main
Using cache found in /home/guy5/.cache/torch/hub/pytorch_vision_main
Using cache found in /home/guy5/.cache/torch/hub/pytorch_vision_main
Using cache found in /home/guy5/.cache/torch/hub/pytorch_vision_main
Using cache found in /home/guy5/.cache/torch/hub/pytorch_vision_main
Using cache found in /home/guy5/.cache/torch/hub/pytorch_vision_main
Using cache found in /home/guy5/.cache/torch/hub/pytorch_vision_main
Using cache found in /home/guy5/.c

In [12]:
model_paths[test_model_name]

'./models/xray/vit_b_16_0.ckpt'

In [9]:
# df = {}

# for model_name in model_names:
#     m = torch.hub.load("pytorch/vision", models[model_name],
#                            weights="IMAGENET1K_V1").to(device)
#     m = modify_model_output_classes(m, 3)
#     m.load_state_dict(torch.load(model_paths[model_name])['net'])
#     print(f'loaded {model_paths[model_name]}')
#     m.eval()
#     df[f'{model_name}_probs'] = []
#     df[f'{model_name}_noise_probs'] = []
#     with torch.no_grad():
#         sample_num = 0
#         for batch, labels in test_loader:
#             batch, labels = batch.to(device), labels.to(device)
#             probs = torch.softmax(m(batch), dim=1).max(dim=1).values
#             noise_probs =torch.softmax(m(batch+
#                                          (16/255)*torch.randn_like(batch)),
#                                        dim=1).max(dim=1).values

#             df[f'{model_name}_probs'].extend(probs.detach().cpu().numpy().tolist())
#             df[f'{model_name}_noise_probs'].extend(noise_probs.detach().cpu().numpy().tolist())

# df['sample_index'] = list(range(len(test_loader.dataset)))        
# df = pd.DataFrame(df)
# df.to_csv(f'./results/{attack}/xray_all_preds.csv', index=False)
    
df = {}
batch, labels = next(iter(test_loader))
batch, labels = batch.to(device), labels.to(device)
for model_name in model_names:
    if '_copy' in model_name:
        continue
    m = torch.hub.load("pytorch/vision", models[model_name],
                           weights="IMAGENET1K_V1").to(device)
    m = modify_model_output_classes(m, 3)
    m.load_state_dict(torch.load(model_paths[model_name])['net'])
    print(f'loaded {model_paths[model_name]}')
    m.eval()
    df[f'{model_name}_probs'] = []
    df[f'{model_name}_noise_probs'] = []
    with torch.no_grad():
        probs = torch.softmax(m(batch), dim=1).max(dim=1).values
        noise_probs =torch.softmax(m(batch+
                                     (16/255)*torch.randn_like(batch)),
                                   dim=1).max(dim=1).values

        df[f'{model_name}_probs'].extend(probs.detach().cpu().numpy().tolist())
        df[f'{model_name}_noise_probs'].extend(noise_probs.detach().cpu().numpy().tolist())

df = pd.DataFrame(df)
df.to_csv(f'./results/pertubation_ranking/xray/xray_all_preds.csv', index=False)  

Using cache found in /home/guy5/.cache/torch/hub/pytorch_vision_main


loaded ./models/xray/vit_b_16_0.ckpt


Using cache found in /home/guy5/.cache/torch/hub/pytorch_vision_main


loaded ./models/xray/densenet121_0.ckpt


Using cache found in /home/guy5/.cache/torch/hub/pytorch_vision_main


loaded ./models/xray/efficientnet_b2_0.ckpt


Using cache found in /home/guy5/.cache/torch/hub/pytorch_vision_main


loaded ./models/xray/resnet18_0.ckpt


Using cache found in /home/guy5/.cache/torch/hub/pytorch_vision_main


loaded ./models/xray/swin_s_0.ckpt


In [15]:
import torchvision.transforms as transforms
import cv2

class Resize(object):
    def __init__(self, org_size, inter_size):
        self.org_size = org_size
        self.inter_size = inter_size

    def __call__(self, img):
        img = transforms.functional.resize(img, self.inter_size)
        return transforms.functional.resize(img, self.org_size)
    
class Compress(object):
    def __init__(self,):
        self.encode_param =  [int(cv2.IMWRITE_JPEG_QUALITY), 90]

    def __call__(self, img):
        img = (img.cpu().numpy()*255).astype(np.uint8).transpose([1,2,0])
        result, encimg = cv2.imencode('.jpg', img, self.encode_param)
        decimg = cv2.imdecode(encimg, 1)
        return torch.Tensor(decimg.transpose(2,0,1))/255
    
class Adv_dataset_with_transforms(Dataset):
    def __init__(self, adv_inputs, true_labels, sample_index, transforms):
        self.data = adv_inputs
        self.target = true_labels
        self.sample_index = sample_index
        self.transform = transforms
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        x = self.data[index]
        x = self.transform(x)
        
        y = self.target[index]
        sample_index = self.sample_index[index]
        return x, y, sample_index

In [16]:
transform = 'compression'

if transform == 'compression':
    trans = Compress()
elif transform == 'resize':
    trans = Resize(224, 200)
elif transform == 'combined':
    trans = transforms.Compose([
        Compress(),
        Resize(224, 200)
    ])
else:
    trans = lambda x: x

In [17]:
for surogate, others in other_list.items():
    
    true_labels= torch.load(f'./adv_samples/PGD/xray/{surogate}_labels.ckpt')
    adv_inputs = torch.load(f'./adv_samples/PGD/xray/{surogate}_samples.ckpt')
    sample_index = torch.load(f'./adv_samples/PGD/xray/{surogate}_index.ckpt')
    
    adv_dataset = torch.utils.data.TensorDataset(adv_inputs, true_labels, sample_index)
    adv_loader = torch.utils.data.DataLoader(adv_dataset, batch_size=64,
                                             shuffle=False, num_workers=4)
    print(f'Testing model {surogate}')
    results_df = {}
    for test_model_name in others:
        
        results_df[f'{test_model_name}_pred_label'] = []
        results_df[f'{test_model_name}_pred_conf'] = []
        results_df[f'{test_model_name}_true_conf'] = []
        results_df[f'{test_model_name}_true_label'] = []
        
        path = model_paths[test_model_name]
        model = models[test_model_name]
        
        m = torch.hub.load("pytorch/vision", model,
                               weights="IMAGENET1K_V1").to(device)
        m = modify_model_output_classes(m, 3)
        m.load_state_dict(torch.load(path)['net'])
        print(f'loaded {path}')
        m.eval()

        true_labels = []
        adv_inputs = []
        with torch.no_grad():
            for batch, labels, _ in adv_loader:
                batch = batch.to(device)
                outputs = torch.softmax(m(batch), dim=1).cpu()
                max_return = outputs.max(dim=1)
                
                results_df[f'{test_model_name}_pred_label'].extend(max_return.indices.numpy().tolist())
                results_df[f'{test_model_name}_pred_conf'].extend(max_return.values.numpy().tolist())
                
                true_conf = torch.gather(outputs, dim=1, index=labels.view(-1, 1)).view(-1)
                results_df[f'{test_model_name}_true_conf'].extend(true_conf.cpu().numpy().tolist())
                results_df[f'{test_model_name}_true_label'].extend(labels.numpy().tolist())

        del m
    results_df['sample_index'] = []
    for _, _, index in adv_loader:
        results_df['sample_index'].extend(index.numpy().tolist())
    results_df = pd.DataFrame(results_df)
    results_df.to_csv(f'./results/transforms/{transform}/xray_surogate_{surogate}.csv', index=False)

Testing model Resnet18


Using cache found in /home/guy5/.cache/torch/hub/pytorch_vision_main


loaded ./models/xray/vit_b_16_0.ckpt


Using cache found in /home/guy5/.cache/torch/hub/pytorch_vision_main


loaded ./models/xray/resnet18_1.ckpt


Using cache found in /home/guy5/.cache/torch/hub/pytorch_vision_main


loaded ./models/xray/densenet121_0.ckpt


Using cache found in /home/guy5/.cache/torch/hub/pytorch_vision_main


loaded ./models/xray/efficientnet_b2_0.ckpt


Using cache found in /home/guy5/.cache/torch/hub/pytorch_vision_main


loaded ./models/xray/swin_s_0.ckpt
Testing model Swin_s


Using cache found in /home/guy5/.cache/torch/hub/pytorch_vision_main


loaded ./models/xray/vit_b_16_0.ckpt


Using cache found in /home/guy5/.cache/torch/hub/pytorch_vision_main


loaded ./models/xray/swin_s_1.ckpt


Using cache found in /home/guy5/.cache/torch/hub/pytorch_vision_main


loaded ./models/xray/densenet121_0.ckpt


Using cache found in /home/guy5/.cache/torch/hub/pytorch_vision_main


loaded ./models/xray/efficientnet_b2_0.ckpt


Using cache found in /home/guy5/.cache/torch/hub/pytorch_vision_main


loaded ./models/xray/resnet18_0.ckpt
Testing model ViT


Using cache found in /home/guy5/.cache/torch/hub/pytorch_vision_main


loaded ./models/xray/densenet121_0.ckpt


Using cache found in /home/guy5/.cache/torch/hub/pytorch_vision_main


loaded ./models/xray/vit_b_16_1.ckpt


Using cache found in /home/guy5/.cache/torch/hub/pytorch_vision_main


loaded ./models/xray/efficientnet_b2_0.ckpt


Using cache found in /home/guy5/.cache/torch/hub/pytorch_vision_main


loaded ./models/xray/resnet18_0.ckpt


Using cache found in /home/guy5/.cache/torch/hub/pytorch_vision_main


loaded ./models/xray/swin_s_0.ckpt
Testing model DenseNet121


Using cache found in /home/guy5/.cache/torch/hub/pytorch_vision_main


loaded ./models/xray/vit_b_16_0.ckpt


Using cache found in /home/guy5/.cache/torch/hub/pytorch_vision_main


loaded ./models/xray/densenet121_1.ckpt


Using cache found in /home/guy5/.cache/torch/hub/pytorch_vision_main


loaded ./models/xray/efficientnet_b2_0.ckpt
loaded ./models/xray/resnet18_0.ckpt


Using cache found in /home/guy5/.cache/torch/hub/pytorch_vision_main
Using cache found in /home/guy5/.cache/torch/hub/pytorch_vision_main


loaded ./models/xray/swin_s_0.ckpt
Testing model Efficientnet_b2


Using cache found in /home/guy5/.cache/torch/hub/pytorch_vision_main


loaded ./models/xray/vit_b_16_0.ckpt


Using cache found in /home/guy5/.cache/torch/hub/pytorch_vision_main


loaded ./models/xray/efficientnet_b2_1.ckpt


Using cache found in /home/guy5/.cache/torch/hub/pytorch_vision_main


loaded ./models/xray/densenet121_0.ckpt
loaded ./models/xray/resnet18_0.ckpt


Using cache found in /home/guy5/.cache/torch/hub/pytorch_vision_main
Using cache found in /home/guy5/.cache/torch/hub/pytorch_vision_main


loaded ./models/xray/swin_s_0.ckpt
