In [2]:
from collections import OrderedDict
from copy import copy
from activmask.datasets.msdd import MSDDataset
from activmask.datasets.synth import SyntheticDataset
from activmask.datasets.xray import JointDataset
from activmask.models.loss import compare_activations, get_grad_contrast

from glob import glob
from activmask.models.loss import compare_activations, get_gradmask_loss
from activmask.models.resnet import ResNetModel
from skimage import io
from textwrap import wrap
import argparse
import copy
import itertools
import logging
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import pickle
import pprint
import random
import seaborn as sns
import sys
import time
import torch
import torch.nn as nn
import activmask.utils.configuration as configuration
import activmask.utils.monitoring as monitoring

%matplotlib inline

import warnings
warnings.filterwarnings("ignore")

In [8]:
# GLOBALS
EPOCH=499
SEEDS=[1234, 3232, 3221, 9856, 1290, 1987, 3200, 6400, 8888, 451]

In [None]:
# FUNCTIONS TO DEPRECATE
def plot_curves(df, mode='test'):
    """Plots the test AUC over epochs for an experiment table."""
    
    import warnings
    raise warnings.DeprecationWarning("Not using this anymore, using Seaborn Plotting tools instead!")    
    
    EPOCH = 499
    EXPERIMENTS = ['unet', 'ae', 'resnet', 'cnn']
    EXP_NAMES = ["UNet", "Autoencoder", "ResNet 18", "ConvNet"]

    assert mode in ['test', 'train', 'valid']
    
    fig, ax = plt.subplots(figsize=(10, 8), nrows=2, ncols=2)
    ax = ax.ravel()
    
    for i, (exp, exp_name) in enumerate(zip(EXPERIMENTS, EXP_NAMES)):
        
        ls = ['solid', 'dashed', 'dashdot', 'dotted', '-' ,'--' ,'-.' , ':','None' ,' ' ,'' ]
        dfs = df.reindex()

        # Filter by experiment.
        dfs = dfs[dfs['experiment_name'].str.contains(exp)]

        for j, name in enumerate(sorted(dfs.experiment_name.unique())):
 
            thisdata = dfs[dfs["experiment_name"] == name]
            willplot = thisdata.groupby(["epoch"]).mean()
            r = willplot.plot(y="{}_auc".format(mode), 
                              ax=ax[i], label=name, ls=ls[j])

        patches, labels = ax[i].get_legend_handles_labels()
 
        ax[i].legend(patches, labels, loc='lower right', title="experiment")    
        ax[i].set_ylim(-0.05, 1.05)  
        ax[i].set_ylabel("AUC")  
        ax[i].set_xlabel("Epoch")

        ax[i].set_title("{} Experiments".format(exp_name))
        fig.set_tight_layout(tight=True)
        
        
#def get_search_results_at_epoch(df, epoch):
#    """Get the test results at the best epoch."""
#    cols = ['experiment_name', 'best_valid_auc', 'optimizer_lr']
#    groups = ['experiment_name', 'optimizer_lr']
#    
#    return get_results_at_epoch(df, epoch, groups, cols, mode='max')


#def get_mask_results_at_epoch(df, epoch):
#    """
#    Get the results of the maxmask experiments at a given epoch. Not used currently.
#    """
#    groups = ["experiment_name", "epoch", "train_dataset_maxmasks"]
#    cols = ["best_test_auc", "best_epoch"]
#    
#    return get_results_at_epoch(df, epoch, groups, cols)


Functions
---------------

In [9]:
def render_salience_maps(text, idx, sample, sample_blur, modelz, exp_name, blur):

    fig, ax = plt.subplots(nrows=3, ncols=5, figsize=(15, 8), dpi=150, gridspec_kw={'hspace': 0.15, 'wspace': 0.01})
    fig.subplots_adjust(wspace=0.0)
    for a in ax:
        for aa in a:
            aa.axis('off')

    x, target, use_mask = sample
    img = x[0][0].cpu().numpy()
    #img = img / np.max(img)
    seg = x[1][0].cpu().numpy() 
    seg_blur = sample_blur[0][1][0]
    
    x_var = torch.autograd.Variable(x[0].unsqueeze(0),requires_grad=True)
    
    ax[0,0].set_title(str(idx) + "Image")
    ax[0,0].imshow(img, interpolation='none', cmap='Greys_r')
    ax[0,0].axis('off')    
    
    ax[0,1].set_title('Pathology + Mask (blur={})'.format(blur))
    ax[0,1].imshow((1-seg_blur).numpy()+(1-seg), cmap="Greys_r", interpolation='none')
    ax[0,1].axis('on')
    ax[0,1].get_xaxis().set_ticks([])
    ax[0,1].get_yaxis().set_ticks([])

    my_cmap = plt.cm.jet
    my_cmap.set_under('k', alpha=0)
    
    name = "UNet Masked"
    model = modelz[name]
    y_prime, x_prime = model(x_var)
    gradmask = get_gradmask_loss(x_var, y_prime, resnet, torch.tensor(1.),
                                 "contrast").detach().cpu().numpy()[0][0]
    gradmask *= (gradmask>np.percentile(gradmask,90))
    ax[2,4].set_title(name)
    ax[2,4].imshow(img, interpolation='none', cmap='Greys_r')
    ax[2,4].imshow(gradmask, interpolation='none', cmap=my_cmap, clim=[0.000001, gradmask.max()])
    ax[2,4].axis('off')  
    ax[1,4].set_title("Unet Masked Recon")
    ax[1,4].imshow(x_prime.detach().numpy()[0][0], interpolation='none', cmap='Greys_r')
    ax[1,4].axis('off')  
    
    name = "UNet Baseline"
    model = modelz[name]
    y_prime, x_prime = model(x_var)
    gradmask = get_gradmask_loss(x_var, y_prime, resnet, torch.tensor(1.),
                                 "contrast").detach().cpu().numpy()[0][0]
    gradmask *= (gradmask>np.percentile(gradmask,90))
    ax[2,0].set_title(name)
    ax[2,0].imshow(img, interpolation='none', cmap='Greys_r')
    ax[2,0].imshow(gradmask, interpolation='none', cmap=my_cmap, clim=[0.000001, gradmask.max()])
    ax[2,0].axis('off')
    ax[0,4].set_title("Unet Baseline Recon")
    ax[0,4].imshow(x_prime.detach().numpy()[0][0], interpolation='none', cmap='Greys_r')
    ax[0,4].axis('off')  
    
    
    name = "Resnet Baseline"
    model = modelz[name]
    y_prime, x_prime = model(x_var)
    gradmask = get_gradmask_loss(x_var, y_prime, resnet, torch.tensor(1.),
                                 "contrast").detach().cpu().numpy()[0][0]
    gradmask *= (gradmask>np.percentile(gradmask,90))
    ax[1,0].set_title(name)
    ax[1,0].imshow(img, interpolation='none', cmap='Greys_r')
    ax[1,0].imshow(gradmask, interpolation='none', cmap=my_cmap, clim=[0.000001, gradmask.max()])
    ax[1,0].axis('off')  
    
    name = "Resnet ActDiff"
    model = modelz[name]
    y_prime, x_prime = model(x_var)
    gradmask = get_gradmask_loss(x_var, y_prime, resnet, torch.tensor(1.),
                                 "contrast").detach().cpu().numpy()[0][0]
    gradmask *= (gradmask>np.percentile(gradmask,90))
    ax[1,1].set_title(name)
    ax[1,1].imshow(img, interpolation='none', cmap='Greys_r')
    ax[1,1].imshow(gradmask, interpolation='none', cmap=my_cmap, clim=[0.000001, gradmask.max()])
    ax[1,1].axis('off')
    
    name = "UNet ActDiff"
    model = modelz[name]
    y_prime, x_prime = model(x_var)
    gradmask = get_gradmask_loss(x_var, y_prime, resnet, torch.tensor(1.),
                                 "contrast").detach().cpu().numpy()[0][0]
    gradmask *= (gradmask>np.percentile(gradmask,90))
    ax[2,1].set_title(name)
    ax[2,1].imshow(img, interpolation='none', cmap='Greys_r')
    ax[2,1].imshow(gradmask, interpolation='none', cmap=my_cmap, clim=[0.000001, gradmask.max()])
    ax[2,1].axis('off')
    
    name = "Resnet Gradmask"
    model = modelz[name]
    y_prime, x_prime = model(x_var)
    gradmask = get_gradmask_loss(x_var, y_prime, resnet, torch.tensor(1.),
                                 "contrast").detach().cpu().numpy()[0][0]
    gradmask *= (gradmask>np.percentile(gradmask,90))
    ax[1,2].set_title(name)
    ax[1,2].imshow(img, interpolation='none', cmap='Greys_r')
    ax[1,2].imshow(gradmask, interpolation='none', cmap=my_cmap, clim=[0.000001, gradmask.max()])
    ax[1,2].axis('off')
    
    name = "UNet Gradmask"
    model = modelz[name]
    y_prime, x_prime = model(x_var)
    gradmask = get_gradmask_loss(x_var, y_prime, resnet, torch.tensor(1.),
                                 "contrast").detach().cpu().numpy()[0][0]
    gradmask *= (gradmask>np.percentile(gradmask,90))
    ax[2,2].set_title(name)
    ax[2,2].imshow(img, interpolation='none', cmap='Greys_r')
    ax[2,2].imshow(gradmask, interpolation='none', cmap=my_cmap, clim=[0.000001, gradmask.max()])
    ax[2,2].axis('off')
    
    name = "Resnet ActGrad"
    model = modelz[name]
    y_prime, x_prime = model(x_var)
    gradmask = get_gradmask_loss(x_var, y_prime, resnet, torch.tensor(1.),
                                 "contrast").detach().cpu().numpy()[0][0]
    gradmask *= (gradmask>np.percentile(gradmask,90))
    ax[1,3].set_title(name)
    ax[1,3].imshow(img, interpolation='none', cmap='Greys_r')
    ax[1,3].imshow(gradmask, interpolation='none', cmap=my_cmap, clim=[0.000001, gradmask.max()])
    ax[1,3].axis('off')
    
    name = "UNet ActGrad"
    model = modelz[name]
    y_prime, x_prime = model(x_var)
    gradmask = get_gradmask_loss(x_var, y_prime, resnet, torch.tensor(1.),
                                 "contrast").detach().cpu().numpy()[0][0]
    gradmask *= (gradmask>np.percentile(gradmask,90))
    ax[2,3].set_title(name)
    ax[2,3].imshow(img, interpolation='none', cmap='Greys_r')
    ax[2,3].imshow(gradmask, interpolation='none', cmap=my_cmap, clim=[0.000001, gradmask.max()])
    ax[2,3].axis('off')
    
    plt.tight_layout()
    plt.show()

    
def render_salicnece_xray(text, i, sample, models_list, exp_name):
    fig, (ax0, ax1, ax2, ax3, ax4, ax5) = plt.subplots(nrows=1, ncols=6,
                                             figsize=(15, 7), dpi=200)
    x, target, use_mask = sample
    #x = torch.tensor(x)
    x_var = torch.autograd.Variable(x[0].unsqueeze(0),
        requires_grad=True).float()
#     if torch.cuda.is_available():
#         x_var = x_var.cuda()
#         cnn = cnn.cuda()
#         resnet = resnet.cuda()
#         unet = unet.cuda()

    ax0.set_title(str(i) + " Input Image")
    img = x[0][0].cpu().numpy()
    img = img / np.max(img)  # Scales the input image so that the maximum=1.
    seg = x[1][0].cpu().numpy() #* 0.5  # Makes mask bright, but not too bright.
    ax0.imshow(img, interpolation='none', cmap='Greys_r')
    ax0.axis('off')
    
    ax1.set_title("Mask")
    ax1.imshow(1-seg, cmap="Greys_r", interpolation='none')
    ax1.get_xaxis().set_ticks([])
    ax1.get_yaxis().set_ticks([])

    ax2.set_title(models_list[0][0])
    this_model = models_list[0][1]
    y_prime, x_prime = this_model(x_var)
    gradmask = get_gradmask_loss(x_var, y_prime, this_model, torch.tensor(1.),
                                 "contrast").detach().cpu().numpy()[0][0]
    ax2.imshow(np.abs(gradmask), cmap="jet", interpolation='none')
    ax2.axis('off')
    
    ax3.set_title(models_list[1][0])
    this_model = models_list[1][1]
    y_prime, x_prime = this_model(x_var)
    gradmask = get_gradmask_loss(x_var, y_prime, this_model, torch.tensor(1.),
                                 "contrast").detach().cpu().numpy()[0][0]
    ax3.imshow(np.abs(gradmask), cmap="jet", interpolation='none')
    ax3.axis('off')
    
    ax4.set_title(models_list[2][0])
    this_model = models_list[2][1]
    y_prime, x_prime = this_model(x_var)
    gradmask = get_gradmask_loss(x_var, y_prime, this_model, torch.tensor(1.),
                                 "contrast").detach().cpu().numpy()[0][0]
    ax4.imshow(np.abs(gradmask), cmap="jet", interpolation='none')
    ax4.axis('off')

    ax5.set_title("UNet Reconstruction")
    # Fails for models that output a nonsense reconstruction (CNN, ResNet).
    if isinstance(x_prime, torch.Tensor):
        ax5.imshow(x_prime[0][0].detach().cpu().numpy(),
                   interpolation='none', cmap='Greys_r')
        ax5.axis('off')
    else:
        ax5.remove()

    plt.tight_layout()
    plt.show()


def render_mean_grad(text, i, dataset, mdl_baseline, mdl_actdiff, mdl_gradmask, exp_name, 
                     img_size=100):
    """
    Renders the mean saliency map across all inputs in the dataset, from the
    input models for visual comparison.
    """
    fig, (ax0, ax1, ax2, ax3, ax4) = plt.subplots(nrows=1, ncols=5,
                                                  figsize=(18, 12), dpi=72)

    mean_img = np.zeros((img_size, img_size))
    mean_mask = np.zeros((img_size, img_size))
    mean_baseline = np.zeros((img_size, img_size))
    mean_actdiff = np.zeros((img_size, img_size))
    mean_gradmask = np.zeros((img_size, img_size))
    
    for sample in dataset:
        x, target, use_mask = sample
        x_var = torch.autograd.Variable(x[0].unsqueeze(0), requires_grad=True)
    
        mean_img += x[0].detach().cpu().numpy()[0]
        mean_mask += x[1].detach().cpu().numpy()[0]
        
        y_prime, x_prime = mdl_baseline(x_var)
        gradmask = get_gradmask_loss(x_var, y_prime, mdl_baseline, torch.tensor(1.),
                                     "contrast").detach().cpu().numpy()[0][0]
        mean_baseline += np.abs(gradmask)

        y_prime, x_prime = mdl_actdiff(x_var)
        gradmask = get_gradmask_loss(x_var, y_prime, mdl_actdiff, torch.tensor(1.),
                                     "contrast").detach().cpu().numpy()[0][0]
        mean_actdiff += np.abs(gradmask)

        y_prime, x_prime = mdl_gradmask(x_var)
        gradmask = get_gradmask_loss(x_var, y_prime, mdl_gradmask, torch.tensor(1.),
                                     "contrast").detach().cpu().numpy()[0][0]
        mean_gradmask += np.abs(gradmask)

    mean_img /= len(dataset)
    mean_mask /= len(dataset)
    mean_baseline /= len(dataset)
    mean_actdiff /= len(dataset)
    mean_gradmask /= len(dataset)
    
    ax0.set_title(str(i) + " Masked Image")
    ax0.imshow(mean_img, interpolation='none', cmap='Greys_r')
    ax0.axis('off')
    
    ax1.set_title("Mask")
    ax1.imshow(mean_mask, cmap="Greys_r", interpolation='none')
    ax1.axis('off')

    ax2.set_title("Baseline")
    ax2.imshow(mean_baseline, cmap="jet", interpolation='none')
    ax2.axis('off')
    
    ax3.set_title("Actdiff")
    ax3.imshow(mean_actdiff, cmap="jet", interpolation='none')
    ax3.axis('off')
    
    ax4.set_title("Gradmask")
    ax4.imshow(mean_gradmask, cmap="jet", interpolation='none')
    ax4.axis('off')

    plt.tight_layout()
    plt.show()


def convert_dtype(dictionary):
    """
    Converts all entries in all subdictionaries to be of datatype
    [int, float, str]. All non-matching entries are converted to str.
    """
    TYPES = [float, int, str, np.float32, np.float64, bool]
    
    for d in dictionary:
        if type(d) == dict:
            d = convert_dtype(d)
        else:
            if type(dictionary[d]) not in TYPES:
                dictionary[d] = str(dictionary[d])



def get_metrics(path, best=False):
    """
    Loads the outputs of training, if best, only 
    keeps best epoch for each dataframe.
    """
    all_df = [] 
    for f in glob(os.path.join(path, "*/stats/metrics.pkl")):
        d = pickle.load(open(f, "rb"))
        convert_dtype(d)
        d = pd.DataFrame.from_dict(d)
        
        if best:
            best_epoch = d.iloc[-1]['best_epoch']
            d = d[d['epoch'] == best_epoch]
        
        all_df.append(d)

    return(all_df)


def df_cleaner(df):
    """Selects the columns of the metrics dataframe to keep."""
    KEEP = ["auc", "best", "seed", "recon_lambda", "actdiff_lambda", 
            "gradmask_lambda", "epoch", "name", "maxmasks", 'optimizer_lr']
    for col in df.columns:
        if not any([string in col for string in KEEP]):
            del df[col]

    # Experiment name is determined by the configuration file used.
    experiments = df.experiment_name.unique()

    print("resulting df \nshape={} tracking {} experiments, \ncolumns={}".format(
        df.shape, len(experiments), df.columns))
    return(df)


def get_last_results_at_epoch(df, epoch, sig_digits=3):
    """Get the train/test/valid AUC at the final, not best, epoch."""
    #groups = ['experiment_name', 'actdiff_lambda', 'recon_lambda']
    fmt_str = "${0:." + str(sig_digits) + "f}\pm{1:." + str(sig_digits) + "f}$"
    
    groups = ['experiment_name']
    cols = ['train_auc', 'valid_auc', 'best_epoch']
    
    df = get_results_at_epoch(df, epoch, groups, cols)
    df = df.round(sig_digits)
    
    results = [] 
    for a, b in zip(df["train_auc"], df["train_auc_std"]):
        results.append(fmt_str.format(a, b))
    df['train_auc'] = results 
    df = df.drop(['train_auc_std'], axis=1)

    results = [] 
    for a, b in zip(df["valid_auc"], df["valid_auc_std"]):
        results.append(fmt_str.format(a, b))
    df['valid_auc'] = results 
    df = df.drop(['valid_auc_std'], axis=1)
    
    results = [] 
    for a, b in zip(df["best_epoch"], df["best_epoch_std"]):
        results.append(fmt_str.format(a, b))
    df['best_epoch'] = results 
    df = df.drop(['best_epoch_std'], axis=1)
    
    return df


def make_results_table(dfs, sig_digits=3, count=False):
    """
    Merge the best test results for all dataframes submitted. 
    Used to make a results table across datasets.
    """
    
    fmt_str = "${0:." + str(sig_digits) + "f}\pm{1:." + str(sig_digits) + "f}$"
    for i, df in enumerate(dfs):
        tmp_df = copy(df)
        
        # Strip the dataset name out of the experiment name.        
        name = tmp_df['experiment_name'].iloc[0].split('_')[0]
        tmp_df['experiment_name'] = tmp_df['experiment_name'].str.replace('{}_'.format(name), '')
        
        # Reformat the table.
        tmp_df = get_test_results(tmp_df, count=count)
        tmp_df = tmp_df.round(sig_digits)

        # Merge mean+/-std into a single column with the experiment name.
        results = [] 
        for a, b in zip(tmp_df["best_test_auc"], tmp_df["best_test_auc_std"]):
            results.append(fmt_str.format(a, b))
        tmp_df['test_auc_{}'.format(name)] = results 
        tmp_df = tmp_df.drop(['best_test_auc', 'best_test_auc_std'], axis=1)

        print(tmp_df)
        
        # Merge the experiments.
        if i == 0:
            final_df = copy(tmp_df)
        else:
            final_df = pd.merge(final_df, tmp_df, on='experiment_name')
    
    return final_df


def get_results(df, groups, cols, count=False, mode='mean'):
    """
    Shows a reduced form of the table with mean and std values 
    over experiments.
    """
    df_tmp = df.groupby(groups)[cols]
    if mode == 'mean':
        df = df_tmp.mean().join(df_tmp.std(),rsuffix='_std')
    elif mode == 'max':
        df = df_tmp.max().join(df_tmp.std(),rsuffix='_std')
    if count:
        df = df.join(df_tmp.count(), rsuffix='_count')
                                                                        
    return df


def get_test_results(df, count=False):
    """Get the test results at the best epoch."""
    groups = ['experiment_name']
    cols = ['best_test_auc']
    
    return get_results(df, groups, cols, count=count, mode='mean')


def get_lineplot_df(input_df, epoch):
    """
    Gets a dataframe useful for the lineplots below.
    """
    df = copy(input_df)
    df = df[(df.epoch==EPOCH)]
    idx = ~df['experiment_name'].str.contains('cnn')
    df = df[idx]
    idx = ~df['experiment_name'].str.contains('_ae_')
    return df[idx]


def get_curveplot_df(input_df, mode):
    """
    Gets a dataframe useful for the curveplots below.
    """
    df = copy(input_df)
    idx = df['experiment_name'].str.contains(mode)
    return df[idx]


def read_img(exp):
    """Reads best epoch images that are generated by the training loop."""
    file = glob("../images/{}/image-best_valid-*.png".format(exp))
    image = io.imread(file[0])
    plt.imshow(image)
    plt.show();
    
def merge_dfs_for_maxmasks(df, df_mask, old_prefix, new_prefix, actdiff_lamb):
    tmp_resnet_act = df[df['experiment_name'] == '{}_resnet'.format(old_prefix)]
    tmp_resnet_act['experiment_name'] = '{}_resnet_actdiff'.format(new_prefix)
    tmp_resnet_act['maxmasks_train'] = 0
    tmp_resnet_act['actdiff_lambda'] = actdiff_lamb
    tmp_resnet_grd = df[df['experiment_name'] == '{}_resnet'.format(old_prefix)]
    tmp_resnet_grd['experiment_name'] = '{}_resnet_gradmask'.format(new_prefix)
    tmp_resnet_grd['maxmasks_train'] = 0
    tmp_resnet_grd['actdiff_lambda'] = 0
    tmp_unet_act = df[df['experiment_name'] == '{}_unet'.format(old_prefix)]
    tmp_unet_act['experiment_name'] = '{}_unet_actdiff'.format(new_prefix)
    tmp_unet_act['maxmasks_train'] = 0
    tmp_unet_act['actdiff_lambda'] = actdiff_lamb
    tmp_unet_grd = df[df['experiment_name'] == '{}_unet'.format(old_prefix)]
    tmp_unet_grd['experiment_name'] = '{}_unet_gradmask'.format(new_prefix)
    tmp_unet_grd['maxmasks_train'] = 0
    tmp_unet_grd['actdiff_lambda'] = 0
    df_mask = pd.concat([df_mask, tmp_resnet_act, tmp_resnet_grd, tmp_unet_act, tmp_unet_grd])
    
    return df_mask


def seed_finder(df, experiment_name):
    test = df[df.epoch == EPOCH]
    test = test[test['experiment_name'] == experiment_name]

    for seed in SEEDS:
        if seed not in np.array(test['seed']):
            print("{} missing {}".format(experiment_name, seed))

In [10]:
# Collects all of the results available in logs-single (for individual runs)                    
df = pd.concat(get_metrics("../checkpoints", best=True))
df = df_cleaner(df)

# Collect the MSDDataset experiments.
df_synth = df[df['experiment_name'].str.contains('synth-search_')]
df_liver = df[df['experiment_name'].str.contains('livermsd-search_')]
df_cardiac = df[df['experiment_name'].str.contains('cardiacmsd-search_')]
df_pancreas = df[df['experiment_name'].str.contains('pancreasmsd-search_')]
df_xray = df[df['experiment_name'].str.contains('xray-search_')]
df_search = pd.concat([df_synth, df_liver, df_cardiac, df_pancreas], axis=0)

#df_search

ValueError: No objects to concatenate

In [None]:
# Ugly thing for figuring out if you have all the seeds.
#seed_finder(df_baseline, 'livermsd_unet')

In [None]:
# Uses the control experiments (no activmask) as the "maxmasks=0" runs for the maxmasks plots below.
df_synth_mask = df[df['experiment_name'].str.contains('synth-masks')]
df_synth_mask = merge_dfs_for_maxmasks(df_synth, df_synth_mask, 'synth', 'synth-masks', 10)
df_liver_mask = df[df['experiment_name'].str.contains('livermsd-masks')]
df_liver_mask = merge_dfs_for_maxmasks(df_liver, df_liver_mask, 'livermsd', 'livermsd-masks', 1)
df_cardiac_mask = df[df['experiment_name'].str.contains('cardiacmsd-masks')]
df_cardiac_mask = merge_dfs_for_maxmasks(df_cardiac, df_cardiac_mask, 'cardiacmsd', 'cardiacmsd-masks', 1)
df_pancreas_mask = df[df['experiment_name'].str.contains('pancreasmsd-masks')]
df_pancreas_mask = merge_dfs_for_maxmasks(df_pancreas, df_pancreas_mask, 'pancreasmsd', 'pancreasmsd-masks', 1)
df_mask = pd.concat([df_synth_mask, df_liver_mask, df_cardiac_mask, df_pancreas_mask], axis=0)

In [None]:
# Write table showing the results at the final epoch.
with open('tables/synth_last_results.tex','w') as tf:
    tf.write(get_last_results_at_epoch(df_synth, EPOCH, sig_digits=2).to_latex())

get_last_results_at_epoch(df_synth, EPOCH, sig_digits=2)

In [None]:
# Write table showing best test AUC results.
with open('tables/all_test_results.tex', 'w') as tf:
    tf.write(make_results_table([df_synth, df_liver, df_cardiac, df_pancreas], sig_digits=2).to_latex())

make_results_table([df_synth, df_liver, df_cardiac, df_pancreas], sig_digits=2)

In [None]:
# Write table showing best test AUC results.
with open('tables/all_xray_results.tex', 'w') as tf:
    tf.write(make_results_table([df_xray], sig_digits=2).to_latex())

make_results_table([df_xray], sig_digits=2)

In [None]:
# Plots the Valid AUC over all epochs of the Synthetic Dataset training.
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(figsize=(10, 6), 
                                             nrows=2, ncols=2, sharex=True, sharey=True)

palette = {
    'synth_unet': "gray", 
    'synth_unet_actdiff': "red",
    'synth_unet_clfmasked': "green",
    'synth_unet_gradmask': "blue",
    'synth_unet_actgrad': "yellow",
    'synth_unet_reconmasked': "orange"
}

g = sns.lineplot(
    x="epoch", y='valid_auc', hue='experiment_name', 
    ax=ax1, data=get_curveplot_df(df_synth, 'unet'),
    palette=palette, hue_order=sorted(list(palette.keys())))
g.set_title("UNet Experiments")
#g.legend(loc='center left', bbox_to_anchor=(1, 0.5))
g.get_legend().remove()
g.set_ylim(-0.05, 1.05)
g.set_xlabel('Epoch')
g.set_ylabel('Valid AUC')

palette = {
    'synth_ae': "gray", 
    'synth_ae_actdiff': "red",
    'synth_ae_clfmasked': "green",
    'synth_ae_gradmask': "blue",
    'synth_ae_actgrad': "yellow",
    'synth_ae_reconmasked': "orange"
}

g = sns.lineplot(
    x="epoch", y='valid_auc', hue='experiment_name', 
    ax=ax2, data=get_curveplot_df(df_synth, 'ae'),
    palette=palette, hue_order=sorted(list(palette.keys())))
g.set_title("CNN Autoencoder Experiments")
g.legend(['Classify', 'Actdiff', 'Actgrad', 'Classify Masked', 'Gradmask', 'Reconstruct Masked'], 
          loc='center left', bbox_to_anchor=(1, 0.5), frameon=False)
g.set_ylim(-0.05, 1.05)
g.set_xlabel('Epoch')
g.set_ylabel('Valid AUC')

palette = {
    'synth_resnet': "gray", 
    'synth_resnet_actdiff': "red",
    'synth_resnet_clfmasked': "green",
    'synth_resnet_gradmask': "blue",
    'synth_resnet_actgrad': "yellow"
}

g = sns.lineplot(
    x="epoch", y='valid_auc', hue='experiment_name', 
    ax=ax3, data=get_curveplot_df(df_synth, 'resnet'),
    palette=palette, hue_order=sorted(list(palette.keys())))
g.set_title("ResNet 18 Experiments")
#g.legend(loc='center left', bbox_to_anchor=(1, 0.5))
g.get_legend().remove()
g.set_ylim(-0.05, 1.05)
g.set_xlabel('Epoch')
g.set_ylabel('Valid AUC')

palette = {
    'synth_cnn': "gray", 
    'synth_cnn_actdiff': "red",
    'synth_cnn_clfmasked': "green",
    'synth_cnn_gradmask': "blue",
    'synth_cnn_actgrad': "yellow",
}

g = sns.lineplot(
    x="epoch", y='valid_auc', hue='experiment_name', 
    ax=ax4, data=get_curveplot_df(df_synth, 'cnn'),
    palette=palette, hue_order=sorted(list(palette.keys())))
g.set_title("CNN Experiments")
#g.legend(loc='center left', bbox_to_anchor=(1, 0.5))
g.get_legend().remove()
g.set_ylim(-0.05, 1.05)
g.set_xlabel('Epoch')
g.set_ylabel('Valid AUC')

fig.set_tight_layout(tight=True)

In [None]:
# Gets the best test auc (early stopping) for all masks experiments.
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(figsize=(10, 6), 
                                             nrows=2, ncols=2, sharex=True, sharey=True)

palette = {
    'synth-masks_unet_gradmask': "red", 
    'synth-masks_unet_actdiff': "lightcoral",
    'synth-masks_resnet_gradmask': "blue",
    'synth-masks_resnet_actdiff': "cornflowerblue",
}

g = sns.lineplot(
    x="maxmasks_train", y='best_test_auc', hue='experiment_name', 
    ax=ax1, data=get_lineplot_df(df_synth_mask, EPOCH),
    palette=palette, style='actdiff_lambda', hue_order=sorted(list(palette.keys())))
g.set_title("Synth")
g.set(ylim=(0.5, 1.05))
g.get_legend().remove()
g.set_xlabel('% Masks Training Examples')
g.set_ylabel('Best Test AUC')

palette = {
    'livermsd-masks_unet_gradmask': "red", 
    'livermsd-masks_unet_actdiff': "lightcoral",
    'livermsd-masks_resnet_gradmask': "blue",
    'livermsd-masks_resnet_actdiff': "cornflowerblue",
}

g = sns.lineplot(
    x="maxmasks_train", y='best_test_auc', hue='experiment_name', 
    ax=ax2, data=get_lineplot_df(df_liver_mask, EPOCH),
    palette=palette, style='actdiff_lambda', hue_order=sorted(list(palette.keys())))
g.set_title("Liver")
g.set(ylim=(0.5, 1.0))
g.legend(['ResNet Actdiff', 'ResNet Gradmask', 'UNet Actdiff', 'UNet Gradmask'], 
         loc='center left', bbox_to_anchor=(1, 0.5), frameon=False)
g.set_xlabel('% Masks Training Examples')
g.set_ylabel('Best Test AUC')


palette = {
    'cardiacmsd-masks_unet_gradmask': "red", 
    'cardiacmsd-masks_unet_actdiff': "lightcoral",
    'cardiacmsd-masks_resnet_gradmask': "blue",
    'cardiacmsd-masks_resnet_actdiff': "cornflowerblue",
}

g = sns.lineplot(
    x="maxmasks_train", y='best_test_auc', hue='experiment_name', 
    ax=ax3, data=get_lineplot_df(df_cardiac_mask, EPOCH),
    palette=palette, style='actdiff_lambda', hue_order=sorted(list(palette.keys())))
g.set_title("Cardiac")
g.set(ylim=(0.5, 1.0))
g.get_legend().remove()
g.set_xlabel('% Masks Training Examples')
g.set_ylabel('Best Test AUC')


palette = {
    'pancreasmsd-masks_unet_gradmask': "red", 
    'pancreasmsd-masks_unet_actdiff': "lightcoral",
    'pancreasmsd-masks_resnet_gradmask': "blue",
    'pancreasmsd-masks_resnet_actdiff': "cornflowerblue",
}

g = sns.lineplot(
    x="maxmasks_train", y='best_test_auc', hue='experiment_name', 
    ax=ax4, data=get_lineplot_df(df_pancreas_mask, EPOCH),
    palette=palette, style='actdiff_lambda', hue_order=sorted(list(palette.keys())))
g.set_title("Pancreas")
g.set(ylim=(0.5, 1.0))
g.get_legend().remove()
g.set_xlabel('% Masks Training Examples')
g.set_ylabel('Best Test AUC')



fig.set_tight_layout(tight=True)

In [None]:
# Gets the best epoch (early stopping) for all masks experiments.
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(figsize=(10, 6), 
                                             nrows=2, ncols=2, sharex=True, sharey=True)

palette = {
    'synth-masks_unet_gradmask': "red", 
    'synth-masks_unet_actdiff': "lightcoral",
    'synth-masks_resnet_gradmask': "blue",
    'synth-masks_resnet_actdiff': "cornflowerblue",
}

g = sns.lineplot(
    x="maxmasks_train", y='best_epoch', hue='experiment_name', 
    ax=ax1, data=get_lineplot_df(df_synth_mask, EPOCH),
    palette=palette, style='actdiff_lambda', hue_order=sorted(list(palette.keys())))
g.set_title("Synth")
g.get_legend().remove()
g.set_xlabel('% Masks Training Examples')
g.set_ylabel('Best Epoch')


palette = {
    'livermsd-masks_unet_gradmask': "red", 
    'livermsd-masks_unet_actdiff': "lightcoral",
    'livermsd-masks_resnet_gradmask': "blue",
    'livermsd-masks_resnet_actdiff': "cornflowerblue",
}

g = sns.lineplot(
    x="maxmasks_train", y='best_epoch', hue='experiment_name', 
    ax=ax2, data=get_lineplot_df(df_liver_mask, EPOCH),
    palette=palette, style='actdiff_lambda', hue_order=sorted(list(palette.keys())))
g.set_title("Liver")
g.legend(['ResNet Actdiff', 'ResNet Gradmask', 'UNet Actdiff', 'UNet Gradmask'],
          loc='center left', bbox_to_anchor=(1, 0.5), frameon=False)
g.set_xlabel('% Masks Training Examples')
g.set_ylabel('Best Epoch')


palette = {
    'cardiacmsd-masks_unet_gradmask': "red", 
    'cardiacmsd-masks_unet_actdiff': "lightcoral",
    'cardiacmsd-masks_resnet_gradmask': "blue",
    'cardiacmsd-masks_resnet_actdiff': "cornflowerblue",
}

g = sns.lineplot(
    x="maxmasks_train", y='best_epoch', hue='experiment_name', 
    ax=ax3, data=get_lineplot_df(df_cardiac_mask, EPOCH),
    palette=palette, style='actdiff_lambda', hue_order=sorted(list(palette.keys())))
g.set_title("Cardiac")
g.get_legend().remove()
g.set_xlabel('% Masks Training Examples')
g.set_ylabel('Best Epoch')


palette = {
    'pancreasmsd-masks_unet_gradmask': "red", 
    'pancreasmsd-masks_unet_actdiff': "lightcoral",
    'pancreasmsd-masks_resnet_gradmask': "blue",
    'pancreasmsd-masks_resnet_actdiff': "cornflowerblue",
}

g = sns.lineplot(
    x="maxmasks_train", y='best_epoch', hue='experiment_name', 
    ax=ax4, data=get_lineplot_df(df_pancreas_mask, EPOCH),
    palette=palette, style='actdiff_lambda', hue_order=sorted(list(palette.keys())))
g.set_title("Pancreas")
g.get_legend().remove()
g.set_xlabel('% Masks Training Examples')
g.set_ylabel('Best Epoch')

fig.set_tight_layout(tight=True)

Salience Maps
---------------------

In [3]:
#d = datasets.MSDDataset.HeartMSDDataset(mode='test', nsamples=10, blur=4)
#d = datasets.MSDDataset.LiverMSDDataset(mode='test', nsamples=1000, blur=4)
#d = datasets.MSDDataset.PancreasMSDDataset(mode='test', nsamples=1000, blur=4)
dataset = datasets.xray.JointDataset(
    "/lustre04/scratch/jdv/xray/NIH/images-128",
    "/lustre04/scratch/jdv/xray/Data_Entry_2017.csv",
    "/lustre04/scratch/jdv/xray/PC/images-128",
    "/lustre04/scratch/jdv/xray/PADCHEST_chest_x_ray_images_labels_160K_01.02.19.csv",
    mode='test',
    seed=1234,
    ratio=0.9,
    new_size=128)

mdl = torch.load("/home/jdv/code/activmask/results/xray-search_resnet_actdiff_1234.pth.tar", map_location='cpu')

def render_grad(dataset, mdl):
    
    # Configure plots.
    fig, axs = plt.subplots(nrows=1, ncols=1, 
                           figsize=(4, 4), dpi=150, 
                           gridspec_kw={'hspace': 0.15, 'wspace': 0.01})
    fig.subplots_adjust(wspace=0.0)
    for a in ax:
        for aa in a:
            aa.axis('off')

    # Prepare the Image.            
    idx = np.random.randint(1000)
    x, y, use_mask = dataset[idx]
    img = x[0][0].cpu().numpy()
    img /= np.max(img)
    seg = x[1][0].cpu().numpy() 
    
    # Required for Grad computation.
    x_var = torch.autograd.Variable(x[0].unsqueeze(0),
                                    requires_grad=True)

    # Saliency calculation.
    y_prime, x_prime = mdl(x_var)
    gradients = get_grad_contrast(x_var, y_prime, y).detach().cpu().numpy()[0][0]
    gradients *= (gradients > np.percentile(gradients, 90))
    
    grad_cmap = plt.cm.jet
    grad_cmap.set_under('k', alpha=0)

    for ax in axs:
        ax.set_title(name)
        ax.imshow(img, interpolation='none', cmap='Greys_r')
        ax.imshow(gradmask, interpolation='none', cmap=my_cmap, clim=[0.000001, gradmask.max()])
        ax.axis('off')  

        
    plt.show()
render_grad(dataset, mdl)


NameError: name 'datasets' is not defined

In [None]:
name, resnet, unet = models_toplot[2]
print(name)
gradz = []
for s in d:
    x,_,_ = s
    x = torch.tensor(x[0])
    x_var = torch.autograd.Variable(x.unsqueeze(0),requires_grad=True)
    y_prime, x_prime = resnet(x_var)
    gradmask = get_gradmask_loss(x_var, y_prime, resnet, torch.tensor(1.),
                                         "contrast").detach().cpu().numpy()[0][0]
    gradz.append(gradmask)

toagg = np.asarray(gradz)
toagg = toagg[d.labels==True]
plt.imshow(np.abs(toagg.mean(0)), cmap="jet", interpolation='none')
plt.title("Avg grad: " + name + " resnet, test N=" + str(len(gradz)));

In [None]:
render_salience_maps("aaa", "", sample, cnn, resnet, unet, "a")

In [None]:
dataset = datasets.MSDDataset.PancreasMSDDataset(mode='test', nsamples=NSAMPLES, blur=16)
baseline = load_resnet("/network/tmp1/vivianoj/checkpoints/pancreasmsd_resnet/best_model_451_1.0.pth.tar")
actdiff = load_resnet('/network/tmp1/vivianoj/checkpoints/pancreasmsd_resnet_actdiff/best_model_451_1.0.pth.tar')
gradmask = load_resnet('/network/tmp1/vivianoj/checkpoints/pancreasmsd_resnet_gradmask/best_model_451_1.0.pth.tar')
render_mean_grad("aaa", "", dataset, baseline, actdiff, gradmask, "a")