# Import Libraries

In [5]:
import argparse
import math
import random
import os
import re
import sys
import numpy as np
import torch
from torch import nn, autograd, optim
from torch.nn import functional as F
from torch.utils import data
import torchvision
from torchvision import utils
from tqdm import tqdm
import wandb
from pathlib import Path
import glob
import yaml
from PIL import Image
from matplotlib.widgets import Slider
import matplotlib.patches as patches
from matplotlib import pyplot as plt

sys.path.insert(0,"/ocean/projects/asc170022p/nmurali/projects/CounterfactualExplainer/MIMICCX-Chest-Explainer/stylegan2Pytorch")
from distributed import (
    get_rank,
    synchronize,
    reduce_loss_dict,
    reduce_sum,
    get_world_size,
)
import pdb
from op import conv2d_gradfix
from non_leaking import augment, AdaptiveAugment
from torch.utils.tensorboard import SummaryWriter

sys.path.insert(0,"/ocean/projects/asc170022p/nmurali/projects/CounterfactualExplainer/MIMICCX-Chest-Explainer/Classifier/torchxrayvision_")
from swagan_updatedEGC import Generator, Discriminator

# GUI libraries
import ipywidgets as widgets
from IPython.display import clear_output
from random import randint

sys.path.insert(0,'/jet/home/nmurali/asc170022p/nmurali/projects/augmentation_by_explanation_eccv22/Classifier')
import datasets

sys.path.insert(0,"/ocean/projects/asc170022p/nmurali/projects/CounterfactualExplainer/MIMICCX-Chest-Explainer/Classifier/torchxrayvision_")
import torchxrayvision as xrv

Using /jet/home/nmurali/.cache/torch_extensions/py39_cu113 as PyTorch extensions root...
No modifications detected for re-loaded extension module upfirdn2d, skipping build step...
Loading extension module upfirdn2d...


In [6]:
# user hyperparams
config_file = '/jet/home/nmurali/asc170022p/nmurali/projects/augmentation_by_explanation_eccv22/Configs/Classifier/DenseNet_AFHQ.yaml'
clf_ckpt_path = '/jet/home/nmurali/asc170022p/nmurali/projects/augmentation_by_explanation_eccv22/Output/AFHQ/Classifier_Seed_1234_Dropout_0.0_LS_False_MU_False_FL_False_afhq_ln0p38/AFHQ-densenet169-AFHQ_256-best-auc0.5962.pt'
gan_ckpt = '/jet/home/nmurali/asc170022p/nmurali/projects/augmentation_by_explanation_eccv22/Output/StyleGAN/AFHQ_ln0p38/checkpoint/099000.pt'

In [7]:
config = yaml.safe_load(open(config_file))
config['class_names'] = config['class_names'].split(',')
pathologies = config['class_names']

# Support Functions

In [8]:
def print_output(pred):
    output = ''
    for i in range(0, pred.shape[0]):
        #if pred[i] > 0.5:
        output += pathologies[i] + ': ' + str(pred[i]) + '\n'
    return output
def requires_grad(model, flag=True):
    for p in model.parameters():
        p.requires_grad = flag

class expand_greyscale(object):
    def __init__(self):
        self.num_target_channels = 3
    def __call__(self, tensor):
        channels = tensor.shape[0]        
        if channels == self.num_target_channels:
            return tensor
        elif channels == 1:
            color = tensor.expand(3, -1, -1)
            return color

class center_crop(object):
    def crop_center(self, img):
        _, y, x = img.shape
        crop_size = np.min([y,x])
        startx = x // 2 - (crop_size // 2)
        starty = y // 2 - (crop_size // 2)
        return img[:, starty:starty + crop_size, startx:startx + crop_size]
    
    def __call__(self, img):
        return self.crop_center(img)

class normalize(object):
    def normalize_(self, img, maxval=255):
        img = (img)/(maxval)
        return img
    
    def __call__(self, img):
        return self.normalize_(img)

In [9]:
if config['dataset'] == 'AFHQ':
    transforms = torchvision.transforms.Compose([
        torchvision.transforms.ToPILImage(),\
        torchvision.transforms.Resize((config['size'], config['size'])), 
        torchvision.transforms.ToTensor()
    ])
    
    dataset = datasets.AFHQ_Dataset(csvpath=config['data_file'], class_names=config['class_names'], transform=transforms, seed=config['seed'])
elif config['dataset'] == 'HAM':
    transforms = torchvision.transforms.Compose([
        torchvision.transforms.ToPILImage(),\
        torchvision.transforms.Resize((config['size'], config['size'])), 
        torchvision.transforms.ToTensor()
    ])
    dataset = datasets.HAM_Dataset(imgpath=config['imgpath'],csvpath=config['data_file'],class_names=config['class_names'],unique_patients=False, transform=transforms, seed=config['seed'])

elif config['dataset'] == 'Dirty_MNIST':
    transforms = torchvision.transforms.Compose([
        torchvision.transforms.ToPILImage(),\
        torchvision.transforms.Resize((config['size'], config['size'])), 
        torchvision.transforms.ToTensor()
    ])
    train_inds = datasets.DIRTY_MNIST_Dataset(csvpath=config['data_file'], transform=transforms, class_names=config['class_names'], seed=config['seed'])
    test_inds = datasets.DIRTY_MNIST_Dataset(csvpath=config['data_file_test'], transform=transforms, class_names=config['class_names'], seed=config['seed'])
    dataset = None

elif config['dataset'] == 'CelebA':
    transforms = torchvision.transforms.Compose([
        torchvision.transforms.ToPILImage(),
        torchvision.transforms.Resize((config['size'], config['size'])),
        torchvision.transforms.CenterCrop(config['center_crop']),
        torchvision.transforms.Resize((config['size'], config['size'])),
        torchvision.transforms.ToTensor()
    ])
    dataset = datasets.CelebA(imgpath=config['imgpath'],  csvpath=config['data_file'], class_names=config['class_names'], transform=transforms, seed=config['seed'])

elif config['dataset'] == 'Stanford-CHEX':
    transforms = torchvision.transforms.Compose([
        #torchvision.transforms.ToPILImage(),
        torchvision.transforms.Resize((config['size'], config['size'])),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Lambda(center_crop()),
        torchvision.transforms.Lambda(normalize())
    ])
    train_inds = datasets.CheX_Dataset(imgpath=config['imgpath'], csvpath=config['data_file'], class_names=config['class_names'], transform=transforms, seed=config['seed'])
    test_inds = datasets.CheX_Dataset(imgpath=config['imgpath'], csvpath=config['data_file_test'], class_names=config['class_names'], transform=transforms, seed=config['seed'])
    dataset = None

elif config['dataset'] == 'MIMIC-CXR':
    transforms = torchvision.transforms.Compose([
        #torchvision.transforms.ToPILImage(),
        torchvision.transforms.Resize((config['size'], config['size'])),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Lambda(center_crop()),
        torchvision.transforms.Lambda(normalize())
    ])
    dataset = datasets.MIMIC_Dataset(imgpath=config['imgpath'], csvpath=config['data_file'], class_names=config['class_names'], transform=transforms, seed=config['seed'])

In [10]:
# dataloader
loader = data.DataLoader(
    dataset,
    batch_size=15,
    #sampler=data.SequentialSampler(dataset),
    drop_last=True,
    shuffle=True
)

In [11]:
# classifier

classifier = xrv.models.DenseNet(num_classes=config['num_classes'], in_channels=config['channel'], drop_rate = config['drop_rate'], \
                                 weights = clf_ckpt_path, return_logit=True,\
                                 **xrv.models.get_densenet_params(config['model'])).to("cuda") 

weights_filename_local:  /jet/home/nmurali/asc170022p/nmurali/projects/augmentation_by_explanation_eccv22/Output/AFHQ/Classifier_Seed_1234_Dropout_0.0_LS_False_MU_False_FL_False_afhq_ln0p38/AFHQ-densenet169-AFHQ_256-best-auc0.5962.pt
........
model loaded /jet/home/nmurali/asc170022p/nmurali/projects/augmentation_by_explanation_eccv22/Output/AFHQ/Classifier_Seed_1234_Dropout_0.0_LS_False_MU_False_FL_False_afhq_ln0p38/AFHQ-densenet169-AFHQ_256-best-auc0.5962.pt


In [12]:
# GAN

generator = Generator(config['size'], 512, config['num_classes'], 8, channel_multiplier=2).to("cuda")
discriminator = Discriminator(config['size'], channel_multiplier=2, concate_size=0).to("cuda")

gan_ckpt = torch.load(gan_ckpt, map_location=lambda storage, loc: storage)
generator.load_state_dict(gan_ckpt["g"])
discriminator.load_state_dict(gan_ckpt["d"])



<All keys matched successfully>

# GUI

In [13]:
# widget functions (for event handling)

def stylegan(img_id, attr1, val1):
    
    requires_grad(generator, False)
    requires_grad(discriminator, False)
    requires_grad(classifier, False)

    real_img = dataset[int(img_id)]['img'].unsqueeze(0)
    real_img = real_img.to("cuda")
    real_pred_cls, clf_feats_real = classifier(real_img)
    real_pred_cls = torch.sigmoid(real_pred_cls) 
    real_d = discriminator(real_img, clf_feats_real)
    real_d = torch.sigmoid(real_d)

    real_img1 = np.asarray(real_img.detach().cpu())
    real_img1 = np.moveaxis(real_img1, 1, 3)
    real_pred_cls1 = np.round(np.asarray(real_pred_cls.detach().cpu()),4)
    real_d = np.round(np.asarray(real_d.detach().cpu()),4)
    attr_idx = pathologies.index(attr1)

    # first plot
    fig = plt.figure(figsize=(20,40))
    plt.subplot(1,2,1)
    img = real_img1[0]
    plt.imshow((img-img.min())/(img.max()-img.min()))
    plt.title('clf:%.2f, D:%.2f' %(real_pred_cls1[0][attr_idx],real_d[0][0]), fontsize=20)
    plt.xticks([])
    plt.yticks([])
    plt.ylabel(real_d[0][0])
    
    # create input vec (condition) for gan
    vec = real_pred_cls1[0]
    vec[attr_idx] = val1
    vec = np.repeat(np.expand_dims(vec,0),1,axis=0)
    vec = torch.Tensor(vec)
    vec = vec.to("cuda")

    recon_img, real_img_latent = generator(real_img, vec, return_latents=True)
    _, recon_img_latent = generator(recon_img, vec, return_latents=True)
    
    # calculate cosine similarity between latent vectors
    real_img_latent = torch.flatten(real_img_latent)
    recon_img_latent = torch.flatten(recon_img_latent)
    cos = nn.CosineSimilarity(dim=0, eps=1e-6)
    cos_sim = cos(real_img_latent, recon_img_latent)
    
    
    recon_pred_cls, clf_feats_recon = classifier(recon_img)
    recon_pred_cls = torch.sigmoid(recon_pred_cls)
    fake_d = discriminator(recon_img, clf_feats_recon)
    fake_d = torch.sigmoid(fake_d)
    
    recon_img1 = np.asarray(recon_img.detach().cpu())
    recon_img1 = np.moveaxis(recon_img1, 1, 3)
    recon_pred_cls1 = np.round(np.asarray(recon_pred_cls.detach().cpu()), 4)
    fake_d = np.round(np.asarray(fake_d.detach().cpu()),4)

    plt.subplot(1,2,2)
    img = recon_img1[0]
    plt.imshow((img-img.min())/(img.max()-img.min()))
    plt.title('clf:%.2f, D:%.2f, Cos Sim:%f' %(recon_pred_cls1[0][attr_idx],fake_d[0][0],cos_sim), fontsize=20)
    plt.xticks([])
    plt.yticks([])
    plt.ylabel(fake_d[0][0])
    plt.show()
    
    
def change_img(btn):
    text_box.value = str(randint(0,len(dataset)-1))

In [14]:
# widgets
menu1 = widgets.Dropdown(
    options=config['class_names'],
    value='cat',
    description='Attribute-1',
    disabled=False,
)


slider1 = widgets.FloatSlider(value=0.0,
                              min=0,
                              max=1,
                              step=0.1,
                              description='Attribute-1',
                              disabled=False,
                              continuous_update=False,
                              orientation='horizontal',
                              readout=True,
                              readout_format='0.2f',
                              msg_throttle=1)


button = widgets.Button(
    description='Change Image',
    disabled=False,
    button_style='', # 'success', 'info', 'warning', 'danger' or ''
    tooltip='Click me',
    icon='check' # (FontAwesome names without the `fa-` prefix)
)

text_box = widgets.Text(
    value='0',
    placeholder='',
    description='Image ID',
    disabled=True
)

In [15]:
# Interactive GUI

widgets.interact(stylegan, img_id=text_box, attr1=menu1, val1=slider1);
display(button)
button.on_click(change_img)

interactive(children=(Text(value='0', description='Image ID', disabled=True, placeholder=''), Dropdown(descrip…

Button(description='Change Image', icon='check', style=ButtonStyle(), tooltip='Click me')

# Classifier Consistency

In [None]:
requires_grad(generator, False)
requires_grad(discriminator, False)
requires_grad(classifier, False)

current_class = 1 # 1:Cardiomegaly, 4:Edema, 9:Pleural Effusion
img_id = 7268

real_img = np.asarray(dataset[int(img_id)]['img'].unsqueeze(0).detach().cpu())
real_img = np.repeat(real_img, 100,axis=0)
real_img = torch.from_numpy(real_img)
real_img = real_img.to("cuda")

real_pred_cls, _ = classifier(real_img)
real_pred_cls = torch.sigmoid(real_pred_cls) 

cond = np.asarray([0.01 * i for i in range(100)])
real_pred_cls_npy = np.asarray(real_pred_cls.detach().cpu())
real_pred_cls_npy[:,current_class] = cond
current_cond  = torch.from_numpy(real_pred_cls_npy)
current_cond = current_cond.to("cuda")

recon_img, _ = generator(real_img, current_cond, return_latents=False)
recon_pred_cls, _ = classifier(recon_img)
recon_pred_cls = torch.sigmoid(recon_pred_cls)

In [None]:
batch_size = 15
n = batch_size
cond = np.asarray([0.1 * i for i in range(10)])
cond = np.repeat(cond,n,axis=0)
cond = np.reshape(cond, [10,n])
cond = np.transpose(cond)
cond = np.ravel(cond)
cond.shape

In [None]:
real_pred_cls_npy[:,current_class].shape

In [None]:
# for celeba
requires_grad(generator, False)
requires_grad(discriminator, False)
requires_grad(classifier, False)
all_real_pred = np.empty([0])
all_fake_pred = np.empty([0])
all_cond = np.empty([0])
current_class = 1 # 1:Cardiomegaly, 4:Edema, 9:Pleural Effusion
counter = 0
for batch in tqdm(loader):
    real_img = np.asarray(batch['img'].detach().cpu())
    real_img = np.repeat(real_img, 10,axis=0)
    real_img = torch.from_numpy(real_img)
    real_img = real_img.to("cuda")
    real_pred_cls, _ = classifier(real_img)
    real_pred_cls = torch.sigmoid(real_pred_cls) 
    real_pred_cls_npy = np.asarray(real_pred_cls.detach().cpu())
    real_pred_cls_npy[:,current_class] = cond
    current_cond  = torch.from_numpy(real_pred_cls_npy)
    current_cond = current_cond.to("cuda")
    recon_img, _ = generator(real_img, current_cond, return_latents=False)
    recon_pred_cls, _ = classifier(recon_img)
    recon_pred_cls = torch.sigmoid(recon_pred_cls)
    if all_real_pred.shape[0] == 0:
        all_real_pred = np.asarray(real_pred_cls.detach().cpu())
        all_cond = np.asarray(current_cond.detach().cpu())
        all_fake_pred = np.asarray(recon_pred_cls.detach().cpu())
    else:
        all_real_pred = np.append(all_real_pred, np.asarray(real_pred_cls.detach().cpu()),axis=0)
        all_fake_pred = np.append(all_fake_pred,np.asarray(recon_pred_cls.detach().cpu()) ,axis=0)
        all_cond = np.append(all_cond,  np.asarray(current_cond.detach().cpu()),axis=0)
    counter += 1
    if counter == 15:
        break
print(all_real_pred.shape, all_cond.shape, all_fake_pred.shape)

In [None]:
real_pred_cls[1].shape

In [None]:
all_real_pred2 = all_real_pred
all_cond2 = all_cond
all_fake_pred2 = all_fake_pred

In [None]:
all_real_pred.shape

In [None]:
all_real_pred = np.reshape(all_real_pred, [-1,10,6])
all_cond = np.reshape(all_cond, [-1,10,6])
all_fake_pred = np.reshape(all_fake_pred, [-1,10,6])

In [None]:
all_real_pred = np.mean(all_real_pred, axis=1)
all_real_pred.shape

In [None]:
all_fake_pred = np.reshape(all_fake_pred, [-1,6])

In [None]:
all_cond = np.reshape(all_cond, [-1,6])

In [None]:
all_real_pred.shape, all_cond.shape

In [None]:
for c in [current_class]:
    print("Current Class: ", c)
    bins =  np.asarray(all_real_pred[:,c]*10).astype(int)
    print(np.unique(bins,return_counts=True))

    target_bin = np.asarray(all_cond[:,c]*10).astype(int)
    source_bin = np.repeat(bins, repeats=10)
    source_pred = np.repeat(all_real_pred[:,c], repeats=10)
    target_pred = all_fake_pred[:,c]
    delta = target_bin-source_bin
    print(target_bin.shape, source_bin.shape, source_pred.shape,target_pred.shape,delta.shape)
    print(np.min(delta), np.max(delta))

    real_p = target_bin * 0.1
    real_p_ = (target_bin+1) * 0.1
    real_p = (real_p + real_p_)/2
    fake_q = target_pred
    
    from matplotlib import cm
    colors = cm.get_cmap('viridis', 5)
    newcolors = colors(np.linspace(0, 1, 5))
    import seaborn as sns
    sns.set(style="white")
    sns.set_context("notebook", font_scale=2.5, rc={"lines.linewidth": 2})
    fig = plt.figure(figsize = (6,6))
    names = ['0.0-0.2', '', '0.2-0.4', '0.30-0.40', '0.4-0.6', '0.50-0.60', \
             '0.6-0.8', '0.70-0.80', '0.8-1.0', '0.90-1.00']
    makrker_size = [10,0,9,0,13,0,13,0,13,0]
    markers = ['o', '','s','','*', '','X', 'X', '<', '>', 's', '*', 'D', 'd', 'X']
    x = np.arange(0.0, 1.0, step=0.1)
    plt.plot(x, x,c='black',linestyle='dashed',alpha=0.5) # dashdot black
    for i in range(0,10,2):
        index = np.where(source_bin == i)
        print(index[0].shape,i)
        target_pred_i = fake_q[index]
        source_pred_i = real_p[index]
        index = np.where(source_bin == i+1)
        target_pred_i = np.append(target_pred_i, fake_q[index],axis=0)
        source_pred_i = np.append(source_pred_i,real_p[index],axis=0)
        target_pred_i = np.reshape(target_pred_i,[-1,10])
        source_pred_i = np.reshape(source_pred_i,[-1,10])
        
        
        mean_t = np.mean(target_pred_i,0)
        sd_t = np.std(target_pred_i,0)           
        mean_s = np.mean(source_pred_i,0)
        sd_s = np.std(source_pred_i,0)
        x_axis = np.arange(0.0, 1.0, step=0.1)
        ax = sns.lineplot(mean_s,mean_t,label=names[i],color=newcolors[int(i/2)],\
                              alpha=1,marker=markers[i], markersize=makrker_size[i])
    plt.xticks(np.arange(0, 1.1, step=0.2))
    plt.yticks(np.arange(0, 1.1, step=0.2))
    ax.get_legend().remove()
    ax.xaxis.set_major_locator(plt.MaxNLocator(3))
    ax.yaxis.set_major_locator(plt.MaxNLocator(3))
    #plt.legend(loc=2)
    plt.xlabel( r'$f(x)+\delta$' )
    plt.ylabel(r'$f(x_{\delta})$')
    plt.title(pathologies[c])
    plt.show()
    

In [None]:
# df = pd.read_csv('/jet/home/nmurali/asc170022p/nmurali/data/mimic/uniform_mimic_clf_preds.csv')

In [None]:
# plt.hist(np.array(df['cardiomegaly']))

In [None]:
# data = AnyDataset(csv_path='/jet/home/nmurali/asc170022p/nmurali/data/mimic/uniform_mimic_clf_preds.csv',pathologies='Cardiomegaly')
# store = Store()
# for bidx,batch in enumerate(tqdm(data)):
#     img = batch['x']
#     img=torch.tensor(img).unsqueeze(0).to("cuda")
#     preds, _ = classifier(img)
#     preds = torch.sigmoid(preds)
#     store.feed([preds[0][1].unsqueeze(0)])
#     if bidx==100:
#         break
    

In [None]:
# plt.hist(np.array((store.lov[0]).cpu().detach()))

In [None]:
# data2 = AnyDataset(csv_path='/ocean/projects/asc170022p/nmurali/data/mimic/all.csv',img_path_field='lateral_512_jpeg',pathologies='Cardiomegaly',transform='mimic')
# img = data2[200000]['x']
# img=torch.tensor(img).unsqueeze(0).to("cuda")
# preds, _ = classifier(img)
# preds = torch.sigmoid(preds)
# print(preds[0][1])

In [None]:
# df = pd.read_csv('/ocean/projects/asc170022p/nmurali/data/mimic/all.csv')
# np.unique(np.array(df['Cardiomegaly'].values.tolist()),return_counts=True)

In [None]:
# df = pd.read_csv('/jet/home/nmurali/asc170022p/nmurali/projects/misc/Pytorch-UNet/mimic_clf_preds.csv')
# np.unique(np.array((df['cardiomegaly']>0.8).values.tolist()),return_counts=True)

In [None]:
# dataset
mimic_csv_path = '/ocean/projects/asc170022p/nmurali/data/mimic/all.csv'
dataset = AnyDataset(csv_path=mimic_csv_path, img_path_field='lateral_512_jpeg', transform='mimic')

pathologies = ['Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung Lesion', 'Lung Opacity', 'Edema', 'Consolidation',
                'Pneumonia', 'Atelectasis', 'Pneumothorax', 'Pleural Effusion', 'Pleural Other', 'Fracture']

# dataloader
loader = data.DataLoader(
    dataset,
    batch_size=64,
    #sampler=data.SequentialSampler(dataset),
    drop_last=True,
    shuffle=True
)

In [None]:
store = Store()
for bidx, batch in enumerate(tqdm(loader)):
    with torch.no_grad():
        img = batch['x']
        img=torch.tensor(img).to("cuda")
        preds, _ = classifier(img)
        preds = torch.sigmoid(preds)
        store.feed([batch['path'],preds[:,1],preds[:,4],preds[:,9]])
        if bidx==3000:
            break
        

In [None]:
plt.hist(np.array((store.lov[1]).cpu().detach()))

In [None]:
final_df = get_df(store.lov,cols=['path','cardiomegaly','edema','pleural_effusion'])

In [None]:
final_df.to_csv('./final_mimic_uniform.csv',index=False)

In [None]:
np.unique((final_df['cardiomegaly']>0.7),return_counts=True)

In [None]:
df_list= []
df1 = final_df
df_list.append(df1[df1['cardiomegaly']<0.2].sample(1000))
df_list.append(df1[(df1['cardiomegaly']>0.2) & (df1['cardiomegaly']<0.4)].sample(1000))
df_list.append(df1[(df1['cardiomegaly']>0.4) & (df1['cardiomegaly']<0.5)].sample(1000))
df_list.append(df1[(df1['cardiomegaly']>0.5) & (df1['cardiomegaly']<0.7)].sample(2000))
df_list.append(df1[(df1['cardiomegaly']>0.7)])
    

In [None]:
df_cardio = pd.concat(df_list)

In [None]:
# df1 = final_df
# df1[(df1['cardiomegaly']>0.000001) & (df1['cardiomegaly']<0.1)]
df_cardio = df_cardio.sample(6136)

In [None]:
df_cardio.to_csv('./final_mimic_uniform_sampled.csv',index=False)

In [None]:
plt.hist(df_cardio['cardiomegaly'])