In [2]:
import pandas as pd
import numpy as np
import os
from tqdm import tqdm
from sklearn import metrics
import pickle
import shutil
from PIL import Image
from scipy.special import softmax

In [4]:
#load reference csv
reference_csv_path = '/home/jupyter/TCGA_lung.csv'
TCGA_Lung_csv = pd.read_csv(reference_csv_path)

#set feature(.h5 files path)
features_dir = '/home/ext_yao_gary_mayo_edu/lung-features/'
list_feature_files = os.listdir(features_dir)

#set clam path
clam_dir = '/home/ext_yao_gary_mayo_edu/CLAM'
%cd '/home/ext_yao_gary_mayo_edu/CLAM'

#set model_loading_path
model_checkpt_dir = './target_checkpoint.pt'

#set train/val/test split path
sp

/home/ext_yao_gary_mayo_edu/CLAM


In [6]:
from models.model_clam import CLAM_MB
import torch
import os
import h5py
from torch.autograd import grad
import openslide
from datasets.dataset_h5 import eval_transforms
from models.resnet_custom import resnet50_baseline

from torch.utils.data import Dataset, DataLoader, sampler
from torchvision import transforms, utils, models
import torch.nn.functional as F

e_transform = eval_transforms(pretrained = True)
model = CLAM_MB(n_classes=2, dropout=True)
ckpt_path = model_checkpt_dir
ckpt = torch.load(ckpt_path, map_location=torch.device('cpu'))
ckpt_clean = {}
for key in ckpt.keys():
    if 'instance_loss_fn' in key:
        continue
    ckpt_clean.update({key.replace('.module', ''):ckpt[key]})
model.load_state_dict(ckpt_clean, strict=True)

feature_extractor = resnet50_baseline(pretrained=True)
feature_extractor.eval()
model_defense = model_defense.cuda()
feature_extractor = feature_extractor.cuda()

mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)

toTensor = transforms.ToTensor()
normalize = transforms.Normalize(mean = mean, std = std)

FileNotFoundError: [Errno 2] No such file or directory: '/home/ext_yao_gary_mayo_edu/CLAM_saves/100s_2_checkpoint.pt'

In [None]:
object_methods = [method_name for method_name in dir(model_defense)
                  if callable(getattr(model_defense, method_name))]
print(object_methods)

In [None]:
splits = pd.read_csv('/home/jupyter/splits.csv')
train_file_list = splits['train'].values.tolist()
test_file_list = splits['test'].values.tolist()
val_file_list = splits['val'].values.tolist()
train_file_list = [x for x in train_file_list if x == x]
test_file_list = [x for x in test_file_list if x == x]
val_file_list = [x for x in val_file_list if x == x]
val_test_file_list = val_file_list + test_file_list
full_list = val_test_file_list + train_file_list

# Layer -1: Take forward_run output and produce accu and roctrain_file_list

In [None]:
def accu(output, reference, answer_rows = ['c1', 'c2']):
    total = 0
    correct = 0
    for index, row in output.iterrows():
        curr_slide_name = row['slide_name']
        curr_c1 = float(row[answer_rows[0]])
        curr_c2 = float(row[answer_rows[1]])
        reference_row = reference[reference['slide_name'] == curr_slide_name]
        reference_label = reference_row['label'].item()
        if reference_label == 'LUAD':
            if curr_c1 > curr_c2:
                correct += 1
        elif curr_c2 > curr_c1:
            correct +=1 
        total += 1
    return correct/total

In [None]:
def auc(output, reference, answer_rows = ['c1', 'c2']):
    #output is softmaxed
    pred = output[answer_rows[0]].astype(np.float)
    labels = []
    for index, row in output.iterrows():
        curr_slide_name = row['slide_name']
        reference_row = reference[reference['slide_name'] == curr_slide_name]
        reference_label = reference_row['label'].item()
        if reference_label == 'LUAD':
            labels.append(2)
        else:
            labels.append(1)
    y = np.array(labels)
    fpr, tpr, thresholds = metrics.roc_curve(y, pred, pos_label=2)
    auc = metrics.auc(fpr, tpr)
    return auc

# Layer -2: dataset level forward run

In [None]:
def forward_dataset(model, features_dir, target_list, softmax = True, verbose = False, save = False, limit = -1):
    list_feature_files = os.listdir(features_dir)
    output = pd.DataFrame(columns = ['slide_name', 'c1', 'c2'])
    for slide_name in tqdm(target_list):
        full_feature_path = os.path.join(features_dir, '{}.h5'.format(slide_name))
        try:
            with h5py.File(full_feature_path, 'r') as hdf5_file:
                features = hdf5_file['features'][:]
                coords = hdf5_file['coords'][:]
        except:
            print('Failed on {}'.format(slide_name))
            continue
        features = torch.from_numpy(features).cuda()
        l, y, y1, a, r = forward_ins(model, features, coords)
        
        if softmax:
            c1 = y[0,0].detach().cpu().item()
            c2 = y[0,1].detach().cpu().item()
        else:
            c1 = l[0,0].detach().cpu().item()
            c2 = l[0,1].detach().cpu().item()
        
        output = output.append(pd.DataFrame([[slide_name, c1, c2]], columns = output.columns))
    return output 

In [None]:
def dataset_topk_attack(model, feature_extractor, reference, target_list, features_dir, slides_dir, results_save_dir = None, dataset_save_dir = None, topk = 10, e = 0.001, f = 0.001, limit = -1, save_dataset = False, attention_comp = False, verbose = False):
    output = pd.DataFrame(columns = ['slide_name', 'c1_o', 'c2_o', 'c1_a', 'c2_a', 'c1_att', 'c2_att'])
    feature_files = os.listdir(features_dir)
    for slide_name in tqdm(target_list):
        full_feature_path = os.path.join(features_dir, '{}.h5'.format(slide_name))
        try:
            with h5py.File(full_feature_path, 'r') as hdf5_file:
                features = hdf5_file['features'][:]
                coords = hdf5_file['coords'][:]
            features = torch.from_numpy(features)
            slide_full_path = os.path.join(slides_dir,
                                           '{}.svs'.format(slide_name))
            wsi = openslide.open_slide(slide_full_path)

            reference_row = reference[reference['slide_name'] == slide_name]
            reference_label = reference_row['label'].item()

            slide_resolution = int(reference_row['resolution'])
        except:
            print('failed on {}'.format(slide_name))
            continue
        
        if reference_label == 'LUAD':
            attack_target = 0
        else:
            attack_target = 1
        
        if slide_resolution == 40:
            patch_size = 512
            downsample = 2
        else:
            patch_size = 256
            downsample = 1
        
        y_collection, a_collection, idx_coord_list, img_collection, features_collection, noise_collection = instance_topk_attack(model, feature_extractor,
                                                                                                               features, coords, wsi, patch_size = patch_size, downsample = downsample, 
                                                            target = attack_target, topk = topk, e = e , f = f, attention_comp = attention_comp)
        y_o, y_a = y_collection[0:2]
        a_o, a_a = a_collection[0:2]
        img_a = img_collection[0]
        features_a = features_collection[0]
        if attention_comp:
            y_att = y_collection[2]
            a_att = a_collection[2]
            img_att = img_collection[1]
            features_att = features_collection[1]
        
        y_o = y_o.detach().cpu()
        y_a = y_a.detach().cpu()
        c1_o = y_o[0,0].item()
        c2_o = y_o[0,1].item()
        c1_a = y_a[0,0].item()
        c2_a = y_a[0,1].item()
        c1_att = 0
        c2_att = 0
    
        if attention_comp:
            c1_att = y_att[0,0].item()
            c2_att = y_att[0,1].item()
            
        output = output.append(pd.DataFrame([[slide_name, c1_o, c2_o, c1_a, c2_a, c1_att, c2_att]], columns = output.columns))
        
        if save_dataset:
            #dataset_save_dir has three folders: img, features and coords
            #results_save_dir  has folders: att
            img_save_path = os.path.join(dataset_save_dir, 'img')
            feature_save_path = os.path.join(dataset_save_dir, 'features')
            coords_save_path = os.path.join(dataset_save_dir, 'coords')
            
            attention_save_path = os.path.join(results_save_dir, 'att')
            
            
            torch.save(img_a, os.path.join(img_save_path, '{}_adv.tr'.format(slide_name)))
            torch.save(features_a, os.path.join(feature_save_path, '{}_adv.tr'.format(slide_name)))
            with open(os.path.join(coords_save_path,
                '{}_adv.p'.format(slide_name)), 'wb') as handle:
                pickle.dump(idx_coord_list, handle)
            
            #todo attention
            
            if attention_comp:
                torch.save(img_att, os.path.join(img_save_path, '{}_att.tr'.format(slide_name)))
                torch.save(features_att, os.path.join(feature_save_path, '{}_att.tr'.format(slide_name)))
            
    return output

# Layer -3: instance level forward_run

In [None]:
def forward_ins(model, features, coords):
    model.cuda()
    l, y, y1, a, r = model(features)
    return l, y, y1, a, r

In [None]:
def instance_topk_attack(model, feature_extractor, features, coords, wsi,
        patch_size = 256, patch_level = 0, downsample = 1, target = 1, topk = 10, e = .001,
        f = 0.001, attention_comp = False, verbose = False):
    
    model = model.cuda()
    feature_extractor = feature_extractor.cuda()
    
    features = torch.clone(features).cuda()
    features.requires_grad = True
    
    l, y, y1, a, r = model(features)
    res, ind = a[target,:].topk(topk)
    topk_coord = coords[ind.cpu(),:]
    
    if verbose:
        print("model prediciton {}".format(y))
        print("top k attented values {}".format(res))
        print("top k attention index {}".format(ind))
        print(ind.shape)
        print(coords.shape)
        print("top k attented coord {}".format(topk_coord))
    if topk == 1:
        topk_coord = [topk_coord]

    features_copy = torch.clone(features)
    
    img_copy = torch.ones((topk,3,256,256)).cuda()
    img_copy_unnormalized = torch.ones((topk,3,256,256)).cuda()
    
    attack_img_unnormalized = torch.ones((topk,3,256,256)).cuda()
    attack_img_normalized = torch.ones((topk,3,256,256)).cuda()
    
    att_img_unnormalized = torch.ones((topk,3,256,256)).cuda()
    att_img_normalized = torch.ones((topk,3,256,256)).cuda()
    
    idx_coord_list = []
    
    for i in range(topk):
        coord_i = topk_coord[i]
        idx_coord_list.append((ind[i],coord_i))
        img = wsi.read_region(coord_i, patch_level,
                (patch_size, patch_size)).convert('RGB')
        if downsample != 1:
            final_size = (int(patch_size/downsample), ) *2
            img = img.resize(final_size)
        
        img = toTensor(img).cuda()
        img_copy_unnormalized[i] = img
        
        img = normalize(img).unsqueeze(0)
        img_copy[i] = img
    
    #topk attack
    img_copy.requires_grad = True
    img_copy_unnormalized.requires_grad = True
    
    new_features = feature_extractor(img_copy)
    for i in range(topk):
        features_copy[ind[i], :] = new_features[i]
    l_original, y_original, y1_original, a_original, r = model(features_copy)
    
    loss = F.nll_loss(l_original, torch.tensor([target]).cuda())
    loss.backward()
    d_x_y = img_copy.grad.data
    
    attack_img_unnormalized = torch.clone(img_copy_unnormalized + d_x_y.sign() * e)
    for i in range(topk):
        attack_img_normalized[i] = normalize(attack_img_unnormalized[i])
    attack_features = feature_extractor(attack_img_normalized)
    
    for i in range(topk):
        features_copy[ind[i], :] = attack_features[i]
    l_attacked, y_attacked, y1_attacked, a_attacked,r = model(features_copy)
    
    noise_collection = [d_x_y.sign().detach()]
    l_collection = [l_original.detach(), l_attacked.detach()]
    a_collection = [a_original.detach(), a_attacked.detach()]
    y_collection = [y_original.detach(), y_attacked.detach()]
    img_collection = [img_copy_unnormalized.detach(), attack_img_unnormalized.detach()]
    features_collection= [attack_features.detach()]
    
    if attention_comp:
        model.zero_grad()
        feature_extractor.zero_grad()
        attention_loss = torch.sum((a_attacked - a_original)**2)
        
        d_x_adiff = grad(attention_loss, attack_img_normalized)[0]
        
        att_img_unnormalized = attack_img_unnormalized - f * d_x_adiff.sign()
        
        for i in range(topk):
            att_img_normalized[i] = normalize(att_img_unnormalized[i])
        
        attack_features_att = feature_extractor(att_img_normalized)

        for i in range(topk):
            features_copy[ind[i], :] = attack_features_att[i]

        l_att, y_att, y1_att, a_att, r = model(features_copy)
        new_attention_loss = torch.sum((a_att - a)**2)
        #print('attenion loss {} new attention loss {}'.format(attention_loss, new_attention_loss))
        l_collection.append(l_att.detach())
        a_collection.append(a_att.detach())
        y_collection.append(y_att.detach())
        noise_collection.append(d_x_adiff.detach())
        img_collection.append(att_img_unnormalized.detach())
        features_collection.append(attack_features_att.detach())
        
    return y_collection, a_collection, idx_coord_list, img_collection, features_collection, noise_collection

# dropout no dropout experiment

In [None]:
model_dropout = CLAM_MB(n_classes=2, dropout=True)
ckpt_path = '/home/ext_yao_gary_mayo_edu/CLAM_saves/100s_2_checkpoint.pt'
ckpt = torch.load(ckpt_path, map_location=torch.device('cpu'))
ckpt_clean = {}
for key in ckpt.keys():
    if 'instance_loss_fn' in key:
        continue
    ckpt_clean.update({key.replace('.module', ''):ckpt[key]})
model_dropout.load_state_dict(ckpt_clean, strict=True)

model_no_dropout = CLAM_MB(n_classes=2, dropout=False)

ckpt_path = '/home/ext_yao_gary_mayo_edu/CLAM_saves/100s_2_checkpoint.pt'
ckpt = torch.load(ckpt_path, map_location=torch.device('cpu'))
ckpt_clean = {}
for key in ckpt.keys():
    if 'instance_loss_fn' in key:
        continue
    key_new = key.replace('attention_net.module.3.attention', 'attention_net.module.2.attention')
    ckpt_clean.update({key_new.replace('.module', ''):ckpt[key]})
    
model_no_dropout.load_state_dict(ckpt_clean, strict=True)

In [None]:
dropout_output = forward_dataset(model_dropout, features_dir='/home/ext_yao_gary_mayo_edu/lung-features/')
no_dropout_output = forward_dataset(model_no_dropout, features_dir='/home/ext_yao_gary_mayo_edu/lung-features/')

In [None]:
dropout_output

In [None]:
dropout_accu, dropout_auc = accu(dropout_output, answer_rows = ['c1','c2'], reference = TCGA_Lung_csv), auc(dropout_output, answer_rows = ['c1','c2'], reference = TCGA_Lung_csv)
no_dropout_accu, no_dropout_auc = accu(no_dropout_output, answer_rows = ['c1', 'c2'], reference = TCGA_Lung_csv), auc(no_dropout_output, answer_rows = ['c1', 'c2'], reference = TCGA_Lung_csv)

In [None]:
print(dropout_accu, dropout_auc)
print(no_dropout_accu, no_dropout_auc)

# Running Experiment

In [None]:
output = forward_dataset(model_defense, '/home/ext_yao_gary_mayo_edu/lung-features/h5_files/', val_test_file_list)
print(accu(output, TCGA_Lung_csv), auc(output, TCGA_Lung_csv))

In [None]:
output = dataset_topk_attack(model_defense, feature_extractor, TCGA_Lung_csv, val_test_file_list,
                             '/home/ext_yao_gary_mayo_edu/lung-features/h5_files/', '/home/ext_yao_gary_mayo_edu/TCGA-Lung-Slides',
                             e = 0.008, f = 0.008, limit = 100, attention_comp = True)

In [None]:
baseline_accuracy, baseline_auc = accu(output, TCGA_Lung_csv, 
                                       answer_rows = ['c1_o', 'c2_o']), auc(output, TCGA_Lung_csv, answer_rows = ['c1_o', 'c2_o'])
print(baseline_accuracy, baseline_auc)

In [None]:
attacked_accuracy, attacked_auc = accu(output, TCGA_Lung_csv,
                                       answer_rows = ['c1_a', 'c2_a']),auc(output, TCGA_Lung_csv, answer_rows = ['c1_a', 'c2_a'])
print(attacked_accuracy, attacked_auc)

In [None]:
attented_accuracy, attented_auc = accu(output, TCGA_Lung_csv, 
                                       answer_rows = ['c1_att', 'c2_att']), auc(output, TCGA_Lung_csv, answer_rows = ['c1_att', 'c2_att'])
print(attented_accuracy, attented_auc)

In [None]:
default_results_path = '/home/jupyter/topk_datasets/validation_saves_final'
default_datasets_path = '/home/jupyter/topk_datasets/validation_datasets_final'
curr_results_path = os.path.join(default_results_path, 'ee2fe2')
curr_datasets_path = os.path.join(default_datasets_path, 'ee2fe2')
def gen_save_folders(dataset_path, results_path):
    if os.path.isdir(dataset_path):
        shutil.rmtree(dataset_path)
    if os.path.isdir(results_path):
        shutil.rmtree(results_path)
    os.mkdir(dataset_path)
    os.mkdir(results_path)
    os.mkdir(os.path.join(dataset_path, 'img'))
    os.mkdir(os.path.join(dataset_path, 'features'))
    os.mkdir(os.path.join(dataset_path, 'coords'))

    os.mkdir(os.path.join(results_path, 'att'))

In [None]:
e_list = [0.001, 0.005, 0.01, 0.03, 0.06, 0.1]
save_name_list = ['e3e', 'e5e3', 'e2e', 'e32e', 'e62e', 'e1e']

In [None]:
for e, save_name in zip(e_list, save_name_list):
    curr_results_path = os.path.join(default_results_path, save_name)
    curr_datasets_path = os.path.join(default_datasets_path, save_name)
    gen_save_folders(dataset_path=curr_datasets_path, results_path = curr_results_path)
    output = dataset_topk_attack(model_defense, feature_extractor, TCGA_Lung_csv, val_test_file_list,
                             '/home/ext_yao_gary_mayo_edu/lung-features/h5_files/', '/home/ext_yao_gary_mayo_edu/TCGA-Lung-Slides',
                             results_save_dir=curr_results_path, dataset_save_dir=curr_datasets_path, save_dataset = True, e = e, f = e, topk = 10, limit = -1, attention_comp = True)
    output.to_csv('/home/jupyter/spreads_val/ts{}.csv'.format(e),index=False)

In [None]:
k_list = [1,3,5,10,15,20,30,40,50]
save_name_list = ['k1', 'k3', 'k5', 'k10', 'k15', 'k20', 'k30', 'k40']

In [None]:
for k, save_name in zip(k_list, save_name_list):
    curr_results_path = os.path.join(default_results_path, save_name)
    curr_datasets_path = os.path.join(default_datasets_path, save_name)
    gen_save_folders(dataset_path=curr_datasets_path, results_path = curr_results_path)
    output = dataset_topk_attack(model_defense, feature_extractor, TCGA_Lung_csv, val_test_file_list,
                             '/home/ext_yao_gary_mayo_edu/lung-features/h5_files/', '/home/ext_yao_gary_mayo_edu/TCGA-Lung-Slides',
                             results_save_dir=curr_results_path, dataset_save_dir=curr_datasets_path, save_dataset = True, e = 0.01, f = 0.01, topk = k, limit = -1, attention_comp = True)
    output.to_csv('/home/jupyter/spreads/new{}.csv'.format(save_name),index=False)

In [None]:
e_list = [.01]
save_name_list = ['e1e_whole_final']

In [None]:
for e, save_name in zip(e_list, save_name_list):
    curr_results_path = os.path.join(default_results_path, save_name)
    curr_datasets_path = os.path.join(default_datasets_path, save_name)
    gen_save_folders(dataset_path=curr_datasets_path, results_path = curr_results_path)
    output = dataset_topk_attack(model_defense, feature_extractor, TCGA_Lung_csv, val_test_file_list,
                             '/home/ext_yao_gary_mayo_edu/lung-features/h5_files/', '/home/ext_yao_gary_mayo_edu/TCGA-Lung-Slides',
                             results_save_dir=curr_results_path, dataset_save_dir=curr_datasets_path, save_dataset = True, e = e, f = e/2, topklimit = -1, attention_comp = True)
    output.to_csv('/home/jupyter/spreads_val/{}.csv'.format(save_name),index=False)

In [None]:
output_path = '/home/jupyter/spreads_val/'

In [None]:
curr_save = 'e1e'
full_csv_path = os.path.join(output_path, '{}.csv'.format(curr_save))
curr_output = pd.read_csv(full_csv_path)

In [None]:
compilation_csv = pd.DataFrame(columns = ['save_name', 'baseline_accu', 'baseline_auc', 'attacked_accu', 'attacked_auc', 'att_accu', 'att_auc'])
for file_name in os.listdir(output_path):
    full_csv_path = os.path.join(output_path, file_name)
    try:
        curr_output = pd.read_csv(full_csv_path)
    except:
        continue
    save_name = file_name[:-4]
    baseline_accuracy, baseline_auc = accu(curr_output, TCGA_Lung_csv, answer_rows = ['c1_o', 'c2_o']), auc(curr_output, TCGA_Lung_csv, answer_rows = ['c1_o', 'c2_o'])
    attacked_accuracy, attacked_auc = accu(curr_output, TCGA_Lung_csv, answer_rows = ['c1_a', 'c2_a']), auc(curr_output, TCGA_Lung_csv, answer_rows = ['c1_a', 'c2_a'])
    att_accuracy, att_auc = accu(curr_output, TCGA_Lung_csv, answer_rows = ['c1_att', 'c2_att']), auc(curr_output, TCGA_Lung_csv, answer_rows = ['c1_att', 'c2_att'])
    compilation_csv = compilation_csv.append(pd.DataFrame([['{}'.format(save_name), baseline_accuracy, baseline_auc, attacked_accuracy, attacked_auc, att_accuracy, att_auc]], columns = compilation_csv.columns))
compilation_csv.to_csv('/home/jupyter/val_compliation.csv')

In [None]:
!rm -r /home/jupyter/spreads/.*

#  Attention Percentage experiment

In [None]:
def get_topk_attention_ratio(a, threshholds):
    #input a is raw attention for one class
    output = []
    a_softmax = torch.softmax(a, dim = 0)
    for threshhold in threshholds:
        res, ind = a_softmax.topk(threshhold)
        output.append(torch.sum(res).item())
    return output
    

In [None]:
def instance_attention_analysis(model, feature_extractor, reference, slide_name, feature_dir, slide_dir, e = 0.01, f = 0.01, topk = 10):
    full_feature_path = os.path.join(features_dir, '{}.h5'.format(slide_name))
    full_slide_path = os.path.join(slide_dir, '{}.svs'.format(slide_name))
    with h5py.File(full_feature_path, 'r') as hdf5_file:
        features = hdf5_file['features'][:]
        coords = hdf5_file['coords'][:]
        features = torch.from_numpy(features)
        wsi = openslide.open_slide(full_slide_path)

        reference_row = reference[reference['slide_name'] == slide_name]
        reference_label = reference_row['label'].item()

        slide_resolution = int(reference_row['resolution'])
            
    if reference_label == 'LUAD':
        attack_target = 0
    else:
        attack_target = 1
        
    if slide_resolution == 40:
        patch_size = 512
        downsample = 2
    else:
        patch_size = 256
        downsample = 1
    
    y_collection, a_collection, idx_coord_list, img_collection, features_collection, noise_collection = instance_topk_attack(model, feature_extractor, features, coords, wsi, patch_size = patch_size, downsample = downsample, 
                                                            target = attack_target, topk = topk, e = e , f = f, attention_comp = True)
    y_o, y_a, y_att = y_collection
    a_o, a_a, a_att = a_collection
    img_a = img_collection[0]
    features_a = features_collection[0]
    
    res, original_topk_ind = a_o.topk(topk)
    
    a_o = softmax(a_o.detach().cpu().numpy(), axis = 1)
    a_a = softmax(a_a.detach().cpu().numpy(), axis = 1)
    a_att = softmax(a_att.detach().cpu().numpy(), axis = 1)
    original_topk_ind = original_topk_ind.detach().cpu().numpy()
    '''
    print(original_topk_ind)
    print(a_o[:,original_topk_ind])
    print(a_a[:,original_topk_ind])
    print(a_att[:,original_topk_ind])
    print(y_o, y_a, y_att)
    '''
    labels = ['1', '2', '3', '4', '5', '6', '7', '8', '9', 'remaining attention scores']
    fig, ax = plt.subplots(3, 2)
    set_size(20,20, ax[0,0])
    sizes_nat_0 = a_o[0][original_topk_ind[0,:-1]].tolist()
    sizes_nat_0.append(1 - sum(sizes_nat_0))
    sizes_nat_1 = a_o[1][original_topk_ind[1,:-1]].tolist()
    sizes_nat_1.append(1 - sum(sizes_nat_1))
    print(sizes_nat_0)
    ax[0,0].pie(sizes_nat_0, labels = labels)
    ax[0,1].pie(sizes_nat_1, labels = labels)
    ax[0,0].set_title('Softmax Attention Shares, Natural, class 1')
    ax[0,1].set_title('Softmax Attention Shares, Natural, class 2')
    sizes_adv_0 = a_a[0][original_topk_ind[0,:-1]].tolist()
    sizes_adv_0.append(1 - sum(sizes_adv_0))
    sizes_adv_1 = a_a[1][original_topk_ind[1,:-1]].tolist()
    sizes_adv_1.append(1 - sum(sizes_adv_1))
    print(sizes_adv_0)
    ax[1,0].pie(sizes_adv_0, labels = labels)
    ax[1,1].pie(sizes_adv_1, labels = labels)
    ax[1,0].set_title('Softmax Attention Shares, Topk Attacked e = 0.01, class 1')
    ax[1,1].set_title('Softmax Attention Shares, Topk Attacked e = 0.01, class 2')
    sizes_att_0 = a_att[0][original_topk_ind[0,:-1]].tolist()
    sizes_att_0.append(1 - sum(sizes_att_0))
    sizes_att_1 = a_att[1][original_topk_ind[1,:-1]].tolist()
    sizes_att_1.append(1 - sum(sizes_att_1))
    print(sizes_att_0)
    ax[2,0].pie(sizes_att_0, labels = labels)
    ax[2,1].pie(sizes_att_1, labels = labels)
    ax[2,0].set_title('Softmax Attention Shares, Topk with Compensastion e,f = 0.01, class 1')
    ax[2,1].set_title('Softmax Attention Shares, Topk with Compensastion e,f = 0.01, class 2')
    plt.suptitle('Attention Hijacking Demostration, slide_id = {}'.format(slide_name[:10]))
    plt.show()
    #print(a_o[0][original_topk_ind[0,:]].shape)
    #pie_chart(a_o[0][original_topk_ind[0,:]].tolist())
   # pie_chart(a_a[0][original_topk_ind[0,:]].tolist())
    #pie_chart(a_att[0][original_topk_ind[0,:]].tolist())

In [None]:
instance_attention_analysis(model_defense, feature_extractor, reference = TCGA_Lung_csv, slide_name = 'TCGA-49-6744-01Z-00-DX2.1982e585-65a4-4330-9140-ccabcdd106f8', feature_dir = '/home/ext_yao_gary_mayo_edu/lung-features/', slide_dir = '/home/ext_yao_gary_mayo_edu/lung-slides/', e = .001, f = .001)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
# if using a Jupyter notebook, include:
%matplotlib inline

def pie_chart(sizes):
    # Pie chart, where the slices will be ordered and plotted counter-clockwise:
    print(sizes)
    sizes.append(1 - sum(sizes))


    fig, ax = plt.subplots()
    ax.pie(sizes, autopct='%1.1f%%')
    ax.axis('equal')  # Equal aspect ratio ensures the pie chart is circular.
    ax.set_title('Natural topk attention softmax')


    plt.show()

#  Visualization Demostration

In [None]:
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure

In [None]:
img = Image.open('/home/jupyter/airplane.jpeg')
a = transforms.ToTensor()
tensored_img = a(img)
untensored_img = untransform(tensored_img)
display(untensored_img)

In [None]:
def untransform(img_array_transformed):
    #3*w*h tensor back to img
    #if torch.max(img_array_transformed) > 1:
     #   img_array_transformed[img_array_transformed > 1] = 1
    img_array = img_array_transformed.transpose(0,1).transpose(1,2)
    img_array = img_array.numpy()
    img = Image.fromarray((img_array*255).astype('uint8'), 'RGB')
    return img

In [None]:
def set_size(w,h, ax=None):
    """ w, h: width, height in inches """
    if not ax: ax=plt.gca()
    l = ax.figure.subplotpars.left
    r = ax.figure.subplotpars.right
    t = ax.figure.subplotpars.top
    b = ax.figure.subplotpars.bottom
    figw = float(w)/(r-l)
    figh = float(h)/(t-b)
    ax.figure.set_size_inches(figw, figh)

In [None]:
def instance_visualization_analytics(model, feature_extractor, reference, slide_name, features_dir, slides_dir, topk = 10, e = [0.01, 0.03, 0.05, 0.1, 0.3, 0.5], return_idx = None):
    
    full_feature_path = os.path.join(features_dir, '{}.h5'.format(slide_name))
    with h5py.File(full_feature_path, 'r') as hdf5_file:
        features = hdf5_file['features'][:]
        coords = hdf5_file['coords'][:]
        features = torch.from_numpy(features)
        slide_full_path = os.path.join(slides_dir,
                                       '{}.svs'.format(slide_name))
        wsi = openslide.open_slide(slide_full_path)

        reference_row = reference[reference['slide_name'] == slide_name]
        reference_label = reference_row['label'].item()

        slide_resolution = int(reference_row['resolution'])
        
        if reference_label == 'LUAD':
            attack_target = 0
        else:
            attack_target = 1
        
        if slide_resolution == 40:
            patch_size = 512
            downsample = 2
        else:
            patch_size = 256
            downsample = 1
    
    y_collection, a_collection, idx_coord_list, img_collection, features_collection, noise_collection = instance_topk_attack(model, 
        feature_extractor, features, coords, wsi,
        patch_size = patch_size, patch_level = 0, downsample = downsample, target = attack_target, topk = topk, e = 0.01,
        f = 0.01, attention_comp = False, verbose = False)
    
    y_original, y_attack = y_collection
    print('original y {} {}'.format(y_original[0,0].item(), y_original[0,1].item()))
    e_variation_num = len(e)
    plt.rcParams['axes.facecolor']='white'
    plt.rcParams['savefig.facecolor']='white'
    fig, ax = plt.subplots(2+e_variation_num, topk)
    figure(figsize=(100, 100))
    img_original, img_attack = img_collection
    a_original, a_attack = a_collection
    attack_noise = noise_collection[0]
    topk_idx_list = []
    if return_idx is not None:
        return_i, return_j = return_idx
    return_item = None
    for i in range(topk):
        idx, coord = idx_coord_list[i]
        print(idx)
        print(coord)
        topk_idx_list.append(idx)
        
    for i, current_img_tensor in enumerate(img_original):
        current_attention = a_original[:,topk_idx_list[i]]
        temp_test_img = torch.unsqueeze(current_img_tensor, dim = 0)
        temp_feats = feature_extractor(temp_test_img)
        
        #print(temp_feats.shape)
        l,y,y1,a,r = model(temp_feats)
        c0 = y[0,0].item()
        c1 = y[0,1].item()
        current_img = untransform(current_img_tensor.cpu())
        ax[0, i].imshow(current_img)
        ax[0, i].set_axis_off()
        ax[0, i].set_title("{}, [{}, {}]".format(str(round(current_attention[attack_target].item(),2)), 
                                                   str(round(c0,2)), str(round(c1,2))),
                           color = 'white')
        set_size(15,15, ax[0,i])
    for i, current_noise_tensor in enumerate(attack_noise):
        current_noise_img = untransform(current_noise_tensor.cpu())
        ax[1, i].imshow(current_noise_img)
        ax[1, i].set_axis_off()
        
    for j, current_e in enumerate(e):
        y_collection, a_collection, idx_coord_list, img_collection, features_collection, noise_collection = instance_topk_attack(model, feature_extractor, features, coords, wsi,patch_size = patch_size, patch_level = 0, downsample = downsample, target = attack_target, topk = topk, e = current_e, f = 0.01, attention_comp = False, verbose = False)
        a_original, a_attack = a_collection
        y_original, y_attack = y_collection
        print('attacked at e = {}, {} {}'.format(current_e, y_attack[0,0].item(), y_attack[0,1].item()))
        for i, (current_img_tensor, current_noise_tensor) in enumerate(zip(img_original, attack_noise)):
            current_attention = a_attack[:,topk_idx_list[i]]
            current_attacked_img_tensor = current_img_tensor + current_e * current_noise_tensor
            
            temp_test_img = torch.unsqueeze(current_attacked_img_tensor, dim = 0)
            temp_feats = feature_extractor(temp_test_img)
            l,y,y1,a,r = model(temp_feats)
            c0 = y[0,0].item()
            c1 = y[0,1].item()
            
            current_attacked_img = untransform(current_attacked_img_tensor.cpu())
            ax[j+2,i].imshow(current_attacked_img)
            ax[j+2,i].set_axis_off()
            ax[j+2,i].set_title("{}, [{}, {}]".format(str(round(current_attention[attack_target].item(),2)), 
                                                   str(round(c0,2)), str(round(c1,2))),
                           color = 'white')
            if j+2 == return_j and i == return_i:
                return_item = current_attacked_img
    plt.show()
    return return_item

In [None]:
features_dir = '/home/ext_yao_gary_mayo_edu/lung-features/'
slides_dir = '/home/ext_yao_gary_mayo_edu/lung-slides'
target_slide = 'TCGA-49-6744-01Z-00-DX2.1982e585-65a4-4330-9140-ccabcdd106f8'
moded_center = instance_visualization_analytics(model_defense, feature_extractor, TCGA_Lung_csv, target_slide, features_dir, slides_dir, return_idx = [4,4])

In [None]:
display(moded_center)

In [None]:
def generate_one_amoung_many(slide_name, reference, target_coords, wsi, moded_center, e = 0.01, dist = 1):
    reference_row = reference[reference['slide_name'] == slide_name]
    reference_label = reference_row['label'].item()

    slide_resolution = int(reference_row['resolution'])

    if slide_resolution == 40:
        patch_size = 512
        downsample = 2
    else:
        patch_size = 256
        downsample = 1
    target_l, target_w = target_coords
    dim_l, dim_w = patch_size*(1+(2*dist)), patch_size*(1+(2*dist))
    start_l = target_l - dist*patch_size
    start_w = target_w - dist*patch_size
    print((start_l, start_w), (dim_l, dim_w))
    img = wsi.read_region((start_l, start_w), 0,
                (dim_l, dim_w)).convert('RGB')
    if downsample != 1:
        final_size = (int(dim_l/downsample), ) * 2
        img = img.resize(final_size)
    img_array = np.array(img)
    moded_center_array = np.array(moded_center)
    print(moded_center_array.shape)
    print(img_array.shape)
    img_array[dist*256:(dist+1)*256, dist*256:(dist+1)*256, :] = np.array(moded_center)
    img = Image.fromarray(img_array, 'RGB')
    
    display(img)
    
    

In [None]:
features_dir = '/home/ext_yao_gary_mayo_edu/lung-features/'
slides_dir = '/home/ext_yao_gary_mayo_edu/lung-slides'
target_slide = 'TCGA-49-6744-01Z-00-DX2.1982e585-65a4-4330-9140-ccabcdd106f8'

slide_full_path = os.path.join(slides_dir,
                               '{}.svs'.format(target_slide))
wsi = openslide.open_slide(slide_full_path)
generate_one_amoung_many(target_slide, TCGA_Lung_csv, (49312,33184), wsi, moded_center, dist = 4)

In [None]:
def gen_pathology_exp(e_levels = [0.05, 0.1, 0.3], ):
    

# Defense Experiment

In [None]:
def tile_accu(output, target_columns = ['natural_used', 'natural_total']):
    total_used = 0
    total_tiles = 0
    for index, row in output.iterrows():
        slide_used = int(row[target_columns[0]])
        slide_total = int(row[target_columns[1]])
        total_used += slide_used
        total_tiles += slide_total
    return total_used, total_tiles
    

In [None]:
def defense_validation(model, feature_extractor, features_dir, adv_dataset_dir, reference, val_list, defense = 'std_filter', d = 2.5, full_pipe = False, attention_comp = False, softmax = True, verbose = False, save = False, limit = -1, natural = False):
    #validated on e1e good!
    #dataset_save_dir has three folders: img, features and coords
    list_feature_files = os.listdir(features_dir)
    output = pd.DataFrame(columns = ['slide_name', 'c1_undefended', 'c2_undefended', 'c1_defended', 'c2_defended', 'natural_used', 'natural_total', 'adv_used', 'adv_total'])
    for slide_name in tqdm(val_list):
        full_feature_path = os.path.join(features_dir, '{}.h5'.format(slide_name))
        try:
            with h5py.File(full_feature_path, 'r') as hdf5_file:
                features = hdf5_file['features'][:]
                coords = hdf5_file['coords'][:]
        except:
            pring('failre on {}'.format(slide_name))
            continue
        features_original = torch.from_numpy(features).cuda()
        #load adv
        if not attention_comp:
            adv_feature_path = os.path.join(adv_dataset_dir, 'features', '{}_adv.tr'.format(slide_name))
            adv_coords_path = os.path.join(adv_dataset_dir, 'coords', '{}_adv.p'.format(slide_name))
            adv_img_path = os.path.join(adv_dataset_dir, 'img', '{}_adv.tr'.format(slide_name))
        else:
            adv_feature_path = os.path.join(adv_dataset_dir, 'features', '{}_att.tr'.format(slide_name))
            adv_coords_path = os.path.join(adv_dataset_dir, 'coords', '{}_adv.p'.format(slide_name))
            adv_img_path = os.path.join(adv_dataset_dir, 'img', '{}_att.tr'.format(slide_name))
        try:
            idx_coords = pickle.load(open(adv_coords_path, 'rb'))
            adv_features = torch.load(adv_feature_path)
            adv_img = torch.load(adv_img_path)
            if full_pipe:
                adv_features = feature_extractor(adv_img)
        except:
            print('Failed on {}'.format(slide_name))
            continue
        #hacki implemeantation
        k = len(idx_coords)
        l, y, y1, a, r = forward_ins(model, features_original, coords)
        total = features_original.shape[0]
        adv_total = 10
        natural_total = total - adv_total
        adv_used = 0
        natural_used = 0
        c1 = y[0,0].detach().cpu().item()
        c2 = y[0,1].detach().cpu().item()
        reference_row = reference[reference['slide_name'] == slide_name]
        reference_label = reference_row['label'].item()
        if reference_label == 'LUAD':
            target = 0
        else:
            target = 1
        
        res, ind = a[target,0:].topk(k)
        if not natural:
            for i in range(k):
                features_original[ind[i], :] = adv_features[i]
    
        if defense == 'std_filter' or defense == 'topk_filter':
            l_defended, y_defended, y1, a, r = model.forward_with_defense(features_original, defense, d)
            mask = model.get_defense_mask(features_original, defense, d)
            total_used = torch.sum(mask.int())
            adv_mask = mask[ind]
            adv_used = torch.sum(adv_mask.int()).item()
            natural_used = total_used - adv_used
        l_undefended, y_undefended, y1, a, r = model(features_original)
        
        #l, y, y1, a, r = forward_ins(model, features_original, coords)
        if softmax:
            c1_defended = y_defended[0,0].detach().cpu().item()
            c2_defended = y_defended[0,1].detach().cpu().item()
        else:
            c1_adv = l[0,0].detach().cpu().item()
            c2_adv = l[0,1].detach().cpu().item()
        
        c1_undefended = y_undefended[0,0].detach().cpu().item()
        c2_undefended = y_undefended[0,1].detach().cpu().item()
        output = output.append(pd.DataFrame([[slide_name, c1_undefended, c2_undefended, c1_defended, c2_defended, natural_used, natural_total, adv_used, adv_total]], columns = output.columns))
    return output 

In [None]:
saves_dir = '/home/jupyter/topk_datasets/datasets/'
saves = os.listdir('/home/jupyter/topk_datasets/datasets/')
saves.remove('.ipynb_checkpoints')
compliation_csv = pd.DataFrame(columns = ['save_name', 'accu_undefended', 'auc_undefended', 'accu_defended', 'auc_defended',' natural_tiles_used%', 'adv_tiles_used%'])
#compliation_csv = pd.DataFrame(columns = ['save_name', 'accu_undefended', 'auc_undefended', 'accu_defended', 'auc_defended'])
print(saves)
att_comp = False
#d_list = [1.5, 1.6, 1.7, 1.8, 1.9, 2.0, 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3.0]
#d_list = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]
d_list = [1.5, 1.6, 1.7, 1.8, 1.9, 2.0, 2.1, 2.2, 2.3, 2.4, 2.5, 2.55, 2.6, 2.65, 2.7, 2.75, 2.8, 2.85, 2.9, 2.95, 3.0, 3.05, 3.1, 3.15, 3.2, 3.25, 3.3, 3.35, 3.4, 3.45, 3.5, 3.55, 3.6, 3.65, 3.7, 3.75, 3.8, 3.85, 3.9, 3.95, 4.0]
for d_curr in d_list:
    output = defense_validation(model_defense, feature_extractor, features_dir = '/home/ext_yao_gary_mayo_edu/lung-features/', adv_dataset_dir = '/home/jupyter/topk_datasets/datasets/{}'.format('e2e'), reference = TCGA_Lung_csv, val_list = val_test_file_list
                                   ,defense = 'std_filter', d = d_curr, attention_comp = att_comp)
    accu_und, auc_und  = accu(output, answer_rows=['c1_undefended', 'c2_undefended'], reference = TCGA_Lung_csv), auc(output, answer_rows=['c1_undefended', 'c2_undefended'], reference = TCGA_Lung_csv) 
    accu_def, auc_def = accu(output, answer_rows=['c1_defended', 'c2_defended'], reference = TCGA_Lung_csv), auc(output, answer_rows=['c1_defended', 'c2_defended'], reference = TCGA_Lung_csv)
    
    natural_used, natural_total = tile_accu(output, target_columns = ['natural_used', 'natural_total'])
    adv_used, adv_total = tile_accu(output, target_columns = ['adv_used', 'adv_total'])
    #print(natural_used, natural_total, adv_used, adv_total)
    compliation_csv = compliation_csv.append(pd.DataFrame([['{}'.format(d_curr), accu_und, auc_und, accu_def, auc_def, (natural_used/natural_total), (adv_used/adv_total)]], columns = compliation_csv.columns))
    #compliation_csv = compliation_csv.append(pd.DataFrame([['{}'.format(d_curr), accu_und, auc_und, accu_def, auc_def]], columns = compliation_csv.columns))
    compliation_csv.to_csv('/home/jupyter/defensive_complicatione2e_noatcomp.csv')


In [None]:
saves_dir = '/home/jupyter/topk_datasets/datasets/'
saves = os.listdir('/home/jupyter/topk_datasets/datasets/')
saves.remove('.ipynb_checkpoints')
compliation_csv = pd.DataFrame(columns = ['save_name', 'accu_undefended', 'auc_undefended', 'accu_defended', 'auc_defended'])
print(saves)
att_comp = False

output = defense_validation(model_defense, feature_extractor, features_dir = '/home/ext_yao_gary_mayo_edu/lung-features/', adv_dataset_dir = '/home/jupyter/topk_datasets/datasets/{}'.format('e1e'), reference = TCGA_Lung_csv, val_list = val_test_file_list
                               ,defense = 'std_filter', d = 3, attention_comp = att_comp, natural = True)
accu_und, auc_und  = accu(output, answer_rows=['c1_undefended', 'c2_undefended'], reference = TCGA_Lung_csv), auc(output, answer_rows=['c1_undefended', 'c2_undefended'], reference = TCGA_Lung_csv) 
accu_def, auc_def = accu(output, answer_rows=['c1_defended', 'c2_defended'], reference = TCGA_Lung_csv), auc(output, answer_rows=['c1_defended', 'c2_defended'], reference = TCGA_Lung_csv)
compliation_csv = compliation_csv.append(pd.DataFrame([['natural', accu_und, auc_und, accu_def, auc_def]], columns = compliation_csv.columns))
compliation_csv.to_csv('/home/jupyter/defensive_complication_noac3.csv')

for save in saves:
    if 'h' in save and not att_comp:
        
        continue
    current_adv_dataset_dir = os.path.join(saves_dir, save)

    output = defense_validation(model_defense, feature_extractor, features_dir = '/home/ext_yao_gary_mayo_edu/lung-features/', adv_dataset_dir = '/home/jupyter/topk_datasets/datasets/{}'.format(save), reference = TCGA_Lung_csv, val_list = val_test_file_list
                               ,defense = 'std_filter', d = 3, attention_comp = att_comp)
    accu_und, auc_und  = accu(output, answer_rows=['c1_undefended', 'c2_undefended'], reference = TCGA_Lung_csv), auc(output, answer_rows=['c1_undefended', 'c2_undefended'], reference = TCGA_Lung_csv) 
    accu_def, auc_def = accu(output, answer_rows=['c1_defended', 'c2_defended'], reference = TCGA_Lung_csv), auc(output, answer_rows=['c1_defended', 'c2_defended'], reference = TCGA_Lung_csv)
    compliation_csv = compliation_csv.append(pd.DataFrame([[save, accu_und, auc_und, accu_def, auc_def]], columns = compliation_csv.columns))
    compliation_csv.to_csv('/home/jupyter/defensive_complication_noac3.csv')

In [None]:
accu(output, TCGA_Lung_csv)

# framework demo

In [None]:
def instance_adv_framework_analytics(model, feature_extractor, reference, slide_name, features_dir, slides_dir, topk = 10):
    
    full_feature_path = os.path.join(features_dir, '{}.h5'.format(slide_name))
    with h5py.File(full_feature_path, 'r') as hdf5_file:
        features = hdf5_file['features'][:]
        coords = hdf5_file['coords'][:]
        features = torch.from_numpy(features).cuda()
        slide_full_path = os.path.join(slides_dir,
                                       '{}.svs'.format(slide_name))
        wsi = openslide.open_slide(slide_full_path)

        reference_row = reference[reference['slide_name'] == slide_name]
        reference_label = reference_row['label'].item()

        slide_resolution = int(reference_row['resolution'])
        
        if reference_label == 'LUAD':
            attack_target = 0
        else:
            attack_target = 1
        
        if slide_resolution == 40:
            patch_size = 512
            downsample = 2
        else:
            patch_size = 256
            downsample = 1
    
    y_collection, a_collection, idx_coord_list, img_collection, features_collection, noise_collection = instance_topk_attack(model, 
        feature_extractor, features, coords, wsi,
        patch_size = patch_size, patch_level = 0, downsample = downsample, target = attack_target, topk = topk, e = 0.01,
        f = 0.01, attention_comp = False, verbose = False)
    
    attack_features = features_collection[0]
    
    feature_diff_list = []
    for i in range(topk):
        idx, coord = idx_coord_list[i]
        original_feature = features[idx, :]
        curr_diff = original_feature - attack_features[i, :]
        feature_diff_list += curr_diff.tolist()
    
        
    diff_array = np.asarray(feature_diff_list)
    plt.hist(diff_array, bins = 250)
    plt.title('Distribution of changes in resnet extracted features natural/topk attack e = 0.01')
    
    
    

In [None]:
features_dir = '/home/ext_yao_gary_mayo_edu/lung-features/'
slides_dir = '/home/ext_yao_gary_mayo_edu/lung-slides'
target_slide = 'TCGA-49-6744-01Z-00-DX2.1982e585-65a4-4330-9140-ccabcdd106f8'
moded_center = instance_adv_framework_analytics(model_defense, feature_extractor, TCGA_Lung_csv, target_slide, features_dir, slides_dir)

In [None]:
def load_adv_instance(model, reference, slide_name, adv_dataset_dir, features_original, attention_comp = False):
        if not attention_comp:
            adv_feature_path = os.path.join(adv_dataset_dir, 'features', '{}_adv.tr'.format(slide_name))
            adv_coords_path = os.path.join(adv_dataset_dir, 'coords', '{}_adv.p'.format(slide_name))
            adv_img_path = os.path.join(adv_dataset_dir, 'img', '{}_adv.tr'.format(slide_name))
        else:
            adv_feature_path = os.path.join(adv_dataset_dir, 'features', '{}_att.tr'.format(slide_name))
            adv_coords_path = os.path.join(adv_dataset_dir, 'coords', '{}_adv.p'.format(slide_name))
            adv_img_path = os.path.join(adv_dataset_dir, 'img', '{}_att.tr'.format(slide_name))
        idx_coords = None
        try:
            idx_coords = pickle.load(open(adv_coords_path, 'rb'))
            adv_features = torch.load(adv_feature_path)
            adv_img = torch.load(adv_img_path)
        except:
            print('Failed on {}'.format(slide_name))
        if idx_coords == None:
            return None
        #hacki implemeantation
        k = len(idx_coords)
        #print(k)
        with torch.no_grad():
            l, y, y1, a, r = model(features_original)

            reference_row = reference[reference['slide_name'] == slide_name]
            reference_label = reference_row['label'].item()
            if reference_label == 'LUAD':
                target = 0
            else:
                target = 1

            res, ind = a[target,0:].topk(k)

            for i in range(k):
                #print(features_original[ind[i]])
                #print(adv_features[i])
                features_original[ind[i], :] = adv_features[i]
                
        return features_original

In [None]:
e = 0.001
shuffle = False
epoch_limit = 10
train_ratio = .8
att_comp = False
adv_threshhold = 1052
def train_epoch_p2(model, mock_model, reference, optimizer, train_list, feature_path, adv_dataset_path, attention_comp = False):
    loss_sum = 0
    results = pd.DataFrame(columns = ['slide_name', 'c1', 'c2'])
    loss_metric = torch.nn.CrossEntropyLoss()
    for slide_name in tqdm(train_list):
        adverarial_ins = False
        if '_adv' in slide_name:
            slide_name = slide_name[:-4]
            adverarial_ins = True
        reference_row = reference[reference['slide_name'] == slide_name]
        try:
            label_eng = reference_row['label']
            if label_eng.item() == 'LUAD':
                label = 0
            else:
                label = 1

            label = torch.Tensor([label])
            label = label.long().cuda()
            full_feature_path = os.path.join(feature_path, '{}.h5'.format(slide_name))

            with h5py.File(full_feature_path, 'r') as hdf5_file:
                features = hdf5_file['features'][:]
                coords = hdf5_file['coords'][:]
        except:
            print('Failed on {}'.format(slide_name))
            continue
        
        features_original = torch.from_numpy(features).cuda()
        
        
        #if current_index is bigger than natural dataset length, that means we are doing an adversarial training instance
        if adverarial_ins:
            features_original = load_adv_instance(mock_model, TCGA_Lung_csv, slide_name, adv_dataset_path, features_original, attention_comp = attention_comp)
            slide_name = slide_name + '_adv'
            
        if features_original == None:
            continue
        model.zero_grad()
        l, y, y1, a, r = model(features_original)
        loss = loss_metric(y, label)
        loss_sum += loss.item()
        loss.backward()
        optimizer.step()
        results = results.append(pd.DataFrame([[slide_name, y[0,0].item(), y[0,1].item()]], columns = results.columns))
    return results, loss_sum
        
    
def validate_epoch_p2(model, mock_model, reference, validate_list, feature_path, adv_dataset_path, adv = False, attention_comp = False):
    results = pd.DataFrame(columns = ['slide_name', 'c1', 'c2'])
    for slide_name in tqdm(validate_list):
        reference_row = reference[reference['slide_name'] == slide_name]
        label_eng = reference_row['label']
        try:
            if label_eng.item() == 'LUAD':
                label = 0
            else:
                label = 1

            full_feature_path = os.path.join(feature_path, '{}.h5'.format(slide_name))
            with h5py.File(full_feature_path, 'r') as hdf5_file:
                features = hdf5_file['features'][:]
                coords = hdf5_file['coords'][:]
        except:
            #print('Failed on {}'.format(slide_name))
            continue
            
        features_original = torch.from_numpy(features).cuda()
        if adv:
            features_original = load_adv_instance(mock_model, TCGA_Lung_csv, slide_name, adv_dataset_path, features_original, attention_comp=attention_comp)
        
        if features_original == None:
            continue
        with torch.no_grad():
            l, y, y1, a, r = model(features_original)
        results = results.append(pd.DataFrame([[slide_name, y[0,0].item(), y[0,1].item()]], columns = results.columns))
    return results

In [None]:
model = CLAM_MB(n_classes=2, dropout=False)

ckpt_path = '/home/ext_yao_gary_mayo_edu/CLAM_saves/100s_2_checkpoint.pt'
ckpt = torch.load(ckpt_path, map_location=torch.device('cpu'))
ckpt_clean = {}
for key in ckpt.keys():
    if 'instance_loss_fn' in key:
        continue
    key_new = key.replace('attention_net.module.3.attention', 'attention_net.module.2.attention')
    ckpt_clean.update({key_new.replace('.module', ''):ckpt[key]})
    
model.load_state_dict(ckpt_clean, strict=True)
model.cuda()
mock_model = CLAM_MB(n_classes=2, dropout=False)

ckpt_path = '/home/ext_yao_gary_mayo_edu/CLAM_saves/100s_2_checkpoint.pt'
ckpt = torch.load(ckpt_path, map_location=torch.device('cpu'))
ckpt_clean = {}
for key in ckpt.keys():
    if 'instance_loss_fn' in key:
        continue
    key_new = key.replace('attention_net.module.3.attention', 'attention_net.module.2.attention')
    ckpt_clean.update({key_new.replace('.module', ''):ckpt[key]})
    
mock_model.load_state_dict(ckpt_clean, strict=True)
mock_model.cuda()

import torch.optim as optim
optimizer = optim.SGD([
                {'params': model.parameters()},
            ], lr=1e-2, momentum=0.9)

In [None]:
def gen_listed_split(num_samples = 1052, train_ratio = .8, adv = True):
    num_train_samples = int(train_ratio * num_samples)
    num_val_samples = num_samples - num_train_samples
    base_range = np.arange(0, num_samples)
    np.random.shuffle(base_range)
    train_range = base_range[0:num_train_samples]
    validate_range = base_range[num_train_samples:]
    if adv:
        train_range = np.append(train_range, train_range +num_samples)
        
    return train_range, validate_range

In [None]:
import random
train_file_list_with_adv = []
for file in train_file_list:
    train_file_list_with_adv.append(file)
    train_file_list_with_adv.append(file + '_adv')
random.shuffle(train_file_list_with_adv)
random.shuffle(val_test_file_list)

In [None]:
accounting_pd = pd.DataFrame(columns = ['epoch', 'loss', 'train_accu', 'accu', 'adv_accu'])
for epoch in range(10):
    results = validate_epoch_p2(model, mock_model, TCGA_Lung_csv, val_test_file_list, feature_path = '/home/ext_yao_gary_mayo_edu/lung-features/', adv_dataset_path='/home/jupyter/topk_datasets/datasets/e1e_whole_final/', attention_comp = True)
    accu_nat, auc_nat = accu(results, TCGA_Lung_csv), auc(results, TCGA_Lung_csv)
    results = validate_epoch_p2(model, mock_model, TCGA_Lung_csv, val_test_file_list, feature_path = '/home/ext_yao_gary_mayo_edu/lung-features/', adv_dataset_path='/home/jupyter/topk_datasets/datasets/e1e_whole_final/', adv = True, attention_comp = True)
    accu_adv, auc_adv = accu(results, TCGA_Lung_csv), auc(results, TCGA_Lung_csv)
    results, loss_sum = train_epoch_p2(model, mock_model, TCGA_Lung_csv, optimizer, train_file_list_with_adv, feature_path = '/home/ext_yao_gary_mayo_edu/lung-features/', adv_dataset_path='/home/jupyter/topk_datasets/datasets/e1e_whole_final/', attention_comp=True)
    accounting_pd = accounting_pd.append(pd.DataFrame([[epoch, loss_sum, 0, accu_nat, accu_adv]], columns = accounting_pd.columns))
    accounting_pd.to_csv('/home/jupyter/adv_training_ac.csv')
    

In [None]:
def whole_pipeline_adv_pass(model, feature_extractor, reference, slide_name, adv_dataset_dir, features_original, attention_comp = False):
    if not attention_comp:
        adv_feature_path = os.path.join(adv_dataset_dir, 'features', '{}_adv.tr'.format(slide_name))
        adv_coords_path = os.path.join(adv_dataset_dir, 'coords', '{}_adv.p'.format(slide_name))
        adv_img_path = os.path.join(adv_dataset_dir, 'img', '{}_adv.tr'.format(slide_name))
    else:
        adv_feature_path = os.path.join(adv_dataset_dir, 'features', '{}_att.tr'.format(slide_name))
        adv_coords_path = os.path.join(adv_dataset_dir, 'coords', '{}_adv.p'.format(slide_name))
        adv_img_path = os.path.join(adv_dataset_dir, 'img', '{}_att.tr'.format(slide_name))
        idx_coords = None
    idx_coords = pickle.load(open(adv_coords_path, 'rb'))
    adv_features = torch.load(adv_feature_path)
    adv_img = torch.load(adv_img_path)
    if idx_coords == None:
        return None
        
    
    adv_features = feature_extractor(adv_img)
    k = len(idx_coords)
    for (i, (row, loc)) in enumerate(idx_coords):
        features_original[row, :] = adv_features[i,:]
        
    l, y, y1, a ,r = model(features_original)
    return l, y, y1, a, r
    

In [None]:
e = 0.001
shuffle = False
epoch_limit = 10
train_ratio = .8
adv_threshhold = 1052
def train_epoch(model, feature_extractor, reference, optimizer, train_list, feature_path, adv_dataset_path, attention_comp = False):
    loss_sum = 0
    results = pd.DataFrame(columns = ['slide_name', 'c1', 'c2'])
    loss_metric = torch.nn.CrossEntropyLoss()
    for slide_name in tqdm(train_list):
        adverarial_ins = False
        if '_adv' in slide_name:
            slide_name = slide_name[:-4]
            adverarial_ins = True
        reference_row = reference[reference['slide_name'] == slide_name]
        try:
            label_eng = reference_row['label']
            if label_eng.item() == 'LUAD':
                label = 0
            else:
                label = 1

            label = torch.Tensor([label])
            label = label.long().cuda()
            full_feature_path = os.path.join(feature_path, '{}.h5'.format(slide_name))

            with h5py.File(full_feature_path, 'r') as hdf5_file:
                features = hdf5_file['features'][:]
                coords = hdf5_file['coords'][:]
            if adverarial_ins:
                adv_coords_path = os.path.join(adv_dataset_path, 'coords', '{}_adv.p'.format(slide_name))
                idx_coords = pickle.load(open(adv_coords_path, 'rb'))
        except:
            print('Failed on {}'.format(slide_name))
            continue
        
        features_original = torch.from_numpy(features).cuda()
        
        
        #if current_index is bigger than natural dataset length, that means we are doing an adversarial training instance
            
        if features_original == None:
            continue
        model.zero_grad()
        feature_extractor.zero_grad()
        if adverarial_ins:
            l, y, y1, a, r = whole_pipeline_adv_pass(model, feature_extractor, TCGA_Lung_csv, slide_name, adv_dataset_path, features_original, attention_comp=attention_comp)  
        else:
            l, y, y1, a, r = model(features_original)
        
        loss = loss_metric(y, label)
        loss_sum += loss.item()
        loss.backward()
        optimizer.step()
        results = results.append(pd.DataFrame([[slide_name, y[0,0].item(), y[0,1].item()]], columns = results.columns))
    return results, loss_sum
        
    
def validate_epoch(model, feature_extractor, reference, validate_list, feature_path, adv_dataset_path, adv = False, attention_comp = False):
    results = pd.DataFrame(columns = ['slide_name', 'c1', 'c2'])
    for slide_name in tqdm(validate_list):
        reference_row = reference[reference['slide_name'] == slide_name]
        label_eng = reference_row['label']
        try:
            if label_eng.item() == 'LUAD':
                label = 0
            else:
                label = 1

            full_feature_path = os.path.join(feature_path, '{}.h5'.format(slide_name))
            with h5py.File(full_feature_path, 'r') as hdf5_file:
                features = hdf5_file['features'][:]
                coords = hdf5_file['coords'][:]
        except:
            #print('Failed on {}'.format(slide_name))
            continue
            
        features_original = torch.from_numpy(features).cuda()
        with torch.no_grad():
            if adv:
                l, y, y1, a, r = whole_pipeline_adv_pass(model, feature_extractor, TCGA_Lung_csv, slide_name, adv_dataset_path, features_original, attention_comp=attention_comp)  
            else:
                l, y, y1, a, r = model(features_original)
        results = results.append(pd.DataFrame([[slide_name, y[0,0].item(), y[0,1].item()]], columns = results.columns))
    return results

In [None]:
model = CLAM_MB(n_classes=2, dropout=False)

ckpt_path = '/home/ext_yao_gary_mayo_edu/CLAM_saves/100s_2_checkpoint.pt'
ckpt = torch.load(ckpt_path, map_location=torch.device('cpu'))
ckpt_clean = {}
for key in ckpt.keys():
    if 'instance_loss_fn' in key:
        continue
    key_new = key.replace('attention_net.module.3.attention', 'attention_net.module.2.attention')
    ckpt_clean.update({key_new.replace('.module', ''):ckpt[key]})
    
model.load_state_dict(ckpt_clean, strict=True)
model.cuda()

feature_extractor = resnet50_baseline(pretrained=True)
feature_extractor.eval()
feature_extractor.cuda()

import torch.optim as optim
all_parameters = list(model.parameters()) + list(feature_extractor.parameters())
optimizer = optim.SGD([
                {'params': all_parameters},
            ], lr=1e-2, momentum=0.9)

In [None]:
accounting_pd = pd.DataFrame(columns = ['epoch', 'loss', 'train_accu', 'accu', 'adv_accu'])
attention_comp = False
for epoch in range(10):
    results = validate_epoch(model, feature_extractor, TCGA_Lung_csv, val_test_file_list, feature_path = '/home/ext_yao_gary_mayo_edu/lung-features/', adv_dataset_path='/home/jupyter/topk_datasets/datasets/e1e_whole_final/', attention_comp = attention_comp)
    accu_nat, auc_nat = accu(results, TCGA_Lung_csv), auc(results, TCGA_Lung_csv)
    results = validate_epoch(model, feature_extractor, TCGA_Lung_csv, val_test_file_list, feature_path = '/home/ext_yao_gary_mayo_edu/lung-features/', adv_dataset_path='/home/jupyter/topk_datasets/datasets/e1e_whole_final/', adv = True, attention_comp = attention_comp)
    accu_adv, auc_adv = accu(results, TCGA_Lung_csv), auc(results, TCGA_Lung_csv)
    results, loss_sum = train_epoch(model, feature_extractor, TCGA_Lung_csv, optimizer, train_file_list_with_adv, feature_path = '/home/ext_yao_gary_mayo_edu/lung-features/', adv_dataset_path='/home/jupyter/topk_datasets/datasets/e1e_whole_final/', attention_comp = attention_comp)
    accounting_pd = accounting_pd.append(pd.DataFrame([[epoch, loss_sum, 0, accu_nat, accu_adv]], columns = accounting_pd.columns))
    accounting_pd.to_csv('/home/jupyter/adv_whole_training_noacp2.csv')
    

In [None]:
example_coord_path = '/home/jupyter/topk_datasets/datasets/e1e_whole_final/coords/TCGA-XC-AA0X-01Z-00-DX1.61A34BE0-F16B-4EC1-8E7F-7BF94F6629F4_adv.p' 
example_coord = pickle.load(open(example_coord_path, 'rb'))
print(example_coord)

In [None]:
d