In [7]:
import os.path

from model import get_models_and_path, get_remote_vggs_and_path, get_resnets_and_path, get_remote_resnets_and_path
import torch
import pandas as pd
from plot_and_print import plot_tile
import os
from PIL import Image
import matplotlib.pyplot as plt
from zennit.attribution import Gradient
from relevance import plot_relevance
from data_loader import TileLoader
from scipy import stats

from zennit import torchvision, composites, image
import numpy as np
import zennit as zen
can_res = zen.torchvision.ResNetCanonizer()


In [19]:
data_dir = "../results/"
patients = [os.path.basename(f) for f in os.scandir(data_dir) if f.is_dir()]
patients.remove('p021')
filenames = []
for patient in patients:
    filenames.append(data_dir + patient + "/RUBCNL_results_new.csv")

patient_id = 2
print(filenames[patient_id])
df = pd.read_csv(filenames[patient_id])
def cut_path(x):
    return x[3:]
new_col = df.path
new_col = new_col.apply(cut_path)
df.path = new_col


loader = TileLoader()

# best values first in list, ascending
def get_sorted_values_by_col(df, min, max, colname='../remote_models/new/models/res18/RUBCNL_Res18/Res18_1000_ep_29.pt', gene="RUBCNL"):
    idx = df[gene].sort_values()[min:max].index.values

    vals_orig_by_range = df.iloc[idx]
    diff = vals_orig_by_range[gene] - vals_orig_by_range[colname]

    diff = diff.abs().sort_values()
    return df.iloc[diff.index]
small_vals  = get_sorted_values_by_col(df, 0, 100)
middle_vals = get_sorted_values_by_col(df, int(len(df)/2-50), int(len(df)/2+50))
big_vals    = get_sorted_values_by_col(df, -100, len(df))

In [20]:
df.head()

In [12]:
def plot_model_comparison(grads_res, grads_vgg, image_path, width=4, subplot_size=40):
    plt.figure(figsize=(100,100))
    height = int((len(grads_res) + len(grads_vgg) + 1) / width + 0.999)
    f, ax = plt.subplots(height,width)
    f.set_figheight(subplot_size)
    f.set_figwidth(subplot_size)
    for i in range(len(grads_res)):
        ax[int(i/width),i%width].imshow(grads_res[i][0])
        ax[int(i/width),i%width].set_title(grads_res[i][1]+"\nlabel: "+str(round(grads_res[i][2], 3))+"\npred: "+str(round(grads_res[i][3], 3)))

    for i in range(len(grads_res), len(grads_vgg) + len(grads_res)):
        j = i - len(grads_res)
        ax[int(i/width),i%width].imshow(grads_vgg[j][0])
        ax[int(i/width),i%width].set_title(grads_vgg[j][1]+"\nlabel: "+str(round(grads_vgg[j][2], 3))+"\npred: "+str(round(grads_vgg[j][3], 3)))
    img = Image.open(image_path)
    ax[-1,-1].imshow(img)
    ax[-1,-1].set_title('original')
    plt.show()

def get_grads(models, composite, idx, df, loader, gene="RUBCNL"):
    imgs = []
    for model,path in models:

        model.eval()
        img = loader.open(df.path[idx]).unsqueeze(0)
        with Gradient(model, composite) as attributor:
            out, grad = attributor(img)
        rel = plot_relevance(grad, filename=None, only_return=True)
        position = path.find('models') + len('models')
        imgs.append((rel,path[position:], df.iloc[idx][gene], out.item()))
    return imgs



In [5]:
can_vgg = zen.torchvision.VGGCanonizer()
"""
transform_norm = torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
low, high = transform_norm(torch.tensor([[[[[0.]]] * 3], [[[[1.]]] * 3]]))
composite_vgg = zen.composites.EpsilonGammaBox(low=low, high=high, canonizers=[can_vgg])
"""
can_res = zen.torchvision.ResNetCanonizer()

composite_res = zen.composites.EpsilonPlusFlat(canonizers=[can_res])
composite_vgg = zen.composites.EpsilonPlusFlat(canonizers=[can_vgg])

In [9]:
#models = get_models_and_path()
resnets = get_remote_resnets_and_path()
vggs = get_remote_vggs_and_path()

In [10]:
tile_id = 4170
tile_path = df.iloc[tile_id].path

In [13]:
grads_res = get_grads(resnets, composite_res, 4170, df, loader)
grads_vgg = get_grads(vggs, composite_vgg, 4170, df, loader)
plot_model_comparison(grads_res, grads_vgg, tile_path)

In [14]:
best = [small_vals[0:5], middle_vals[0:5], big_vals[0:5]]
worst = [small_vals[-5:], middle_vals[-5:], big_vals[-5:]]
# interesting:
interesting = []
interesting.append(best[0].index.values[1])
#set = worst[2]
#id = set.index.values[1]


set = best[2]
id = set.index.values[3]
row = df.iloc[id]
print(id)
grads_res = get_grads(resnets, composite_res, id, df, loader)
#plot_model_comparison(grads_res, [], row.path)

In [None]:
grads_vgg = get_grads(vggs, composite_vgg, id, df, loader)
plot_model_comparison([], grads_vgg, row.path)


In [22]:
# pearson
target = df["RUBCNL"]
preds = df.iloc[:, 5:]
print(preds.columns)

In [23]:
pearsons = pd.DataFrame(columns=preds.columns)
print(patients)
for i in range(len(filenames)):
    df_tmp = pd.read_csv(filenames[i])
    target = df_tmp["RUBCNL"]
    preds = df_tmp.iloc[:, 5:]
    row = []
    for name, data in preds.items():
        pearsons.at[i,name] = round(stats.pearsonr(target, data)[0],3)
    print(pearsons.shape)
#print(pearsons.head())
print(pearsons.shape)
pearsons['idx'] = patients
pearsons.set_index('idx', append=True)
pearsons.head()
    

In [24]:
pearsons[pearsons["idx"]=='p009']