In [None]:
import pandas as pd
import numpy as np
import ast
import pickle
import os

import slideflow as sf
import multiprocessing as mp
from slideflow.model import build_feature_extractor
from sklearn.linear_model import LogisticRegression, LinearRegression
from slideflow.gan.stylegan3.stylegan3 import dnnlib, legacy, utils
import torch
from e4e3.utils import common, train_utils
from e4e3.utils.model_utils import setup_model
import torchvision.transforms.functional as F_transform
import matplotlib
%matplotlib inline
import matplotlib.pyplot as plt
from tqdm import tqdm



In [None]:
# Set Project working directory
PROJECT_DIR = os.getcwd()
device = torch.device('cuda:3')

# Generate Samples Images with HistoXGAN and Alternative Encoders

In [None]:
#First we must load all the encoders into memory

#Set Backbone to 'RETCCL' to generate iamges with RETCCL based encoders
backbone = "CTP"

available_gpus = [torch.cuda.device(i) for i in range(torch.cuda.device_count())]
device = torch.device('cuda:3')


if backbone == "CTP":
    feature_extractor = build_feature_extractor('ctranspath', tile_px=512, device = device)
else:
    feature_extractor = build_feature_extractor('retccl', tile_px=512, device = device)



netctp_simple, opts = setup_model(PROJECT_DIR + "/experiment/simple_" + backbone + "/checkpoints/iteration_200000.pt", "SingleStyleCodeEncoder", device)
netctp_e4e, opts = setup_model(PROJECT_DIR + "/experiment/e4e_" + backbone + "/checkpoints/iteration_200000.pt", "Encoder4Editing", device)
if backbone == "CTP":
    netld, opts = setup_model(PROJECT_DIR + "/experiment/lpips_dists/checkpoints/iteration_200000.pt", "Encoder4Editing", device)
else:
    netld, opts = setup_model(PROJECT_DIR + "/experiment/lpips_dists_discrim/checkpoints/iteration_200000.pt", "Encoder4Editing", device)

net_models = [netld, netctp_simple, netctp_e4e]


def forward(batch, net, device):
    x = batch[0] 
    y = batch[0]
    x = F_transform.resize(x, 256)
    x, y = x.to(device).float(), y.to(device).float()
    y_hat, latent = net.forward(x, return_latents=True)
    return x, y, y_hat, latent

G = None
if backbone == "RETCCL":
    with dnnlib.util.open_url(PROJECT_DIR + '/FINAL_MODELS/RetCCL/snapshot.pkl') as f:
        G = legacy.load_network_pkl(f)['G_ema'].to(device)
else:
    with dnnlib.util.open_url(PROJECT_DIR + '/FINAL_MODELS/CTransPath/snapshot.pkl') as f:
        G = legacy.load_network_pkl(f)['G_ema'].to(device)
    

In [None]:
#Note - to use the alternative encoders tested in this study, the 'real' image needs to be provided, not just image features
#IF CPTAC slides are available and TFRecords are extracted, please update the /PROJECTS/HistoXGAN/datasets.json file to
#indicate the location of TFRecords for the CPTAC datasets

#Alternatively, leave the following variables as True - this will load input images that were used in the publication figures
#and generate the corresponding 'generated' images from the alternative encoders and HistoXGAN
USE_SAVED_INPUT_IMAGES = True


imgs_dict = {}
imgs_final = {}

source_list_CPTAC =  ["CPTAC_HNSC", "CPTAC_LUAD", "CPTAC_LSCC", "CPTAC_BRCA", "CPTAC_COADREAD_NO_ROI", "CPTAC_GBM_NO_ROI", "CPTAC_PDA_NO_ROI", "CPTAC_UCEC_NO_ROI"]

if not USE_SAVED_INPUT_IMAGES:
    P = sf.Project(PROJECT_DIR + '/PROJECTS/HistoXGAN/')

    #Must ensure the datasets are properly set up
    
    P.annotations = PROJECT_DIR + 'PROJECTS/HistoXGAN/cptac_all_annotations.csv'

    for s in source_list_CPTAC:
        P.sources = [s]
        imgs = []
        dataset = P.dataset(tile_px=512, tile_um=400)
        df_train = dataset.torch(batch_size=1, num_workers=8, infinite = False)
        with torch.no_grad():
            batch_count = 0
            for batch_idx, batch in enumerate(df_train):
                for i in range(1):
                    img_col = []
                    for net in net_models:
                        x, y, y_hat, latent = forward(batch, net, device)
                        y = common.tensor2im(y[i])
                        y_hat = common.tensor2im(y_hat[i])
                        if len(img_col) == 0:
                            img_col += [y]
                        img_col += [y_hat]
                    img_col += [((G(feature_extractor((batch[0].to(device)+1)*127.5), 0, noise_mode ='const') + 1)*127.5).permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[i].cpu().numpy()]
                    imgs += [img_col]
                imgs_dict[s] = imgs
                break
        imgs_final[s] = imgs_dict[s][0]
else:
    import pickle
    #Loading the 'source' image from saved images for publication
    if backbone == 'RETCCL':
        with open(PROJECT_DIR + '/FINAL_IMAGES/img_display_cptac_retccl.pkl', 'rb') as f:
            imgs_input = pickle.load(f)
    else:
        with open(PROJECT_DIR + '/FINAL_IMAGES/img_display_cptac.pkl', 'rb') as f:
            imgs_final = pickle.load(f)
    for s in source_list_CPTAC:
        imgs = []
        with torch.no_grad():
            batch_count = 0
            batch = torch.tensor(np.array(imgs_input[s][0])/127.5 - 1, dtype = torch.float).permute(2,0,1).unsqueeze(dim = 0).unsqueeze(dim = 0).to(device)
            for i in range(1):
                img_col = []
                for net in net_models:
                    x, y, y_hat, latent = forward(batch, net, device)
                    y = common.tensor2im(y[i])
                    y_hat = common.tensor2im(y_hat[i])
                    if len(img_col) == 0:
                        img_col += [y]
                    img_col += [y_hat]
                img_col += [((G(feature_extractor((batch[0].to(device)+1)*127.5), 0, noise_mode ='const') + 1)*127.5).permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[i].cpu().numpy()]
                imgs += [img_col]
            imgs_dict[s] = imgs
        imgs_final[s] = imgs_dict[s][0]

        
#Code to display the generated images
imgs = []
for img in imgs_final:
    imgs += [imgs_final[img]]
fig, axs2 = plt.subplots(len(imgs[0]), len(imgs)) #, dpi = 300, figsize = (10, 10/8*5))#(2*len(imgs), 2*len(imgs[0])))
col = 0
for img_name in imgs_final:
    img_col = imgs_final[img_name]
    row = 0
    for img in img_col:
        axs2[row][col].imshow(img)
        axs2[row][col].set_xticks([])
        axs2[row][col].set_yticks([])
        axs2[row][col].xaxis.set_label_position('top')
        axs2[row][col].set_aspect('auto')
        row = row + 1
        
    str_name = img_name[6:]
    str_name = str_name.replace("_NO_ROI", "")
    str_name = str_name.replace("COADREAD", "COAD")
    axs2[0][col].set_xlabel(str_name)
    col = col + 1
fig.subplots_adjust(left = 0, top = 1, right = 1, bottom = 0, wspace=-0.1, hspace=0)
axs2[0][0].set_ylabel("Original")
axs2[1][0].set_ylabel("LPIPS / DISTS")
axs2[2][0].set_ylabel("SingleStyle")
axs2[3][0].set_ylabel("Encoder4Editing")
axs2[4][0].set_ylabel("HistoXGAN")

plt.show()

In [None]:
#TO SAVE GENERATED IMAGES FOR USE IN THE FOLLOWING COMPOSITE FIGURES FOR PUBLICATION

if backbone == "CTP":
    f = open(PROJECT_DIR + '/FINAL_IMAGES/img_display_cptac.pkl', 'wb')
    pickle.dump(imgs_final, f)
    f.close()
else:
    f = open(PROJECT_DIR + '/FINAL_IMAGES/img_display_cptac_retccl.pkl', 'wb')
    pickle.dump(imgs_final, f)
    f.close()

# Figure 2, Supplemental Figure 2, and Supplemental Tables

## These figures include samples of generated images for the different encoders, as well as the aveage error in reconstruction across datasets

### Supplemental Tables 1-2

In [None]:
#Need to set prefix for which data to load
prefix = 'RETCCL'

In [None]:
# Table 1

df1 = pd.read_csv(PROJECT_DIR + "/MERGEDLOSS/" + prefix + "LPIPSDISTS.csv", index_col = 0)
df2 = pd.read_csv(PROJECT_DIR + "/MERGEDLOSS/SIMPLE" + prefix + ".csv", index_col = 0)
df3 = pd.read_csv(PROJECT_DIR + "/MERGEDLOSS/E4E" + prefix + ".csv", index_col = 0)
df4 = pd.read_csv(PROJECT_DIR + "/MERGEDLOSS/" + prefix + "GAN.csv", index_col = 0)
df1_tcga = df1[df1.index.str.contains("TCGA")]
df1_cptac= df1[df1.index.str.contains("CPTAC")]
df2_tcga = df2[df2.index.str.contains("TCGA")]
df2_cptac = df2[df2.index.str.contains("CPTAC")]
df3_tcga = df3[df3.index.str.contains("TCGA")]
df3_cptac = df3[df3.index.str.contains("CPTAC")]
df4_tcga = df4[df4.index.str.contains("TCGA")]
df4_cptac = df4[df4.index.str.contains("CPTAC")]
df1['Source'] = df1.index.str[0:5].str.replace("_", "")
df1['Subset'] = df1.index.str[5:].str.replace("_", "")
df1['n'] = df1['count']
df1['LPIP / DISTS'] = df1['mean'].map(lambda x: '{:.3f}'.format(x)) + " (" + df1['stdev'].map(lambda x: '{:.3f}'.format(x)) + ") "
df1['Single Layer'] = df2['mean'].map(lambda x: '{:.3f}'.format(x)) + " (" + df2['stdev'].map(lambda x: '{:.3f}'.format(x)) + ") "
df1['Encoder4Editing'] = df3['mean'].map(lambda x: '{:.3f}'.format(x)) + " (" + df3['stdev'].map(lambda x: '{:.3f}'.format(x)) + ") "
df1['LatentGAN'] = df4['mean'].map(lambda x: '{:.3f}'.format(x)) + " (" + df4['stdev'].map(lambda x: '{:.3f}'.format(x)) + ") "
df1['L_raw_mean'] = df1['mean'] * df1['n']
df1['S_raw_mean'] = df2['mean'] * df1['n']
df1['E_raw_mean'] = df3['mean'] * df1['n']
df1['H_raw_mean'] = df4['mean'] * df1['n']

# print("Overall TCGA Means")
# print(df1.loc[df1.Source == 'TCGA', 'L_raw_mean'].sum() / df1.loc[df1.Source == 'TCGA', 'n'].sum())
# print(df1.loc[df1.Source == 'TCGA', 'S_raw_mean'].sum() / df1.loc[df1.Source == 'TCGA', 'n'].sum())
# print(df1.loc[df1.Source == 'TCGA', 'E_raw_mean'].sum() / df1.loc[df1.Source == 'TCGA', 'n'].sum())
# print(df1.loc[df1.Source == 'TCGA', 'H_raw_mean'].sum() / df1.loc[df1.Source == 'TCGA', 'n'].sum())

# print("Overall CPTAC Means")
# print(df1.loc[df1.Source == 'CPTAC', 'L_raw_mean'].sum() / df1.loc[df1.Source == 'CPTAC', 'n'].sum())
# print(df1.loc[df1.Source == 'CPTAC', 'S_raw_mean'].sum() / df1.loc[df1.Source == 'CPTAC', 'n'].sum())
# print(df1.loc[df1.Source == 'CPTAC', 'E_raw_mean'].sum() / df1.loc[df1.Source == 'CPTAC', 'n'].sum())
# print(df1.loc[df1.Source == 'CPTAC', 'H_raw_mean'].sum() / df1.loc[df1.Source == 'CPTAC', 'n'].sum())

df1.loc[df1.Source == 'TCGA']
df1 = df1.drop(['count', 'mean', 'stdev', 'stderr'], axis = 1)

df1.sort_values(by = ['Source', 'Subset'], ascending = [False, True], inplace=True)

print(df1)



### FIGURE 2 - CTransPath Examples and Loss

In [None]:
#Need to set prefix for which data to load
prefix = 'CTP'
%matplotlib inline

df1 = pd.read_csv(PROJECT_DIR + "/MERGEDLOSS/" + prefix + "LPIPSDISTS.csv", index_col = 0)
df2 = pd.read_csv(PROJECT_DIR + "/MERGEDLOSS/SIMPLE" + prefix + ".csv", index_col = 0)
df3 = pd.read_csv(PROJECT_DIR + "/MERGEDLOSS/E4E" + prefix + ".csv", index_col = 0)
df4 = pd.read_csv(PROJECT_DIR + "/MERGEDLOSS/" + prefix + "GAN.csv", index_col = 0)
df1_tcga = df1[df1.index.str.contains("TCGA")]
df1_cptac= df1[df1.index.str.contains("CPTAC")]
df2_tcga = df2[df2.index.str.contains("TCGA")]
df2_cptac = df2[df2.index.str.contains("CPTAC")]
df3_tcga = df3[df3.index.str.contains("TCGA")]
df3_cptac = df3[df3.index.str.contains("CPTAC")]
df4_tcga = df4[df4.index.str.contains("TCGA")]
df4_cptac = df4[df4.index.str.contains("CPTAC")]


fig = plt.figure(figsize=(10, 12), dpi = 300)

subfigs = fig.subfigures(2, 1)

axs = subfigs[0].subplots(1, 2, sharey = True,  gridspec_kw= {'width_ratios': [3.2, 1]})

from matplotlib.legend_handler import HandlerTuple

def bar_plot(ax, df, dataset):

    ic = 6
    if dataset == 'TCGA':
        ic = 5
    tick_labels = df[0].index.str[ic:].values
    tick_labels.sort()
    tick_labels = tick_labels
    tick_counts = [list(range(len(df[0].index.values))), [], [], []]
    tick_counts[0] = [x for x in tick_counts[0]]
    for df2_n in df[1].index.str[ic:].values:
        tick_counts[1] += [list(tick_labels).index(df2_n)]
    for df3_n in df[2].index.str[ic:].values:
        tick_counts[2] += [list(tick_labels).index(df3_n)]
    for df4_n in df[3].index.str[ic:].values:
        tick_counts[3] += [list(tick_labels).index(df4_n)]
    cmap = matplotlib.cm.get_cmap('Blues')
    p1 = ax.bar(tick_counts[0], df[0]['mean'].values.astype(float), label = 'LPIPS/DISTS Encoder', zorder = -25, color = cmap(0.3))
    p2 = ax.bar(tick_counts[1], df[1]['mean'].values.astype(float), label = 'SingleStyle Encoder', zorder = -20, color = cmap(0.5))
    #for CPTAC to have the right order of bar graphs:
    
    if 'THYM' in tick_labels:
        p10 = ax.bar([list(tick_labels).index('THYM')], df[0][df[0].index == 'TCGA_THYM']['mean'].values.astype(float), label = 'LPIPS/DISTS Encoder', zorder = -15, color = cmap(0.3))
    p3 = ax.bar(tick_counts[2], df[2]['mean'].values.astype(float), label = 'Encoder4Editing', zorder = -10, color = cmap(0.7))
    p4 = ax.bar(tick_counts[3], df[3]['mean'].values.astype(float), label = 'HistoXGAN Encoder', zorder = -5, color = cmap(0.9))
    
#plt.bar(tick_counts_CPTAC, df_cptac[0].values.astype(float), label = '')
    ax.set_xticks([x for x in tick_counts[0]])
    ax.set_xticklabels(tick_labels, rotation=90)
    ax.errorbar(tick_counts[0], df[0]['mean'].values.astype(float), yerr=df[0]['stderr'].astype(float), fmt='none', color='black', capsize=3, zorder = -25)
    ax.errorbar(tick_counts[1], df[1]['mean'].values.astype(float), yerr=df[1]['stderr'].astype(float), fmt='none', color='black', capsize=3, zorder = -20)    
    if 'THYM' in tick_labels:
        ax.errorbar([list(tick_labels).index('THYM')], df[0][df[0].index == 'TCGA_THYM']['mean'].values.astype(float), yerr = df[0][df[0].index == 'TCGA_THYM']['stderr'].astype(float), fmt='none', color='black', capsize = 3, zorder = -15)
    ax.errorbar(tick_counts[2], df[2]['mean'].values.astype(float), yerr=df[2]['stderr'].astype(float), fmt='none', color='black', capsize=3, zorder = -10)
    ax.errorbar(tick_counts[3], df[3]['mean'].values.astype(float), yerr=df[3]['stderr'].astype(float), fmt='none', color='black', capsize=3, zorder = -5)
    ax.set_xlabel('Cancer Subtype, ' + dataset)
    if prefix == "RETCCL":
        ax.set_ylim([0, 0.029])
    else:
        ax.set_ylim([0, 0.095])
    if dataset == 'TCGA':
        ax.set_ylabel('L1 Loss')
        leg = None
        if True and 'THYM' in tick_labels:
            leg = ax.legend([(p1, p10), p2, p3, p4], ['LPIPS/DISTS Encoder', 'SingleStyle Encoder', 'Encoder4Editing', 'HistoXGAN Encoder'], loc = 'upper left'
                           )#handler_map={tuple: HandlerTuple(ndivide=None)})
        else:
            leg = ax.legend(loc = 'upper left', facecolor = 'white')
        for lh in leg.legendHandles: 
            lh.set_alpha(1)
        leg.get_frame().set_alpha(1)
    else:
        ax.tick_params(axis = 'y', left = False)

bar_plot(axs[0], [df1_tcga, df2_tcga, df3_tcga, df4_tcga], dataset = 'TCGA')
bar_plot(axs[1], [df1_cptac, df2_cptac, df3_cptac, df4_cptac], dataset = 'CPTAC')
#plt.setp(axs[1].get_yticks(), visible=False)
#subf.subplots_adjust(wspace=0.05, hspace=0)



import pickle
with open(PROJECT_DIR + '/FINAL_IMAGES/img_display_cptac.pkl', 'rb') as f:
    imgs_final = pickle.load(f)
    
#imgs_final = {}
imgs = []
for img in imgs_final:
    imgs += [imgs_final[img]]
axs2 = subfigs[1].subplots(len(imgs[0]), len(imgs)) #, dpi = 300, figsize = (10, 10/8*5))#(2*len(imgs), 2*len(imgs[0])))
col = 0
for img_name in imgs_final:
    img_col = imgs_final[img_name]
    row = 0
    for img in img_col:
        axs2[row][col].imshow(img)
        #axs[col][row].spines['right'].set_visible(False)
        #axs[col][row].spines['top'].set_visible(False)
        #axs[col][row].spines['left'].set_visible(False)
        #axs[col][row].spines['bottom'].set_visible(False)
        axs2[row][col].set_xticks([])
        axs2[row][col].set_yticks([])
        axs2[row][col].xaxis.set_label_position('top')
        axs2[row][col].set_aspect('auto')
        row = row + 1
        
    str_name = img_name[6:]
    str_name = str_name.replace("_NO_ROI", "")
    str_name = str_name.replace("COADREAD", "COAD")
    axs2[0][col].set_xlabel(str_name)
    col = col + 1
axs2[0][0].set_ylabel("Original")
axs2[1][0].set_ylabel("LPIPS / DISTS")
axs2[2][0].set_ylabel("Single Layer")
axs2[3][0].set_ylabel("Encoder4Editing")
axs2[4][0].set_ylabel("HistoXGAN")
#plt.tight_layout()

subfigs[1].subplots_adjust(left = 0, top = 1, right = 1, bottom = 0, wspace=-0.1, hspace=0)
subfigs[0].subplots_adjust(left = 0.03, top = 1, bottom = 0.3, wspace = 0.07)

subfigs[0].text(-0.01,1,"A", zorder = 30, clip_on = False, weight='bold' )
subfigs[1].text(-0.01,1.01,"B", zorder = 30, clip_on = False, weight='bold')

plt.savefig(PROJECT_DIR + "/pub_figures/fig1_" + prefix + ".svg")
plt.savefig(PROJECT_DIR + "/pub_figures/fig1_"  + prefix + ".png")
plt.show()

### Supplemental Figure 2 - CTransPath Examples and Loss

In [None]:
prefix = "RETCCL"
df1 = pd.read_csv(PROJECT_DIR + "/MERGEDLOSS/" + prefix + "LPIPSDISTS.csv", index_col = 0)
df2 = pd.read_csv(PROJECT_DIR + "/MERGEDLOSS/SIMPLE" + prefix + ".csv", index_col = 0)
df3 = pd.read_csv(PROJECT_DIR + "/MERGEDLOSS/E4E" + prefix + ".csv", index_col = 0)
df4 = pd.read_csv(PROJECT_DIR + "/MERGEDLOSS/" + prefix + "GAN.csv", index_col = 0)
df1_tcga = df1[df1.index.str.contains("TCGA")]
df1_cptac= df1[df1.index.str.contains("CPTAC")]
df2_tcga = df2[df2.index.str.contains("TCGA")]
df2_cptac = df2[df2.index.str.contains("CPTAC")]
df3_tcga = df3[df3.index.str.contains("TCGA")]
df3_cptac = df3[df3.index.str.contains("CPTAC")]
df4_tcga = df4[df4.index.str.contains("TCGA")]
df4_cptac = df4[df4.index.str.contains("CPTAC")]

import matplotlib.gridspec as gridspec

fig = plt.figure(figsize=(10, 12), dpi = 300)

subfigs = fig.subfigures(2, 1)

axs = subfigs[0].subplots(1, 2, sharey = True,  gridspec_kw= {'width_ratios': [3.2, 1]})

from matplotlib.legend_handler import HandlerTuple

def bar_plot(ax, df, dataset):

    ic = 6
    if dataset == 'TCGA':
        ic = 5
    tick_labels = df[0].index.str[ic:].values
    tick_labels.sort()
    tick_labels = tick_labels
    tick_counts = [list(range(len(df[0].index.values))), [], [], []]
    tick_counts[0] = [x for x in tick_counts[0]]
    for df2_n in df[1].index.str[ic:].values:
        tick_counts[1] += [list(tick_labels).index(df2_n)]
    for df3_n in df[2].index.str[ic:].values:
        tick_counts[2] += [list(tick_labels).index(df3_n)]
    for df4_n in df[3].index.str[ic:].values:
        tick_counts[3] += [list(tick_labels).index(df4_n)]
    cmap = matplotlib.cm.get_cmap('Blues')
    p1 = ax.bar(tick_counts[0], df[0]['mean'].values.astype(float), label = 'LPIPS/DISTS Encoder', zorder = -2, color = cmap(0.3))
    p2 = ax.bar(tick_counts[1], df[1]['mean'].values.astype(float), label = 'SingleStyle Encoder', zorder = -1, color = cmap(0.5))
    #for RETCCL to have the right order of bar graphs:
    if dataset == 'TCGA':
        tl_l = []
        tl_m = []
        for s in ['CESC', 'CHOL', 'COADREAD', 'DLBC', 'HNSC', 'KIRP', 'LGG', 'LUAD', 'LUSC', 'OV', 'PCPG', 'PRAD', 'SARC', 'SKCM', 'TGCT', 'THCA', 'UCEC', 'UCS' ,'UVM']:
            tl_l += [list(tick_labels).index(s)]
            tl_m += [df[0][df[0].index == 'TCGA_' + s]['mean'].values.astype(float)[0]]
    
    if dataset == 'CPTAC':
        tl_l = []
        tl_m = []
        for s in ['LUSC', 'PAAD', 'UCEC']:
            tl_l += [list(tick_labels).index(s)]
            tl_m += [df[0][df[0].index == 'CPTAC_' + s]['mean'].values.astype(float)[0]]
    
    p10 = ax.bar(tl_l, tl_m, label = 'LPIPS/DISTS Encoder', zorder = 0, color = cmap(0.3))


    p3 = ax.bar(tick_counts[2], df[2]['mean'].values.astype(float), label = 'Encoder4Editing', zorder = 15, color = cmap(0.7))
    p4 = ax.bar(tick_counts[3], df[3]['mean'].values.astype(float), label = 'HistoXGAN Encoder', zorder = 20, color = cmap(0.9))
    
#plt.bar(tick_counts_CPTAC, df_cptac[0].values.astype(float), label = '')
    ax.set_xticks([x for x in tick_counts[0]])
    ax.set_xticklabels(tick_labels, rotation=90)
    ax.errorbar(tick_counts[0], df[0]['mean'].values.astype(float), yerr=df[0]['stderr'].astype(float), fmt='none', color='black', capsize=3, zorder = -2)
    ax.errorbar(tick_counts[1], df[1]['mean'].values.astype(float), yerr=df[1]['stderr'].astype(float), fmt='none', color='black', capsize=3, zorder = -1)    
    if dataset == 'TCGA':
        tl_l = []
        tl_m = []
        tl_e = []
        for s in ['CESC', 'CHOL', 'COADREAD', 'DLBC', 'HNSC', 'KIRP', 'LGG', 'LUAD', 'LUSC', 'OV', 'PCPG', 'PRAD', 'SARC', 'SKCM', 'TGCT', 'THCA', 'UCEC', 'UCS' ,'UVM']:
            tl_l += [list(tick_labels).index(s)]
            tl_m += [df[0][df[0].index == 'TCGA_' + s]['mean'].values.astype(float)[0]]
            tl_e += [df[0][df[0].index == 'TCGA_' + s]['stderr'].astype(float)[0]]
        ax.errorbar(tl_l, tl_m, yerr = tl_e, fmt='none', color='black', capsize = 3, zorder = 0)
    if dataset == 'CPTAC':
        tl_l = []
        tl_m = []
        tl_e = []
        for s in ['LUSC', 'PAAD', 'UCEC']:
            tl_l += [list(tick_labels).index(s)]
            tl_m += [df[0][df[0].index == 'CPTAC_' + s]['mean'].values.astype(float)[0]]
            tl_e += [df[0][df[0].index == 'CPTAC_' + s]['stderr'].astype(float)[0]]
        ax.errorbar(tl_l, tl_m, yerr = tl_e, fmt='none', color='black', capsize = 3, zorder = 0)
    ax.errorbar(tick_counts[2], df[2]['mean'].values.astype(float), yerr=df[2]['stderr'].astype(float), fmt='none', color='black', capsize=3, zorder = 15)
    ax.errorbar(tick_counts[3], df[3]['mean'].values.astype(float), yerr=df[3]['stderr'].astype(float), fmt='none', color='black', capsize=3, zorder = 20)
    ax.set_xlabel('Cancer Subtype, ' + dataset)
    ax.set_ylim([0, 0.029])
    if dataset == 'TCGA':
        ax.set_ylabel('L1 Loss')
        leg = None
        if True:
            leg = ax.legend([(p1, p10), p2, p3, p4], ['LPIPS/DISTS Encoder', 'SingleStyle Encoder', 'Encoder4Editing', 'HistoXGAN Encoder'],
                           loc = 'upper left', facecolor = 'white')#handler_map={tuple: HandlerTuple(ndivide=None)})
        else:
            leg = ax.legend(loc = 'upper left', facecolor = 'white')
        for lh in leg.legendHandles: 
            lh.set_alpha(1)
        leg.get_frame().set_alpha(1)
    else:
        ax.tick_params(axis = 'y', left = False)

bar_plot(axs[0], [df1_tcga, df2_tcga, df3_tcga, df4_tcga], dataset = 'TCGA')
bar_plot(axs[1], [df1_cptac, df2_cptac, df3_cptac, df4_cptac], dataset = 'CPTAC')
#plt.setp(axs[1].get_yticks(), visible=False)

import pickle
with open(PROJECT_DIR + '/FINAL_IMAGES/img_display_cptac_retccl.pkl', 'rb') as f:
    imgs_final = pickle.load(f)
    
#imgs_final = {}
imgs = []
for img in imgs_final:
    imgs += [imgs_final[img]]
axs2 = subfigs[1].subplots(len(imgs[0]), len(imgs)) #, dpi = 300, figsize = (10, 10/8*5))#(2*len(imgs), 2*len(imgs[0])))
col = 0
for img_name in imgs_final:
    img_col = imgs_final[img_name]
    row = 0
    for img in img_col:
        axs2[row][col].imshow(img)
        #axs[col][row].spines['right'].set_visible(False)
        #axs[col][row].spines['top'].set_visible(False)
        #axs[col][row].spines['left'].set_visible(False)
        #axs[col][row].spines['bottom'].set_visible(False)
        axs2[row][col].set_xticks([])
        axs2[row][col].set_yticks([])
        axs2[row][col].xaxis.set_label_position('top')
        axs2[row][col].set_aspect('auto')
        row = row + 1
        
    str_name = img_name[6:]
    str_name = str_name.replace("_NO_ROI", "")
    str_name = str_name.replace("COADREAD", "COAD")
    axs2[0][col].set_xlabel(str_name)
    col = col + 1
axs2[0][0].set_ylabel("Original")
axs2[1][0].set_ylabel("LPIPS / DISTS")
axs2[2][0].set_ylabel("SingleStyle")
axs2[3][0].set_ylabel("Encoder4Editing")
axs2[4][0].set_ylabel("HistoXGAN")
#plt.tight_layout()

subfigs[1].subplots_adjust(left = 0, top = 1, right = 1, bottom = 0, wspace=-0.1, hspace=0)
subfigs[0].subplots_adjust(left = 0.03, top = 1, bottom = 0.3, wspace = 0.07)

subfigs[0].text(-0.01,1,"A", zorder = 30, clip_on = False, weight='bold' )
subfigs[1].text(-0.01,1.01,"B", zorder = 30, clip_on = False, weight='bold')

plt.savefig(PROJECT_DIR + "/pub_figures/fig1_retccl.svg")
plt.savefig(PROJECT_DIR + "/pub_figures/fig1_retccl.png")
plt.show()

# Figure 3 / Supplemental Figure 3

## These figures illustrate using the mean difference in features to transition between states of grade, subtype, and gene expression

In [None]:
#The following are helper functions for interpolation and the statistical calculations for correlation between predicted
#grade / subtype / gene expression for real / generated images

import math
import scipy.stats as stats
from sklearn.linear_model import LogisticRegression, LinearRegression

with dnnlib.util.open_url(PROJECT_DIR + '/FINAL_MODELS/CTransPath/snapshot.pkl') as f:
    G = legacy.load_network_pkl(f)['G_ema'].to(device)

def vector_interpolate(
    G: torch.nn.Module,
    z: torch.tensor,
    z2: torch.tensor,
    device: torch.device,
    steps: int = 100
):
    for interp_idx in range(steps):
        torch_interp = torch.tensor(z - z2 + 2*interp_idx/steps*z2).to(device)
        img = G(torch_interp, 0, noise_mode ='const')
        img = (img + 1) * (255/2)
        img = img.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
        yield img
        
def generate_images(vector_z, vector_z2, prefix = 'test'):
    z = torch.tensor(vector_z).to(device)
    z2 = torch.tensor(vector_z2).to(device)
    img_array = []
    generator = vector_interpolate(G, z, z2, device = device)
    for interp_idx, img in enumerate(generator):
        img_array += [img]
    return img_array

def get_log_features(df, col, name):
    y = df[[col]].values
    feat_cols = list(df.columns.values)
    feat_cols = [f for f in feat_cols if 'Feature_' in f]
    #print(feat_cols)
    X = df[feat_cols]
    vector_list = X.loc[0, :].values.tolist()        
    clf = LogisticRegression().fit(X, y)
    return vector_list, clf.coef_


def strround(s, r = 2):
    return str(round(s, r))

def r_to_z(r):
    return math.log((1 + r) / (1 - r)) / 2.0

def z_to_r(z):
    e = math.exp(2 * z)
    return((e - 1) / (e + 1))

def r_confidence_interval(r, alpha, n):
    z = r_to_z(r)
    se = 1.0 / math.sqrt(n - 3)
    z_crit = stats.norm.ppf(1 - alpha/2)  # 2-tailed z critical value

    lo = z - z_crit * se
    hi = z + z_crit * se

    # Return a sequence
    return (z_to_r(lo), z_to_r(hi))
def SuperScriptinate(number):
    return number.replace('0','⁰').replace('1','¹').replace('2','²').replace('3','³').replace('4','⁴').replace('5','⁵').replace('6','⁶').replace('7','⁷').replace('8','⁸').replace('9','⁹').replace('-','⁻')

def sci_notation(number, sig_fig=2):
    ret_string = "{0:.{1:d}e}".format(number, sig_fig)
    a,b = ret_string.split("e")
    b = int(b)         # removed leading "+" and strips leading zeros too.
    return a + " x 10" + SuperScriptinate(str(b))


#Load compare group will perform calculations and plot the correlation between predictions from real image tiles
#versus generated versions for those same image tiles
#dfs1 and dfs2 are file paths to the 'patient_predictions.csv' saved in the folder
#generated when evaluating models in the Slideflow software.
#For k-fold cross validation (i.e. the internal evaluation of the models in TCGA) these variables should be arrays
#representing a list of files for the predictions in the held out test set for each cross fold
#For convenience and reproducibility we have saved these predictions in the /FINAL_PREDICTIONS/ folder and loaded them
#automatically
def load_compare_group(dfs1 = None, dfs2 = None, title = None, ax = None, read_csv = False, pred_only = False, pred_df = None, color_df = None, pred_class_col = None, pred_class_opts = None, class_names = None, xlabel = False, ylabel = False, unknown_name = "", regen = False):
    if not regen and os.path.isfile(PROJECT_DIR + "/FINAL_PREDICTIONS/" + str(title) + "_" + str(class_names[1]) + "_compare_df.pkl"):
        with open(PROJECT_DIR + "/FINAL_PREDICTIONS/" + str(title) + "_" + str(class_names[1]) + "_compare_df.pkl", 'rb') as f:
            df1 = pickle.load(f)
    else:
        df_1 = []
        df_2 = []
        if read_csv:
            for f in dfs1:
                df_1 += [pd.read_csv(f)]
            for f in dfs2:
                df_2 += [pd.read_csv(f)] 
                df_2[len(df_2) - 1]['load_group'] = len(df_2)

        else:
            for f in dfs1:
                df_1 += [pd.read_parquet(f)]
            for f in dfs2:
                df_2 += [pd.read_parquet(f)]
                df_2[len(df_2) - 1]['load_group'] = len(df_2)

        df1 = pd.concat(df_1)
        df2 = pd.concat(df_2)
        df1 = df1.merge(df2, on='patient', how='inner')
        if pred_df:
            df3 = pd.read_csv(pred_df)
            df1 = df1.merge(df3[['patient', pred_class_col]], on='patient', how='left')
            df1 = df1[df1[pred_class_col].isin(pred_class_opts)]
            df1['group'] = 0
            df1.loc[df1[pred_class_col] == pred_class_opts[1], 'group'] = 1
        if color_df:
            from sklearn.preprocessing import QuantileTransformer
            qt = QuantileTransformer(n_quantiles=1000, random_state=0)
            df3 = pd.read_csv(color_df)
            df1 = df1.merge(df3[['patient', pred_class_col]], on='patient', how='left')
            df1[pred_class_col] = 100 * qt.fit_transform(-1 * df1[[pred_class_col]])
            #df1[pred_class_col] = df1[pred_class_col].apply(lambda x: qt.transform(x))
        with open(PROJECT_DIR + "/FINAL_PREDICTIONS/" + str(title) + "_" + str(class_names[1]) + "_compare_df.pkl", 'wb') as f:
            pickle.dump(df1, f)
    roc = "N/A"
    prc = "N/A"
    roc_gen = "N/A"
    prc_gen = "N/A"
    from sklearn.metrics import roc_auc_score, average_precision_score
    if ax:
        if pred_only:
            if pred_df:
                im = ax.scatter(df1[df1.group == 0].iloc[:, 2], df1[df1.group == 0].iloc[:, 5], label = "True " +  class_names[0])
                im = ax.scatter(df1[df1.group == 1].iloc[:, 2], df1[df1.group == 1].iloc[:, 5], label = "True " +  class_names[1])
                roc = strround(roc_auc_score(df1["group"], df1.iloc[:, 2]))
                roc_gen = strround(roc_auc_score(df1["group"], df1.iloc[:, 5]))
                prc = strround(average_precision_score(df1["group"], df1.iloc[:, 2]))
                prc_gen = strround(average_precision_score(df1["group"], df1.iloc[:, 5]))
            elif color_df:
                ra, rb = stats.pearsonr(df1.iloc[:,2].values, df1[pred_class_col])
                rc, rd = stats.pearsonr(df1.iloc[:,5].values, df1[pred_class_col])
                la, ha = r_confidence_interval(ra, 0.05, len(df1.index))
                lc, hc = r_confidence_interval(rc, 0.05, len(df1.index))

                roc = strround(ra, 2) + " (" + strround(la, 2) + " - " + strround(ha, 2) + ")"
                prc = strround(rb, 2)
                
                roc_gen = strround(rc, 2) + " (" + strround(lc, 2) + " - " + strround(hc, 2) + ")"
                prc_gen = strround(rd, 2)
                
                im = ax.scatter(df1.iloc[:, 2], df1.iloc[:, 5], c=df1[pred_class_col], cmap='Blues')
            else:
                im = ax.scatter(df1.iloc[:, 2], df1.iloc[:, 5], label = 'Unknown ' + unknown_name, color = 'gray')
        else:
            roc = strround(roc_auc_score(df1.iloc[:,3], df1.iloc[:, 2]))
            roc_gen = strround(roc_auc_score(df1.iloc[:,3], df1.iloc[:, 5]))
            prc = strround(average_precision_score(df1.iloc[:,3], df1.iloc[:, 2]))
            prc_gen = strround(average_precision_score(df1.iloc[:,3], df1.iloc[:, 5]))
            im = ax.scatter(df1[df1.iloc[:,3] == 0].iloc[:, 2], df1[df1.iloc[:,3] == 0].iloc[:, 5], label = "True " +  class_names[0])
            im = ax.scatter(df1[df1.iloc[:,3] == 1].iloc[:, 2], df1[df1.iloc[:,3] == 1].iloc[:, 5], label = "True " +  class_names[1])
        if xlabel:
            ax.set_xlabel("Prediction, Real Image")
        else:
            ax.xaxis.set_ticks_position('none')
        if ylabel:
            ax.set_ylabel("Prediction, GAN Image")
        else:
            ax.yaxis.set_ticks_position('none')
        if title:
            ax.set_title(title)

        p = 0
        if pred_only and not pred_df and not color_df:
            r, p = stats.pearsonr(df1.iloc[:,3].values, df1.iloc[:,6].values)
        else:
            r, p = stats.pearsonr(df1.iloc[:,2].values, df1.iloc[:,5].values)
        l, h = r_confidence_interval(r, 0.05, len(df1.index))
        ax.plot([0,1], np.poly1d(np.polyfit(df1.iloc[:, 2], df1.iloc[:, 5], 1))([0,1]), label = "r = " + strround(r,2) + " (" + strround(l, 2) + " - " + strround(h, 2) + ")")
        ax.legend(loc = 'upper left')
        grade_str = "N/A"
        if pred_only and pred_df:
            grade_str = round(100 * len(df1[df1.group == 1].index) / len(df1.index), 1)
        if not pred_only:
            grade_str = round(100 * len(df1[df1.iloc[:, 3] == 1].index) / len(df1.index), 1)
        if class_names[1] == 'High Grade':
            return [[title.split(" - ")[1], title.split(" - ")[0], len(df1), grade_str, strround(r,2) + " (" + strround(l, 2) + " - " + strround(h, 2) + ")", strround(r_to_z(r),2), sci_notation(p), roc, prc, roc_gen, prc_gen]]
        elif not color_df:
            class_str_0 = "N/A"
            if pred_only and pred_df:
                class_str_0 = round(100 * len(df1[df1.group == 0].index) / len(df1.index), 1)
            if not pred_only:
                class_str_0 = round(100 * len(df1[df1.iloc[:, 3] == 0].index) / len(df1.index), 1)
            return [[title.split(" - ")[1], title.split(" - ")[0], len(df1), class_names[0], class_str_0, class_names[1], grade_str, strround(r,2) + " (" + strround(l, 2) + " - " + strround(h, 2) + ")", strround(r_to_z(r),2), sci_notation(p), roc, prc, roc_gen, prc_gen]]
        else:
            class_str_0 = "N/A"
            if pred_only and pred_df:
                class_str_0 = round(100 * len(df1[df1.group == 0].index) / len(df1.index), 1)
            if not pred_only:
                class_str_0 = round(100 * len(df1[df1.iloc[:, 3] == 0].index) / len(df1.index), 1)
            return im, [[title.split(" - ")[1], title.split(" - ")[0], len(df1), class_names[0], class_str_0, class_names[1], grade_str, strround(r,2) + " (" + strround(l, 2) + " - " + strround(h, 2) + ")", strround(r_to_z(r),2), sci_notation(p), roc, prc, roc_gen, prc_gen]]
    else:
        if pred_only:            
            plt.scatter(df1.iloc[:, 2], df1.iloc[:, 5])
        else:
            plt.scatter(df1[df1.iloc[:,3] == 0].iloc[:, 2], df1[df1.iloc[:,3] == 0].iloc[:, 5], label = "True " +  class_names[0])
            plt.scatter(df1[df1.iloc[:,3] == 1].iloc[:, 2], df1[df1.iloc[:,3] == 1].iloc[:, 5], label = "True " +  class_names[1])
        print(df1.corr())
        plt.xlabel("Prediction, Real Image")
        plt.ylabel("Prediction, GAN Image")
        plt.title(title)
        if pred_only:   
            r = df1.corr().iloc[1,4]
        else:
            r = df1.corr().iloc[0,3]
        l, h = r_confidence_interval(r, 0.05, len(df1.index))
        plt.plot([0,1], np.poly1d(np.polyfit(df1.iloc[:, 2], df1.iloc[:, 5], 1))([0,1]), label = "r = " + strround(r,2) + " (" + strround(l, 2) + " - " + strround(h, 2) + ")")
        plt.legend()
        plt.show()

In [None]:
#The following code generates images along the transition for high / low grade

def GRADE_DATASET(dataset, ind, grade_col):
    df = pd.read_csv(PROJECT_DIR + "/PROJECTS/HistoXGAN/SAVED_FEATURES/" +  dataset.lower() + "_features_slide.csv")
    df2 = pd.read_csv(PROJECT_DIR + "/PROJECTS/HistoXGAN/tcga_all_annotations.csv")
    df_mod = pd.read_csv(PROJECT_DIR + "/PROJECTS/HistoXGAN/SAVED_FEATURES/" +  dataset.lower() + "_features_part.csv")
    feat_cols = list(df.columns.values)
    feat_cols = [f for f in feat_cols if 'Feature_' in f]
    vector_base = df_mod[feat_cols].loc[ind, :].values.tolist()    
    df['patient'] = df['Slide'].str[0:12]
    df = df.merge(df2, left_on='patient', right_on='patient', how = 'left')
    #df.loc[df[grade_col] == 'GX', 'Grade_Class'] = np.nan
    #df.loc[df[grade_col] == '[Not Available]', 'Grade_Class'] = np.nan
    df=df.dropna(subset=['high_grade'])
    df['Grade_Class'] = 0
    df.loc[df.high_grade == 'Y', 'Grade_Class'] = 1
    #if dataset == 'PRAD':
    #    df.loc[df[grade_col] == 9, 'Grade_Class'] = 1
    #else:
    #    df.loc[df[grade_col] == 3, 'Grade_Class'] = 1
    #    df.loc[df[grade_col] == 'G3', 'Grade_Class'] = 1
    #    df.loc[df[grade_col] == 'G4', 'Grade_Class'] = 1
    
    vector_z, vector_z2 = get_log_features(df, 'Grade_Class', 'Grade_' + dataset)
    vector_z = vector_base
    return generate_images(vector_z, vector_z2, prefix = 'Grade_' + dataset)

img_dict = {}
img_dict['BRCA'] = GRADE_DATASET('BRCA', 100, 'Grade')
img_dict['PAAD'] = GRADE_DATASET('PAAD', 100, 'histological_grade')
img_dict['HNSC'] = GRADE_DATASET('HNSC', 200, 'neoplasm_histologic_grade')
img_dict['PRAD'] = GRADE_DATASET('PRAD', 200, 'Clinical_Gleason_sum')


f = open(PROJECT_DIR + '/FINAL_IMAGES/img_display_grade_lr.pkl', 'wb')
pickle.dump(img_dict, f)
f.close()


In [None]:
#The following code generates images along the transition between cancer subtypes

def SUBTYPE_DATASET(dataset, ind):
    df = pd.read_csv(PROJECT_DIR + "/PROJECTS/HistoXGAN/SAVED_FEATURES/" +  dataset.lower() + "_features_slide.csv")
    df2 = pd.read_csv(PROJECT_DIR + "/PROJECTS/HistoXGAN/tcga_all_annotations.csv")
    if dataset == 'LUNG':
        df_mod = pd.read_csv(PROJECT_DIR + "/PROJECTS/HistoXGAN/SAVED_FEATURES/luad_features_part.csv")    
    elif dataset == 'KIDNEY':
        df_mod = pd.read_csv(PROJECT_DIR + "/PROJECTS/HistoXGAN/SAVED_FEATURES/kirp_features_part.csv")    
    else:
        df_mod = pd.read_csv(PROJECT_DIR + "/PROJECTS/HistoXGAN/SAVED_FEATURES/" +  dataset.lower() + "_features_part.csv")

    feat_cols = list(df.columns.values)
    feat_cols = [f for f in feat_cols if 'Feature_' in f]
    vector_base = df_mod[feat_cols].loc[ind, :].values.tolist()    

    df['patient'] = df['Slide'].str[0:12]
    if dataset == 'LUNG':
        df2['patient'] = df2['patient'].str[0:12]
    df = df.merge(df2, left_on='patient', right_on='patient', how = 'left')
    df['Subtype'] = np.nan
    if dataset == 'BRCA':
        df.loc[df['2016 Histology Annotations'] == 'Invasive ductal carcinoma', 'Subtype'] = 0
        df.loc[df['2016 Histology Annotations'] == 'Invasive lobular carcinoma', 'Subtype'] = 1
    elif dataset == 'LUNG':    
        df=df.dropna(subset=['project_id'])
        df['Subtype'] = 0
        df.loc[df['project_id'] == 'TCGA-LUSC', 'Subtype'] = 1
    elif dataset == 'ESCA':
        df.loc[df['histological_type'] == "Esophagus Adenocarcinoma, NOS", 'Subtype'] = 0
        df.loc[df['histological_type'] == "Esophagus Squamous Cell Carcinoma", 'Subtype'] = 1
    elif dataset == 'KIDNEY':
        df.loc[df['project_id'] == 'TCGA-KIRC', 'Subtype'] = 0
        df.loc[df['project_id'] == 'TCGA-KIRP', 'Subtype'] = 1
        
    df=df.dropna(subset=['Subtype'])
    
    vector_z, vector_z2 = get_log_features(df, 'Subtype', 'Subtype_' + dataset)
    vector_z = vector_base
    return generate_images(vector_z, vector_z2, prefix = 'Subtype_' + dataset)


img_dict = {}
img_dict['BRCA'] = SUBTYPE_DATASET('BRCA', 100)
img_dict['LUNG'] = SUBTYPE_DATASET('LUNG', 100)
img_dict['ESCA'] = SUBTYPE_DATASET('ESCA', 200)
img_dict['KIDNEY'] = SUBTYPE_DATASET('KIDNEY', 6)

f = open(PROJECT_DIR + '/FINAL_IMAGES/img_display_subtype_lr.pkl', 'wb')
pickle.dump(img_dict, f)
f.close()

In [None]:
#The following code generates images along the transition for high / low gene expression

def BRCA_gene(gene = "CD3G", ind = 100):
    df = pd.read_csv(PROJECT_DIR + "/PROJECTS/HistoXGAN/SAVED_FEATURES/brca_features_slide.csv")
    df2 = pd.read_csv(PROJECT_DIR + "/PROJECTS/HistoXGAN/tcga_all_annotations.csv")
    df_mod = pd.read_csv(PROJECT_DIR + "/PROJECTS/HistoXGAN/SAVED_FEATURES/brca_features_part.csv")
    feat_cols = list(df.columns.values)
    feat_cols = [f for f in feat_cols if 'Feature_' in f]
    vector_base = df_mod[feat_cols].loc[ind, :].values.tolist()    
    df['patient'] = df['Slide'].str[0:12]
    df = df.merge(df2, left_on='patient', right_on='patient', how = 'left')
    df[gene] = np.nan
    df.loc[df[gene + '_class'] == 'L', gene] = 0
    df.loc[df[gene + '_class'] == 'H', gene] = 1
    df=df.dropna(subset=[gene])
    vector_z, vector_z2 = get_log_features(df, gene, gene + '_BRCA')
    vector_z = vector_base
    return generate_images(vector_z, vector_z2, prefix = gene + '_BRCA')

img_dict = {}#6 is good also #39 or 40
img_dict['CD3G'] = BRCA_gene('CD3G')
img_dict['COL1A1'] = BRCA_gene('COL1A1')
img_dict['MKI67'] = BRCA_gene('MKI67')
img_dict['EPCAM'] = BRCA_gene('EPCAM')

f = open(PROJECT_DIR + '/FINAL_IMAGES/img_display_gene_lr.pkl', 'wb')
pickle.dump(img_dict, f)
f.close()


In [None]:
#The following code loads the saved images for grade transition, and also loads the results of models trained to predict grade
#to compare predictions for real slide images vs slide images regenerated with HistoXGAN

import pickle
%matplotlib inline

with open(PROJECT_DIR + '/FINAL_IMAGES/img_display_grade_lr.pkl', 'rb') as f:
    img_dict = pickle.load(f)
    
fig = plt.figure(figsize=(20, 8), dpi = 300)

subfigs = fig.subfigures(1, 2, width_ratios=[14, 6])

axs = subfigs[1].subplots(4, 2, sharex = True, sharey = True)

result_table = []
result_table += load_compare_group(
    title = "BRCA - TCGA",
    ax = axs[0][0],
    class_names = ['Low Grade', 'High Grade'],
    ylabel = "BRCA"
)

result_table += load_compare_group(
    title = "PAAD - TCGA",
    ax = axs[1][0],
    class_names = ['Low Grade', 'High Grade'],
    ylabel = "PAAD"
)

result_table += load_compare_group(
    title = "PRAD - TCGA",
    ax = axs[3][0],
    class_names = ['Low Grade', 'High Grade'],
    ylabel = "PRAD",
    xlabel = True
)
result_table += load_compare_group(
    title = "HNSC - TCGA",
    ax = axs[2][0],
    class_names = ['Low Grade', 'High Grade'],
    ylabel = "HNSC"
)

result_table += load_compare_group(
    title = "BRCA - CPTAC",
    ax = axs[0][1],
    class_names = ['Low Grade', 'High Grade'],
    read_csv = True,
    pred_only = True,
    unknown_name = "Grade")

result_table += load_compare_group(
    title = "PAAD - CPTAC",
    class_names = ['Low Grade', 'High Grade'],
    ax = axs[1][1]
)
result_table += load_compare_group(
    title = "HNSC - CPTAC",
    ax = axs[2][1],
    class_names = ['Low Grade', 'High Grade'],
    read_csv = True,
    pred_only = True,
    xlabel = True,
    unknown_name = "Grade")

axs[3][1].set_visible(False)

df = pd.DataFrame(result_table, columns = ['Dataset', 'Subtype', 'n', 'High Grade (%)', 'Pearson r', 'z statistic', 'p-value, correlation', 'AUROC', 'AP', 'AUROC Gen', 'AUPRC Gen'])
from IPython.display import display
display(df)


img_include = 7

axs2 = subfigs[0].subplots(len(img_dict), img_include)

row_loc = {
    'BRCA':[20,30,40,50,60,70,80],
    'HNSC':[20,30,40,50,60,70,80],
    'PAAD':[20,30,40,50,60,70,80],
    'PRAD':[20,30,40,50,60,70,80],
}
col = 0
for img_name in img_dict:
    row = 0
    for row_item in row_loc[img_name]:
        axs2[col][row].imshow(img_dict[img_name][row_item])
        #axs[col][row].spines['right'].set_visible(False)
        #axs[col][row].spines['top'].set_visible(False)
        #axs[col][row].spines['left'].set_visible(False)
        #axs[col][row].spines['bottom'].set_visible(False)
        axs2[col][row].set_xticks([])
        axs2[col][row].set_yticks([])
        axs2[col][row].xaxis.set_label_position('top')
        row = row + 1
    str_name = img_name
    axs2[col][0].set_ylabel(str_name, size = 18)
    col = col + 1
plt.tight_layout()
subfigs[0].subplots_adjust(left = 0, top = 1, right = 1, bottom = 0, wspace=0, hspace=0)
subfigs[1].subplots_adjust(left = 0.2, top = 1.0, bottom = 0, wspace = 0.1, hspace = 0.2)
axs2[0][0].annotate(text="", xy=(0.98, 1.032), xytext=(0.505,1.032), xycoords="subfigure fraction",  arrowprops=dict(facecolor='C1'))
axs2[0][0].annotate(text="", xy=(0.02, 1.032), xytext=(0.495,1.032), xycoords="subfigure fraction",  arrowprops=dict(facecolor='C0'))
axs2[0][0].annotate(text="Grade", xy = (0.50,1.05), xycoords="subfigure fraction", ha="center", size = 18)
axs2[0][0].annotate(text="Low", xy = (0.02, 1.05), xycoords="subfigure fraction", ha="center", size = 16)
axs2[0][0].annotate(text="High", xy = (0.98, 1.05), xycoords="subfigure fraction", ha="center", size = 16)
subfigs[0].text(-0.03,1.02,"A", zorder = 30, clip_on = False, weight='bold', size = 20 )
subfigs[1].text(0.1,1.02,"B", zorder = 30, clip_on = False, weight='bold', size = 20)
plt.show()


In [None]:
#The following code loads the saved images for subtype transition, and also loads the results of models trained to predict subtype
#to compare predictions for real slide images vs slide images regenerated with HistoXGAN


import pickle
with open(PROJECT_DIR + '/FINAL_IMAGES/img_display_subtype_lr.pkl', 'rb') as f:
    img_dict = pickle.load(f)
    
fig = plt.figure(figsize=(20, 8), dpi = 300)

subfigs = fig.subfigures(1, 2, width_ratios=[14, 6])

axs = subfigs[1].subplots(4, 2, sharex = True, sharey = True)

result_table = []

result_table += load_compare_group(
    title = "BRCA - TCGA",
    ax = axs[0][0],
    class_names = ["Ductal", "Lobular"],
    ylabel = "BRCA"
)

result_table += load_compare_group(
    title = "LUNG - TCGA",
    ax = axs[1][0],
    class_names = ["Adeno", "Squam"],
    ylabel = "LUNG"
)

result_table += load_compare_group(
    title = "ESCA - TCGA",
    ax = axs[2][0],
    class_names = ["Adeno", "Squam"],
    ylabel = "ESCA"
)

result_table += load_compare_group(
    title = "KIDNEY - TCGA",
    ax = axs[3][0],
    ylabel = "KIDNEY",
    class_names = ["Clear", "Papillary"],
    xlabel = True
)

result_table += load_compare_group(
    title = "BRCA - CPTAC",
    ax = axs[0][1],
    pred_df = True,
    pred_class_col = "histological_type", 
    pred_class_opts = ["Inflitrating Ductal Carcinoma", "Inflitrating Lobular Carcinoma"],
    class_names = ["Ductal", "Lobular"],
    pred_only = True)

result_table += load_compare_group(
    title = "LUNG - CPTAC",
    ax = axs[1][1],
    pred_df = True,
    pred_class_col = "cohort", 
    pred_class_opts = ["LUAD", "LUSC"],
    class_names = ["Adeno", "Squam"],
    pred_only = True,
    xlabel = True
)
axs[3][1].set_visible(False)
axs[2][1].set_visible(False)

df = pd.DataFrame(result_table, columns = ['Dataset', 'Subtype', 'n', 'Histology A', 'Histology A (%)', 'Histology B', 'Histology B (%)', 'Pearson r', 'z statistic', 'p-value, correlation', 'AUROC', 'AP', 'AUROC Gen', 'AP Gen'])
from IPython.display import display
display(df)
axs2 = subfigs[0].subplots(len(img_dict), img_include)

row_loc = {
    'BRCA':[20,30,40,50,60,70,80],
    'LUNG':[20,30,40,50,60,70,80],
    'ESCA':[20,30,40,50,60,70,80],
    'KIDNEY':[0,15,30,50,70,85,99],
}
img_include = 7
col = 0
for img_name in img_dict:
    row = 0
    for row_item in row_loc[img_name]:
        axs2[col][row].imshow(img_dict[img_name][row_item])
        #axs[col][row].spines['right'].set_visible(False)
        #axs[col][row].spines['top'].set_visible(False)
        #axs[col][row].spines['left'].set_visible(False)
        #axs[col][row].spines['bottom'].set_visible(False)
        axs2[col][row].set_xticks([])
        axs2[col][row].set_yticks([])
        axs2[col][row].xaxis.set_label_position('top')
        row = row + 1
    str_name = img_name
    axs2[col][0].set_ylabel(str_name, size = 18)
    col = col + 1
plt.tight_layout()
subfigs[0].subplots_adjust(left = 0, top = 1, right = 1, bottom = 0, wspace=0, hspace=0)
subfigs[1].subplots_adjust(left = 0.2, top = 1.0, bottom = 0, wspace = 0.1, hspace = 0.2)
axs2[0][0].annotate(text="", xy=(0.98, 1.032), xytext=(0.505,1.032), xycoords="subfigure fraction",  arrowprops=dict(facecolor='C1'))
axs2[0][0].annotate(text="", xy=(0.02, 1.032), xytext=(0.495,1.032), xycoords="subfigure fraction",  arrowprops=dict(facecolor='C0'))
axs2[0][0].annotate(text="Subtype", xy = (0.50,1.05), xycoords="subfigure fraction", ha="center", size = 18)
padding = 3
axs2[0][0].annotate(text="Ductal", xy = (0.0, 1.0), xytext=(padding, -padding), xycoords="axes fraction", textcoords='offset points', ha="left", va = "top", size = 16,  bbox=dict(pad = padding, facecolor='white', edgecolor='black'))
axs2[0][img_include - 1].annotate(text="Lobular", xy = (1.0, 1.0), xytext=(-padding, -padding), xycoords="axes fraction", textcoords='offset points', ha="right", va = "top", size = 16,  bbox=dict(pad = padding, facecolor='white', edgecolor='black'))
axs2[1][0].annotate(text="Adeno", xy = (0.0, 1.0), xytext=(padding, -padding), xycoords="axes fraction", textcoords='offset points', ha="left", va = "top", size = 16,  bbox=dict(pad = padding, facecolor='white', edgecolor='black'))
axs2[1][img_include - 1].annotate(text="Squamous", xy = (1.0, 1.0), xytext=(-padding, -padding), xycoords="axes fraction", textcoords='offset points', ha="right", va = "top", size = 16,  bbox=dict(pad = padding, facecolor='white', edgecolor='black'))
axs2[2][0].annotate(text="Adeno", xy = (0.0, 1.0), xytext=(padding, -padding), xycoords="axes fraction", textcoords='offset points', ha="left", va = "top", size = 16,  bbox=dict(pad = padding, facecolor='white', edgecolor='black'))
axs2[2][img_include - 1].annotate(text="Squamous", xy = (1.0, 1.0), xytext=(-padding, -padding), xycoords="axes fraction", textcoords='offset points', ha="right", va = "top", size = 16,  bbox=dict(pad = padding, facecolor='white', edgecolor='black'))
axs2[3][0].annotate(text="Clear", xy = (0.0, 1.0), xytext=(padding, -padding), xycoords="axes fraction", textcoords='offset points', ha="left", va = "top", size = 16,  bbox=dict(pad = padding, facecolor='white', edgecolor='black'))
axs2[3][img_include - 1].annotate(text="Papillary", xy = (1.0, 1.0),xytext=(-padding, -padding), xycoords="axes fraction", textcoords='offset points', ha="right", va = "top", size = 16,  bbox=dict(pad = padding, facecolor='white', edgecolor='black'))

subfigs[0].text(-0.03,1.02,"A", zorder = 30, clip_on = False, weight='bold', size = 20 )
subfigs[1].text(0.1,1.02,"B", zorder = 30, clip_on = False, weight='bold', size = 20)
plt.show()


In [None]:
#The following code loads the saved images for gene transition, and also loads the results of models trained to predict gene expression
#to compare predictions for real slide images vs slide images regenerated with HistoXGAN


import pickle
with open(PROJECT_DIR + '/FINAL_IMAGES/img_display_gene_lr.pkl', 'rb') as f:
    img_dict = pickle.load(f)
    
fig = plt.figure(figsize=(21, 8), dpi = 300)

subfigs = fig.subfigures(1, 2, width_ratios=[14, 7])

axs = subfigs[1].subplots(4, 2, sharex = True, sharey = True)

result_table = []

_, r = load_compare_group(
    title = "CD3G - TCGA",
    ax = axs[0][0],
    class_names = ["Low", "High"],
    color_df = True,
    pred_class_col = "CD3G",
    ylabel = "CD3G",
    pred_only = True
)
result_table += r
_, r = load_compare_group(
    title = "COL1A1 - TCGA",
    ax = axs[1][0],
    class_names = ["Low", "High"],
    ylabel = "COL1A1",
    pred_class_col = "COL1A1",
    color_df = True,
    pred_only = True
)
result_table += r
_, r = load_compare_group(
    title = "MKI67 - TCGA",
    ax = axs[2][0],
    class_names = ["Low", "High"],
    ylabel = "MKI67",
    pred_class_col = "MKI67",
    color_df = True,
    pred_only = True
)
result_table += r
_, r = load_compare_group(
    title = "EPCAM - TCGA",
    ax = axs[3][0],
    ylabel = "EPCAM",
    class_names = ["Low", "High"],
    xlabel = True,
    pred_class_col = "EPCAM",
    color_df = True,
    pred_only = True
)

result_table += r
_, r = load_compare_group(
    title = "CD3G - CPTAC",
    ax = axs[0][1],
    pred_class_col = "CD3G",
    color_df = True,
    pred_only = True,    
    class_names = ["Low", "High"])


result_table += r
_, r = load_compare_group(
    title = "COL1A1 - CPTAC",
    ax = axs[1][1],
    pred_class_col = "COL1A1",
    color_df = True,
    pred_only = True,
    class_names = ["Low", "High"])


result_table += r
_, r = load_compare_group(
    title = "MKI67 - CPTAC",
    ax = axs[2][1],
    pred_class_col = "MKI67",
    color_df = True,
    pred_only = True,
    class_names = ["Low", "High"])


result_table += r
im, r = load_compare_group(
    title = "EPCAM - CPTAC",
    ax = axs[3][1],
    pred_class_col = "EPCAM",
    xlabel = True,
    color_df = True,
    pred_only = True,
    class_names = ["Low", "High"])
result_table += r

axs2 = subfigs[0].subplots(len(img_dict), img_include)

row_loc = {
    'CD3G':[20,30,40,50,60,70,80],
    'COL1A1':[20,30,40,50,60,70,80],
    'MKI67':[20,30,40,50,60,70,80],
    'EPCAM':[20,30,40,50,60,70,80],
}
img_include = 7
col = 0
for img_name in img_dict:
    row = 0
    for row_item in row_loc[img_name]:
        axs2[col][row].imshow(img_dict[img_name][row_item])
        #axs[col][row].spines['right'].set_visible(False)
        #axs[col][row].spines['top'].set_visible(False)
        #axs[col][row].spines['left'].set_visible(False)
        #axs[col][row].spines['bottom'].set_visible(False)
        axs2[col][row].set_xticks([])
        axs2[col][row].set_yticks([])
        axs2[col][row].xaxis.set_label_position('top')
        row = row + 1
    str_name = img_name
    axs2[col][0].set_ylabel(str_name, size = 18)
    col = col + 1
plt.tight_layout()

subfigs[1].subplots_adjust(left = 0.15, top = 1.0, bottom = 0, wspace = 0.07, hspace = 0.2)
subfigs[1].colorbar(im, ax=axs.ravel().tolist(), location = 'right', orientation = 'vertical', pad = 0.02, fraction = 0.07, aspect = 60, anchor = (0,0))

subfigs[0].subplots_adjust(left = 0, top = 1, right = 1, bottom = 0, wspace=0, hspace=0)

axs2[0][0].annotate(text="", xy=(0.98, 1.032), xytext=(0.505,1.032), xycoords="subfigure fraction",  arrowprops=dict(facecolor='C1'))
axs2[0][0].annotate(text="", xy=(0.02, 1.032), xytext=(0.495,1.032), xycoords="subfigure fraction",  arrowprops=dict(facecolor='C0'))
axs2[0][0].annotate(text="Gene", xy = (0.50,1.05), xycoords="subfigure fraction", ha="center", size = 18)
axs2[0][0].annotate(text="Low", xy = (0.02, 1.05), xycoords="subfigure fraction", ha="center", size = 16)
axs2[0][0].annotate(text="High", xy = (0.98, 1.05), xycoords="subfigure fraction", ha="center", size = 16)


subfigs[0].text(-0.03,1.02,"A", zorder = 30, clip_on = False, weight='bold', size = 20 )
subfigs[1].text(0.1,1.02,"B", zorder = 30, clip_on = False, weight='bold', size = 20)

df = pd.DataFrame(result_table, columns = ['Dataset', 'Subtype', 'n', 'Histology A', 'Histology A (%)', 'Histology B', 'Histology B (%)', 'Pearson r', 'z statistic', 'p-value, correlation', 'AUROC', 'AP', 'AUROC Gen', 'AP Gen'])
df = df.drop(columns = ['Histology A', 'Histology A (%)', 'Histology B', 'Histology B (%)'])
from IPython.display import display
display(df)

plt.show()



# Loss Based Feature Vector PCA + Exploration

## This code explores using model loss (rather than average difference in feature vector) to explore how predictions are made. Gradient descent is used to calculate how to make images score higher / lower for a deep learning model. This is repeated 50 times to generate a list of gradients, and PCA is performed to identify principle compoennts of these gradients, representing different directions in the feature space that are important to model predictions.

In [None]:
from slideflow.model.torch import load as load_torch_model

from slideflow.gan.stylegan3.stylegan3 import dnnlib, legacy, utils
with dnnlib.util.open_url(PROJECT_DIR + '/FINAL_MODELS/CTransPath/snapshot.pkl') as f:
    G = legacy.load_network_pkl(f)['G_ema'].to(device)

def vector_interpolate(
    G: torch.nn.Module,
    z: torch.tensor,
    z2: torch.tensor,
    device: torch.device,
    steps: int = 100
):
    for interp_idx in range(steps):
        torch_interp = torch.tensor(z - z2 + 2*interp_idx/steps*z2).to(device)
        img = G(torch_interp, 0, noise_mode ='const')
        img = (img + 1) * (255/2)
        img = img.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
        yield img

def generate_images(vector_z, vector_z2, prefix = 'test'):
    z = torch.tensor(vector_z).to(device)
    z2 = torch.tensor(vector_z2).to(device)
    img_array = []
    generator = vector_interpolate(G, z, z2, device = device)
    for interp_idx, img in enumerate(generator):
        img_array += [img]
    return img_array

def lighten_color(color, amount=1.5):
    """
    Lightens the given color by multiplying (1-luminosity) by the given amount.
    Input can be matplotlib color string, hex string, or RGB tuple.

    Examples:
    >> lighten_color('g', 0.3)
    >> lighten_color('#F034A3', 0.6)
    >> lighten_color((.3,.55,.1), 0.5)
    """
    import matplotlib.colors as mc
    import colorsys
    try:
        c = mc.cnames[color]
    except:
        c = color
    c = colorsys.rgb_to_hls(*mc.to_rgb(c))
    return colorsys.hls_to_rgb(c[0], 1 - amount * (1 - c[1]), c[2])

def get_pca_vectors(df_mod, entries, feat_cols, mod, dataset, vectors_dict, a, b, ax):
    from sklearn.decomposition import PCA
    n_comp = 20
    pca = PCA(n_components=n_comp)
    
    result = pca.fit_transform(vectors_dict['loss'])
    plot_pts = pca.transform(vectors_dict['base'])
    
    full_data = pca.transform(df_mod[feat_cols].values.tolist())
    
    tick = 0.05
    x_list, y_list = np.meshgrid(np.arange(-4, 4.01, tick), np.arange(-4, 4.01, tick))
    important_pca = [x for x in range(n_comp)]
    
    for i in range(n_comp):
        for j in range(n_comp):
            pca_1 = important_pca[i]
            pca_2 = important_pca[j]
            if np.abs(np.mean([result[x][pca_2] for x in range(entries)])) < np.abs(np.mean([result[x][pca_1] for x in range(entries)])):
                important_pca[i] = pca_2
                important_pca[j] = pca_1

    A = np.array(full_data)[:entries, [important_pca[a], important_pca[b]]]
    x = 0
    y = 0

    x_final = []
    y_final = []
    base_final = []
    loss_final = []
    slide_final = []
    grade_pred = []
    count_min = []
    
    model = load_torch_model(mod).eval().to(device)
    preprocess = sf.util.get_preprocess_fn(mod)
    loss_fn = torch.nn.L1Loss().to(device)
    softmax = torch.nn.Softmax(dim = -1)    
    
    verbose = False
    distances = np.linalg.norm(A - np.array((0,0)), axis=1)
    min_index = np.argmin(distances)
    
    min_dist = math.sqrt((tick/2) * (tick/2) + (tick/2)*(tick/2))
    for x_r, y_r in zip(x_list, y_list):
        for x, y in zip(x_r, y_r):
            distances = np.linalg.norm(A - np.array((x,y)), axis=1)
            min_index = np.argmin(distances)
            
            if distances[min_index] < min_dist:
                count_min += [(distances < 0.1).sum()]
                vector_patient = df_mod['Slide'].loc[min_index][0:12]
                grade_pred += [vectors_dict['pred'][min_index]]
                x_final += [x]
                y_final += [y]
                base_final += [vectors_dict['base'][min_index]]
                loss_final += [vectors_dict['loss'][min_index]]
                if verbose:
                    print(vector_patient)
                    print(vector_grade)
                    print(x)
                    print(y)
                    print(min_index)
                    print(A[min_index])
                    print(distances[min_index])
                    print()

    base_pca = pca.transform(base_final)
    loss_pca = np.array(pca.transform(loss_final))

    from sklearn.preprocessing import normalize
    loss_pca_norm = normalize(loss_pca[:, [important_pca[a], important_pca[b]]])
    
    from sklearn.preprocessing import QuantileTransformer
    qt = QuantileTransformer(n_quantiles=100, random_state=0)
    import matplotlib
    cmap = matplotlib.cm.get_cmap('Blues')

    for x, y, l1, l2, g, lw in zip(x_final, y_final, loss_pca_norm[:, 0], loss_pca_norm[:, 1], grade_pred, qt.fit_transform(np.linalg.norm(loss_pca[:, [important_pca[a], important_pca[b]]], axis = 1).reshape(-1,1))):
        ax.quiver(x, y, l1, l2, color = cmap(g/2 + 0.25), scale = 35, width = lw/200 + 0.005, headwidth = 3, headlength = 1.5, headaxislength = 1.5)#, width = lw/400 + 0.002, headlength = 10*lw)

    ax.set_xticks([])
    ax.set_yticks([])
    max_a = np.argmax(loss_pca[:, important_pca[a]])
    max_b = np.argmax(loss_pca[:, important_pca[b]])
    if max_b == max_a:
        max_b = np.argsort(loss_pca[:, important_pca[b]])[-2]
    str_list = ["Max Gradient PC 1", "Max Gradient PC 2", "Max Gradient PC 3", "Max Gradient PC 4"]
    cmap = matplotlib.cm.get_cmap('Reds')
    qt_res = qt.fit_transform(np.linalg.norm(loss_pca[:, [important_pca[a], important_pca[b]]], axis = 1).reshape(-1,1))
    ax.quiver(x_final[max_a], y_final[max_a], loss_pca_norm[max_a, 0], loss_pca_norm[max_a, 1], color = 'C1', scale = 35, width = qt_res[max_a][0]/200 + 0.005, headwidth = 3, headlength = 1.5, headaxislength = 1.5, label = str_list[a])#, width = lw/400 + 0.002, headlength = 10*lw)
    ax.quiver(x_final[max_b], y_final[max_b], loss_pca_norm[max_b, 0], loss_pca_norm[max_b, 1], color = lighten_color('C1'), scale = 35, width = qt_res[max_b][0]/200 + 0.005, headwidth = 3, headlength = 1.5, headaxislength = 1.5, label = str_list[b])
    ax.legend()
    return df_mod[feat_cols].loc[max_a, :].values.tolist(), df_mod[feat_cols].loc[max_b, :].values.tolist()


def load_models(mod, patient = 200, entries = 50, full_set = False, regen = False, dataset = 'BRCA', fig_overall = None, label = {}, race = False):
    
    #Get the dataset
    from random import randrange
    from tqdm import tqdm
    if full_set:
        df_mod = pd.read_csv(PROJECT_DIR + '/PROJECTS/HistoXGAN/SAVED_FEATURES/' + dataset.lower() + '_features_slide.csv') 
        entries = len(df_mod.index)
    else:
        #Not implemented for public release given need for full extracted features from slides (~4 GB per dataset)
        raise NotImplementedError("Requires full set of extracted features - not implemented")

    feat_cols = list(df_mod.columns.values)
    feat_cols = [f for f in feat_cols if 'Feature_' in f]      
    print("Rows imported")
    # Load the slideflow model
    model = load_torch_model(mod).eval().to(device)
    preprocess = sf.util.get_preprocess_fn(mod)
    loss_fn = torch.nn.L1Loss().to(device)
    softmax = torch.nn.Softmax(dim = -1)    
    vectors_dict = {}
    vectors_dict['loss'] = []
    vectors_dict['base'] = []
    vectors_dict['slide'] = []
    vectors_dict['pred'] = []
    import os.path
    if not full_set:
        if not regen and os.path.isfile(os.path.dirname(mod) + "/" + str(entries) + "_vectors_dict_pca_nav.pkl"):
            with open(os.path.dirname(mod) + "/" + str(entries) + "_vectors_dict_pca_nav.pkl", 'rb') as f:
                vectors_dict = pickle.load(f)
        else:
            for tensor_vector_target in [1,0]:
                for i in tqdm(range(entries)):
                    r = i
                    vector_base = df_mod[feat_cols].loc[r, :].values.tolist()
                    vectors_dict['base'] += [vector_base]
                    vectors_dict['slide'] += df_mod['Slide'].loc[r]
                    vector_base_tensor = torch.tensor([vector_base]).to(device)
                    vector_base_tensor.requires_grad = True
                    optimizer = torch.optim.SGD([vector_base_tensor], lr=1e-3)
                    pred = softmax(model(G(vector_base_tensor, 0, noise_mode ='const')))
                    loss = loss_fn(pred[0][1], torch.tensor(tensor_vector_target).to(device))
                    vectors_dict['pred'] += [pred[0][1].detach().cpu().numpy()]
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    vectors_dict['loss'] += [-1 * vector_base_tensor.grad.detach().cpu().numpy()[0]]
            with open(os.path.dirname(mod) + "/" + str(entries) + "_vectors_dict_pca_nav.pkl", 'wb') as f:
                pickle.dump(vectors_dict, f)
    else:
        entires = len(df_mod.index)
        if not regen and os.path.isfile(os.path.dirname(mod)  + "/" +  "fullset_vectors_dict_pca_nav.pkl"):
            with open(os.path.dirname(mod)  + "/" +  "fullset_vectors_dict_pca_nav.pkl", 'rb') as f:
                vectors_dict = pickle.load(f)
        else:
            for tensor_vector_target in [1,0]:
                for i in tqdm(range(len(df_mod.index))):
                    vector_base = df_mod[feat_cols].loc[i, :].values.tolist()
                    vectors_dict['base'] += [vector_base]
                    vectors_dict['slide'] += df_mod['Slide'].loc[i]
                    vector_base_tensor = torch.tensor([vector_base]).to(device)
                    vector_base_tensor.requires_grad = True
                    optimizer = torch.optim.SGD([vector_base_tensor], lr=1e-3)
                    pred = softmax(model(G(vector_base_tensor, 0, noise_mode ='const')))
                    loss = loss_fn(pred[0][1], torch.tensor(tensor_vector_target).to(device))
                    vectors_dict['pred'] += [pred[0][1].detach().cpu().numpy()]
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    vectors_dict['loss'] += [-1 * vector_base_tensor.grad.detach().cpu().numpy()[0]]
            with open(os.path.dirname(mod)  + "/" +  "fullset_vectors_dict_pca_nav.pkl", 'wb') as f:
                pickle.dump(vectors_dict, f)

    #df_plot = pd.DataFrame(data = [plot_slides, plot_vectors, loss_vectors], columns = ['Slide', 'Vector', 'Loss'])
    
    from sklearn.decomposition import PCA
    n_comp = 20
    pca = PCA(n_components=n_comp)
    #result_base = pca.fit_transform(vectors_dict['base'])
    
    #result = pca.transform(vectors_dict['loss'])
    
    result = pca.fit_transform(vectors_dict['loss'])
    subfigs = fig_overall.subfigures(1,3, width_ratios = [6, 20, 3])

    axs1 = subfigs[0].subplots(2, 1)

    patient_vectors = []
    a, b = get_pca_vectors(df_mod, entries, feat_cols, mod, dataset, vectors_dict, 0, 1, axs1[0])
    c, d = get_pca_vectors(df_mod, entries, feat_cols, mod, dataset, vectors_dict, 2, 3, axs1[1])
    patient_vectors = [a,b,c,d]
    axs1[0].set_xlabel('PC 1')
    axs1[0].set_ylabel('PC 2')
    axs1[1].set_xlabel('PC 3')
    axs1[1].set_ylabel('PC 4')
    
    img_dict = {}
    row_loc = {}
    important_pca = [x for x in range(n_comp)]
    for i in range(n_comp):
        for j in range(n_comp):
            pca_1 = important_pca[i]
            pca_2 = important_pca[j]
            if np.abs(np.mean([result[x][pca_2] for x in range(entries)])) < np.abs(np.mean([result[x][pca_1] for x in range(entries)])):
                important_pca[i] = pca_2
                important_pca[j] = pca_1
    if not regen and os.path.isfile(os.path.dirname(mod)  + "/" +  "img_dict_pca_nav.pkl"):
        with open(os.path.dirname(mod)  + "/" +  "img_dict_pca_nav.pkl", 'rb') as f:
            img_dict = pickle.load(f)
    else:
        for i in range(4):
            img_dict[str(i)] = generate_images([patient_vectors[i]], [pca.components_[important_pca[i]]])
            if np.mean([result[x][important_pca[i]] for x in range(entries)]) > 0:
                row_loc[str(i)] = [15,30,40,50,60,70,85]
            else:
                row_loc[str(i)] = [85, 70, 60, 50, 40, 30, 15]
            grades = []
            for rs in row_loc[str(i)]:
                z = torch.tensor([patient_vectors[i]])
                z2 = torch.tensor([pca.components_[important_pca[i]]])
                torch_interp = torch.tensor(z - z2 + 2*rs/100*z2).to(device)
                img = G(torch_interp, 0, noise_mode ='const')
                m = model(img)
                grades += [softmax(m)[0].detach().cpu().numpy()[1]]
            print(grades)
        with open(os.path.dirname(mod)  + "/" +  "img_dict_pca_nav.pkl", 'wb') as f:
            pickle.dump(img_dict, f)
    for i in range(4):
        if np.mean([result[x][important_pca[i]] for x in range(entries)]) > 0:
            row_loc[str(i)] = [15,30,40,50,60,70,85]
        else:
            row_loc[str(i)] = [85, 70, 60, 50, 40, 30, 15]
    img_include = 7 
    rows = len(img_dict)
    #gs = plt.GridSpec(ncols = img_include, nrows = rows, left = 0, top = 1, right = 1, bottom = 0, wspace=0, hspace=0)
    #gs_final = plt.GridSpec(ncols = img_include, nrows = rows, left = 0, top = 1, right = 1, bottom = 0, wspace=0, hspace=0)
    #, axes = plt.subplots(rows, img_include, figsize = (2*img_include,2*rows))
    axs = subfigs[1].subplots(len(img_dict), img_include)

    imp_total = 0
    for col in range(20):
        imp_total += np.abs(np.mean([result[x][important_pca[col]] for x in range(entries)]) - np.mean([result[entries + x][important_pca[col]] for x in range(entries)]))
        
    col = 0
    axs2 = subfigs[2].subplots(len(img_dict), 1)
    
    res_dict = []
    for img_name in img_dict:
        row = 0
        for row_item in row_loc[img_name]:   
            axs[col][row].imshow(img_dict[img_name][row_item])
            #axs[col][row].spines['right'].set_visible(False)
            #axs[col][row].spines['top'].set_visible(False)
            #axs[col][row].spines['left'].set_visible(False)
            #axs[col][row].spines['bottom'].set_visible(False)
            axs[col][row].set_xticks([])
            axs[col][row].set_yticks([])
            axs[col][row].xaxis.set_label_position('top')
            row = row + 1
        str_trans = {'0':'PC 1', '1':'PC 2', '2':'PC 3', '3':'PC 4'}
        str_name = img_name
        axs[col][0].set_ylabel(str_trans[str_name], size = 18)

        axs2[col].barh([1 - 0.2], [np.abs(np.mean([result[x][important_pca[col]] for x in range(entries)]) - np.mean([result[entries + x][important_pca[col]] for x in range(entries)])) / imp_total], height = 0.8, label = "Contribution to Prediction")
        axs2[col].barh([0 - 0.2], [pca.explained_variance_ratio_[important_pca[col]]], height = 0.8, label = "Variance Explained")

        res_dict_col = [dataset, label[0], entries]
        
        res_dict_col += [np.abs(np.mean([result[x][important_pca[col]] for x in range(entries)]) - np.mean([result[entries + x][important_pca[col]] for x in range(entries)])) / imp_total]
        res_dict_col += [np.sqrt(np.square(np.std([result[x][important_pca[col]] for x in range(entries)])) + np.square(np.std([result[entries + x][important_pca[col]] for x in range(entries)]))) / imp_total]
        res_dict_col += [pca.explained_variance_ratio_[important_pca[col]]]
        if col == 0:
            res_dict += [res_dict_col]
        axs2[col].spines['right'].set_visible(False)
        axs2[col].spines['top'].set_visible(False)
        #axs[col][row].spines['left'].set_visible(False)
        axs2[col].set_yticks([])
        axs2[col].axes.get_yaxis().set_visible(False)
        axs2[col].set_ylim([-1,2])
        axs2[col].set_xlim([0, 1])
        if col != rows - 1:
            axs2[col].axes.get_xaxis().set_visible(False)
            axs2[col].spines['bottom'].set_visible(False)
        if col == 0:
            axs2[col].legend(bbox_to_anchor=(0.04, 0.7))
        col = col + 1
    plt.tight_layout()
    subfigs[1].subplots_adjust(left = 0, top = 1, right = 1, bottom = 0.1, wspace=0, hspace=0)
    subfigs[2].subplots_adjust(left = 0, top = 1, right = 1, bottom = 0.1, wspace=0, hspace=0)
    subfigs[0].subplots_adjust(left = 0, top = 1, right = 0.85, bottom = 0.1, wspace=0, hspace =0.1)
    axs[0][0].annotate(text="", xy=(0.98, 1.032), xytext=(0.505,1.032), xycoords="subfigure fraction",  arrowprops=dict(facecolor='C1'))
    axs[0][0].annotate(text="", xy=(0.02, 1.032), xytext=(0.495,1.032), xycoords="subfigure fraction",  arrowprops=dict(facecolor='C0'))
    axs[0][0].annotate(text=label[0], xy = (0.50,1.05), xycoords="subfigure fraction", ha="center", size = 18)
    axs[0][0].annotate(text=label[1], xy = (0.02, 1.05), xycoords="subfigure fraction", ha="center", size = 16)
    axs[0][0].annotate(text=label[2], xy = (0.98, 1.05), xycoords="subfigure fraction", ha="center", size = 16)
    
    subfigs[0].text(-0.03,1.02,label[3], zorder = 30, clip_on = False, weight='bold', size = 20 )
    subfigs[1].text(-0.03,1.02,label[4], zorder = 30, clip_on = False, weight='bold', size = 20)
    subfigs[2].text(0.1,1.02,label[5], zorder = 30, clip_on = False, weight='bold', size = 20)
    return res_dict

## Loss based PCA feature space exploration comparison for grade and tissue source site

In [None]:
fig_overall = plt.figure(figsize=(18, 16), dpi = 300)
subfigs = fig_overall.subfigures(2, 1)
res_dict = []
res_dict = load_models(PROJECT_DIR + '/PRETRAINED_MODELS/BRCA_GRADE/high_grade.zip', entries = 500, full_set = True, regen = False, fig_overall = subfigs[0], label = ['Grade', 'Low', 'High', 'A', 'B', 'C'])
res_dict += load_models(PROJECT_DIR + '/PRETRAINED_MODELS/BRCA_SITE/SITE-HP0_epoch3.zip', entries = 500, full_set = True, regen = False, fig_overall = subfigs[1], label = ['Site', '', '', 'D', 'E', 'F'])

df_res = pd.DataFrame(res_dict, columns = ['Cancer', 'Outcome', 'n', 'Contribution', 'stdev', 'Variance'])
df_res['Contribution'] = df_res['Contribution'].map(lambda x: '{:.3f}'.format(x)) + " (" + df_res['stdev'].map(lambda x: '{:.3f}'.format(x)) + ") "
df_res['Variance'] = df_res['Variance'].map(lambda x: '{:.3f}'.format(x))
df_res.drop(columns = ['stdev'], inplace = True)
from IPython.display import display
display(df_res)

plt.show()

## The following is used to perform gradient descent to identify principle components for site, ancestry, and after normalization for site

In [None]:
def load_models_site(mod, patient = 200, entries = 50, full_set = False, regen = False, dataset = 'BRCA', fig_overall = None, label = {}, fixed_patient = False, fixed_patient_values = None):
    
    #Get the dataset
    
    from random import randrange
    from tqdm import tqdm
    if full_set:
        df_mod = pd.read_csv(PROJECT_DIR + '/PROJECTS/HistoXGAN/SAVED_FEATURES/' + dataset.lower() + '_features_slide.csv') 
        entries = len(df_mod.index)
    else:
        #Not implemented for public release given need for full extracted features from slides (~4 GB per dataset)
        raise NotImplementedError("Requires full set of extracted features - not implemented")

    feat_cols = list(df_mod.columns.values)
    feat_cols = [f for f in feat_cols if 'Feature_' in f]      
    print("Rows imported")
    # Load the slideflow model
    model = load_torch_model(mod).eval().to(device)
    preprocess = sf.util.get_preprocess_fn(mod)
    loss_fn = torch.nn.L1Loss().to(device)
    softmax = torch.nn.Softmax(dim = -1)    
    vectors_dict = {}
    vectors_dict['loss'] = []
    vectors_dict['base'] = []
    vectors_dict['slide'] = []
    vectors_dict['pred'] = []
    import os.path
    if not full_set:
        if not regen and os.path.isfile(os.path.dirname(mod)  + "/" +  str(entries) + "_vectors_dict.pkl"):
            with open(os.path.dirname(mod)  + "/" +  str(entries) + "_vectors_dict.pkl", 'rb') as f:
                vectors_dict = pickle.load(f)
        else:
            for tensor_vector_target in [1,0]:
                for i in tqdm(range(entries)):
                    r = i
                    vector_base = df_mod[feat_cols].loc[r, :].values.tolist()
                    vectors_dict['base'] += [vector_base]
                    vectors_dict['slide'] += df_mod['Slide'].loc[r]
                    vector_base_tensor = torch.tensor([vector_base]).to(device)
                    vector_base_tensor.requires_grad = True
                    optimizer = torch.optim.SGD([vector_base_tensor], lr=1e-3)
                    pred = softmax(model(G(vector_base_tensor, 0, noise_mode ='const')))
                    loss = loss_fn(pred[0][1], torch.tensor(tensor_vector_target).to(device))
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    vectors_dict['loss'] += [-1 * vector_base_tensor.grad.detach().cpu().numpy()[0]]
            with open(os.path.dirname(mod)  + "/" +  str(entries) + "_vectors_dict.pkl", 'wb') as f:
                pickle.dump(vectors_dict, f)
    else:
        entires = len(df_mod.index)
        if not regen and os.path.isfile(os.path.dirname(mod)  + "/" +  "fullset_vectors_dict.pkl"):
            with open(os.path.dirname(mod)  + "/" +  "fullset_vectors_dict.pkl", 'rb') as f:
                vectors_dict = pickle.load(f)
        else:
            for tensor_vector_target in [1,0]:
                for i in tqdm(range(len(df_mod.index))):
                    vector_base = df_mod[feat_cols].loc[i, :].values.tolist()
                    vectors_dict['base'] += [vector_base]
                    vectors_dict['slide'] += df_mod['Slide'].loc[i]
                    vector_base_tensor = torch.tensor([vector_base]).to(device)
                    vector_base_tensor.requires_grad = True
                    optimizer = torch.optim.SGD([vector_base_tensor], lr=1e-3)
                    pred = softmax(model(G(vector_base_tensor, 0, noise_mode ='const')))
                    loss = loss_fn(pred[0][1], torch.tensor(tensor_vector_target).to(device))
                    vectors_dict['pred'] += [pred[0][1].detach().cpu().numpy()]
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    vectors_dict['loss'] += [-1 * vector_base_tensor.grad.detach().cpu().numpy()[0]]
            with open(os.path.dirname(mod)  + "/" +  "fullset_vectors_dict.pkl", 'wb') as f:
                pickle.dump(vectors_dict, f)

    #df_plot = pd.DataFrame(data = [plot_slides, plot_vectors, loss_vectors], columns = ['Slide', 'Vector', 'Loss'])
    
    from sklearn.decomposition import PCA
    n_comp = 20
    pca = PCA(n_components=n_comp)
    #result_base = pca.fit_transform(vectors_dict['base'])
    
    #result = pca.transform(vectors_dict['loss'])
    
    result = pca.fit_transform(vectors_dict['loss'])
    
    important_pca = [x for x in range(n_comp)]
    for i in range(n_comp):
        for j in range(n_comp):
            pca_1 = important_pca[i]
            pca_2 = important_pca[j]
            if np.abs(np.mean([result[x][pca_2] for x in range(entries)])) < np.abs(np.mean([result[x][pca_1] for x in range(entries)])):
                important_pca[i] = pca_2
                important_pca[j] = pca_1
                
    if fixed_patient:
        import numbers
        if fixed_patient_values:
            patient_vectors = {}
            patient_vectors[0] = vectors_dict['base'][fixed_patient_values]
            with open(PROJECT_DIR + "/pub_pkl/patient_vector_site_" +dataset + ".pkl", 'wb') as f:
                pickle.dump(patient_vectors, f)           
            #print(patient_vectors)
        if os.path.isfile(PROJECT_DIR + "/pub_pkl/patient_vector_site_" +dataset + ".pkl"):
            with open(PROJECT_DIR + "/pub_pkl/patient_vector_site_" +dataset + ".pkl", 'rb') as f:
                patient_vectors = pickle.load(f)
        #print(patient_vectors)
    else:
        patient_vectors = []
        print(np.argmax(result[:entries, important_pca[0]]))
        a = vectors_dict['base'][np.argmax(result[:entries, important_pca[0]])]
        b = vectors_dict['base'][np.argmax(result[:entries, important_pca[1]])]
        c = vectors_dict['base'][np.argmax(result[:entries, important_pca[2]])]
        d = vectors_dict['base'][np.argmax(result[:entries, important_pca[3]])]
        patient_vectors = [a,b,c,d]  
    img_dict = {}
    row_loc = {}

    img_dict = generate_images([patient_vectors[0]], [pca.components_[important_pca[0]]])
    imp_total = 0
    for col in range(20):
        imp_total += np.abs(np.mean([result[x][important_pca[col]] for x in range(entries)]) - np.mean([result[entries + x][important_pca[col]] for x in range(entries)]))
        #print(imp_total)
    contrib = np.abs(np.mean([result[x][important_pca[0]] for x in range(entries)]) - np.mean([result[entries + x][important_pca[0]] for x in range(entries)])) / imp_total
    var_ratio = pca.explained_variance_ratio_[important_pca[0]]
    if np.mean([result[x][important_pca[i]] for x in range(entries)]) > np.mean([result[x + entries][important_pca[i]] for x in range(entries)]):
        row_loc = [15,30,40,50,60,70,85]
    else:
        print("Reverse " + str(i))
        row_loc = [85, 70, 60, 50, 40, 30, 15]
    return img_dict, row_loc, contrib, var_ratio

In [None]:
def plot_models(regen_top = False, regen = False, subfigs = None, models = None, datasets = None, group = None, full_set=True, entries = 500, fixed_patient = None, fixed_patient_values = None, label = [], race = False):
    img_dict = {}
    row_loc = {}
    contrib = {}
    var_ratio = {}
    if not regen_top and os.path.isfile(PROJECT_DIR + "/pub_pkl/" + group + "img_dict.pkl"):
        with open(PROJECT_DIR + "/pub_pkl/" + group + "img_dict.pkl", 'rb') as f:
            img_dict = pickle.load(f)
        with open(PROJECT_DIR + "/pub_pkl/" + group + "row_loc.pkl", 'rb') as f:
            row_loc = pickle.load(f)
        with open(PROJECT_DIR + "/pub_pkl/" + group + "contrib.pkl", 'rb') as f:
            contrib = pickle.load(f)
        with open(PROJECT_DIR + "/pub_pkl/" + group + "var_ratio.pkl", 'rb') as f:
            var_ratio = pickle.load(f)
    else:
        if fixed_patient_values:
            img_dict[datasets[0]], row_loc[datasets[0]], contrib[datasets[0]], var_ratio[datasets[0]] = load_models_site(models[0], dataset = datasets[0], entries = entries, full_set = full_set, regen = regen, fixed_patient = fixed_patient, fixed_patient_values = fixed_patient_values[0])
            img_dict[datasets[1]], row_loc[datasets[1]], contrib[datasets[1]], var_ratio[datasets[1]] = load_models_site(models[1], dataset = datasets[1], entries = entries, full_set = full_set, regen = regen, fixed_patient = fixed_patient, fixed_patient_values = fixed_patient_values[1])
            img_dict[datasets[2]], row_loc[datasets[2]], contrib[datasets[2]], var_ratio[datasets[2]] = load_models_site(models[2], dataset = datasets[2], entries = entries, full_set = full_set, regen = regen, fixed_patient = fixed_patient, fixed_patient_values = fixed_patient_values[2])
            img_dict[datasets[3]], row_loc[datasets[3]], contrib[datasets[3]], var_ratio[datasets[3]] = load_models_site(models[3], dataset = datasets[3], entries = entries, full_set = full_set, regen = regen, fixed_patient = fixed_patient, fixed_patient_values = fixed_patient_values[3])        
        else:
            img_dict[datasets[0]], row_loc[datasets[0]], contrib[datasets[0]], var_ratio[datasets[0]] = load_models_site(models[0], dataset = datasets[0], entries = entries, full_set = full_set, regen = regen, fixed_patient = fixed_patient)
            img_dict[datasets[1]], row_loc[datasets[1]], contrib[datasets[1]], var_ratio[datasets[1]] = load_models_site(models[1], dataset = datasets[1], entries = entries, full_set = full_set, regen = regen, fixed_patient = fixed_patient)
            img_dict[datasets[2]], row_loc[datasets[2]], contrib[datasets[2]], var_ratio[datasets[2]] = load_models_site(models[2], dataset = datasets[2], entries = entries, full_set = full_set, regen = regen, fixed_patient = fixed_patient)
            img_dict[datasets[3]], row_loc[datasets[3]], contrib[datasets[3]], var_ratio[datasets[3]] = load_models_site(models[3], dataset = datasets[3], entries = entries, full_set = full_set, regen = regen, fixed_patient = fixed_patient)
        with open(PROJECT_DIR + "/pub_pkl/" + group + "img_dict.pkl", 'wb') as f:
            pickle.dump(img_dict, f)
        with open(PROJECT_DIR + "/pub_pkl/" + group + "row_loc.pkl", 'wb') as f:
            pickle.dump(row_loc, f)
        with open(PROJECT_DIR + "/pub_pkl/" + group + "contrib.pkl", 'wb') as f:
            pickle.dump(contrib, f)
        with open(PROJECT_DIR + "/pub_pkl/" + group + "var_ratio.pkl", 'wb') as f:
            pickle.dump(var_ratio, f)    
    img_include = 3
    rows = len(img_dict)
    #gs = plt.GridSpec(ncols = img_include, nrows = rows, left = 0, top = 1, right = 1, bottom = 0, wspace=0, hspace=0)
    #gs_final = plt.GridSpec(ncols = img_include, nrows = rows, left = 0, top = 1, right = 1, bottom = 0, wspace=0, hspace=0)
    #, axes = plt.subplots(rows, img_include, figsize = (2*img_include,2*rows))
    axs = subfigs[0].subplots(len(img_dict), img_include)
    
    col = 0
    axs2 = subfigs[1].subplots(len(img_dict), 1)
    
    for img_name in img_dict:
        row_loc[img_name] = [row_loc[img_name][1], row_loc[img_name][3], row_loc[img_name][5]]
        row = 0
        for row_item in row_loc[img_name]:   
            axs[col][row].imshow(img_dict[img_name][row_item])
            #axs[col][row].spines['right'].set_visible(False)
            #axs[col][row].spines['top'].set_visible(False)
            #axs[col][row].spines['left'].set_visible(False)
            #axs[col][row].spines['bottom'].set_visible(False)
            axs[col][row].set_xticks([])
            axs[col][row].set_yticks([])
            axs[col][row].xaxis.set_label_position('top')
            row = row + 1
        axs[col][0].set_ylabel(img_name, size = 18)

        axs2[col].barh([1 - 0.2], [contrib[img_name]], height = 0.8, label = "Contribution to Prediction")
        axs2[col].barh([0 - 0.2], [var_ratio[img_name]], height = 0.8, label = "Variance Explained")

        axs2[col].spines['right'].set_visible(False)
        axs2[col].spines['top'].set_visible(False)
        #axs[col][row].spines['left'].set_visible(False)
        axs2[col].set_yticks([])
        axs2[col].axes.get_yaxis().set_visible(False)
        axs2[col].set_ylim([-1,2])
        axs2[col].set_xlim([0,1])
        #axs2[col].set_xlim([0, np.mean([result[x][important_pca[0]] for x in range(entries)]) / imp_total])
        if col != rows - 1:
            axs2[col].axes.get_xaxis().set_visible(False)
            axs2[col].spines['bottom'].set_visible(False)
        if col == 0:
            axs2[col].legend(bbox_to_anchor=(0.04, 0.72))
        col = col + 1
    #plt.tight_layout()
    subfigs[0].subplots_adjust(left = 0, top = 1, right = 1, bottom = 0 + 1/9, wspace=0, hspace=0)
    subfigs[1].subplots_adjust(left = 0, top = 1, right = 1  - 4/16, bottom = 0  + 1/9, wspace=0, hspace=0)
    axs[0][0].annotate(text="", xy=(0.98, 1.032), xytext=(0.505,1.032), xycoords="subfigure fraction",  arrowprops=dict(facecolor='C1'))
    axs[0][0].annotate(text="", xy=(0.02, 1.032), xytext=(0.495,1.032), xycoords="subfigure fraction",  arrowprops=dict(facecolor='C0'))
    axs[0][0].annotate(text=label[0], xy = (0.50,1.05), xycoords="subfigure fraction", ha="center", size = 18)
    axs[0][0].annotate(text=label[1], xy = (0.10, 1.05), xycoords="subfigure fraction", ha="center", size = 16)
    axs[0][0].annotate(text=label[2], xy = (0.90, 1.05), xycoords="subfigure fraction", ha="center", size = 16)

    if race:
        t = axs[0][0].annotate(text="Asian", xy = (0.028, 1- 0.029), xycoords="axes fraction", ha="left", va="top", size = 16)
        t.set_bbox(dict(facecolor='white', alpha=1, edgecolor='black'))
        t = axs[1][0].annotate(text="Asian", xy = (0.028, 1- 0.029), xycoords="axes fraction", ha="left", va="top", size = 16)
        t.set_bbox(dict(facecolor='white', alpha=1, edgecolor='black'))
        t = axs[2][0].annotate(text="African", xy = (0.028, 1- 0.029), xycoords="axes fraction", ha="left", va="top", size = 16)
        t.set_bbox(dict(facecolor='white', alpha=1, edgecolor='black'))
        t = axs[3][0].annotate(text="Asian", xy = (0.028, 1- 0.029), xycoords="axes fraction", ha="left", va="top", size = 16)
        t.set_bbox(dict(facecolor='white', alpha=1, edgecolor='black'))

        t = axs[0][row - 1].annotate(text="European", xy = (1- 0.028, 1- 0.029), xycoords="axes fraction", ha="right", va="top", size = 16)
        t.set_bbox(dict(facecolor='white', alpha=1, edgecolor='black'))
        t = axs[1][row - 1].annotate(text="European", xy = (1- 0.028, 1- 0.029), xycoords="axes fraction", ha="right", va="top", size = 16)
        t.set_bbox(dict(facecolor='white', alpha=1, edgecolor='black'))
        t = axs[2][row - 1].annotate(text="European", xy = (1- 0.028, 1- 0.029), xycoords="axes fraction", ha="right", va="top", size = 16)
        t.set_bbox(dict(facecolor='white', alpha=1, edgecolor='black'))
        t = axs[3][row - 1].annotate(text="European", xy = (1- 0.028, 1- 0.029), xycoords="axes fraction", ha="right", va="top", size = 16)
        t.set_bbox(dict(facecolor='white', alpha=1, edgecolor='black'))
    
    subfigs[0].text(-0.05,1.02, label[3], zorder = 30, clip_on = False, weight='bold', size = 20 )
    subfigs[1].text(-0.01,1.02, label[4], zorder = 30, clip_on = False, weight='bold', size = 20)



In [None]:
fig_overall = plt.figure(figsize=(18, 18), dpi = 300)
subfigs_overall = fig_overall.subfigures(2, 2)

subfigs1 = subfigs_overall[0][0].subfigures(1, 2, width_ratios = [6, 3])
plot_models(False, False, subfigs1, full_set = False, models = [PROJECT_DIR + "/PRETRAINED_MODELS/BRCA_SITE/SITE-HP0_epoch3.zip",
                                                                PROJECT_DIR + "/PRETRAINED_MODELS/BLCA_SITE/SITE-HP0_epoch3.zip",
                                                                PROJECT_DIR + "/PRETRAINED_MODELS/COADREAD_SITE/SITE-HP0_epoch3.zip",
                                                                PROJECT_DIR + "/PRETRAINED_MODELS/ESCA_SITE/SITE-HP0_epoch3.zip"],
                                                              datasets = ['BRCA', 'BLCA', 'COADREAD', 'ESCA'], group = 'A', fixed_patient = True, label = ['Site', 'Site 1', 'Site 2', 'A', ''])#, fixed_patient_values = [126, 453, 397, 138 ]) #fixed_patient = [True, True, True, True]) #'# 
                                                                #good: brca 119 (small area of cancer), 388 (too dark?), 136 (lots of lymph)
                                                                #good: brca 124 - somewhat dark - COADREAD 392 - a little whiteish but good
subfigs2 = subfigs_overall[1][1].subfigures(1, 2, width_ratios = [6, 3])
plot_models(False, False, subfigs2, full_set = False, models = [PROJECT_DIR + "/PRETRAINED_MODELS/BRCA_SITE_CYCLEGAN/SITE-HP0_epoch3.zip",
                                                                PROJECT_DIR + "/PRETRAINED_MODELS/BLCA_SITE_CYCLEGAN/SITE-HP0_epoch3.zip",
                                                                PROJECT_DIR + "/PRETRAINED_MODELS/COADREAD_SITE_CYCLEGAN/SITE-HP0_epoch3.zip",
                                                                PROJECT_DIR + "/PRETRAINED_MODELS/ESCA_SITE_CYCLEGAN/SITE-HP0_epoch3.zip"],
                                                              datasets = ['BRCA', 'BLCA', 'COADREAD', 'ESCA'], group = 'C', fixed_patient = True, label = ['Site (CycleGAN Norm)', 'Site 1', 'Site 2', 'D', ''])

subfigs3 = subfigs_overall[0][1].subfigures(1, 2, width_ratios = [6, 3])
plot_models(False, False, subfigs3, full_set = False, models = [PROJECT_DIR + "/PRETRAINED_MODELS/BRCA_SITE_REINHARD/SITE-HP0_epoch3.zip",
                                                                PROJECT_DIR + "/PRETRAINED_MODELS/BLCA_SITE_REINHARD/SITE-HP0_epoch3.zip",
                                                                PROJECT_DIR + "/PRETRAINED_MODELS/COADREAD_SITE_REINHARD/SITE-HP0_epoch3.zip",
                                                                PROJECT_DIR + "/PRETRAINED_MODELS/ESCA_SITE_REINHARD/SITE-HP0_epoch3.zip"],
                                                              datasets = ['BRCA', 'BLCA', 'COADREAD', 'ESCA'], group = 'D', fixed_patient = True, label = ['Site (Reinhard Norm)', 'Site 1', 'Site 2', 'C', ''])

subfigs4 = subfigs_overall[1][0].subfigures(1, 2, width_ratios = [6, 3])
plot_models(False, False, subfigs4, full_set = False, models = [PROJECT_DIR + "/PRETRAINED_MODELS/BRCA_ANCESTRY/SITE-HP0_epoch3.zip",
                                                                PROJECT_DIR + "/PRETRAINED_MODELS/BLCA_ANCESTRY/SITE-HP0_epoch3.zip",
                                                                PROJECT_DIR + "/PRETRAINED_MODELS/COADREAD_ANCESTRY/SITE-HP0_epoch3.zip",
                                                                PROJECT_DIR + "/PRETRAINED_MODELS/ESCA_ANCESTRY/SITE-HP0_epoch3.zip"],
                                                              datasets = ['BRCA', 'BLCA', 'COADREAD', 'ESCA'], group = 'B', fixed_patient = True, label = ['Ancestry', '', '', 'B', ''], race = True) #, fixed_patient_values = [308, 495, 30, 335])
plt.show()

# Explainability of PIK3CA and HRD

### Here, gradient descent is used to generate images that are more / less likely to be predicted to be PIK3CA mutated / HRD high

In [None]:
from slideflow.model.torch import load as load_torch_model
df_mod = pd.read_csv(PROJECT_DIR + '/PROJECTS/HistoXGAN/SAVED_FEATURES/brca_features_plot.csv')
feat_cols = list(df_mod.columns.values)
feat_cols = [f for f in feat_cols if 'Feature_' in f]
base_num = 200
from slideflow.gan.stylegan3.stylegan3 import dnnlib, legacy, utils
with dnnlib.util.open_url(PROJECT_DIR + '/FINAL_MODELS/CTransPath/snapshot.pkl') as f:
    G = legacy.load_network_pkl(f)['G_ema'].to(device)

def get_img_dict(df_mod, m, base_imgs, regen = True):
    if not regen and os.path.isfile(os.path.dirname(m) + "img_dict.pkl"):
        with open(os.path.dirname(m) + "img_dict.pkl", 'rb') as f:
            img_dict = pickle.load(f)
            return img_dict
        
    img_dict = {}

    model = load_torch_model(m).eval().to(device)
    preprocess = sf.util.get_preprocess_fn(m)
    
    for b in base_imgs:
        vector_base = torch.tensor([df_mod[feat_cols].loc[b, :].values.tolist()]).to(device)
        vector_base.requires_grad = True
        lr = 1e-3

        optimizer = torch.optim.Adam([vector_base], lr=lr)
        loss_fn = torch.nn.L1Loss().to(device)
        softmax = torch.nn.Softmax()
        imgs = []
        for i in range(200):
            optimizer.zero_grad()
            img_gen = G(vector_base, 0, noise_mode ='const')#((G(vector_base, 0, noise_mode ='const') + 1)*127.5).clamp(0, 255).to(torch.uint8)
            pred = softmax(model(img_gen))
            loss = loss_fn(pred[0][1], torch.tensor(0).to(device))
            if i in [0, 20, 50, 199]:        
                imgs += [((img_gen + 1)*127.5).permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()]     
            loss.backward()
            optimizer.step()

        imgs = list(reversed(imgs))

        vector_base = torch.tensor([df_mod[feat_cols].loc[b, :].values.tolist()]).to(device)
        vector_base.requires_grad = True
        optimizer = torch.optim.Adam([vector_base], lr=lr)


        for i in range(200):
            optimizer.zero_grad()
            img_gen = G(vector_base, 0, noise_mode ='const')#((G(vector_base, 0, noise_mode ='const') + 1)*127.5).clamp(0, 255).to(torch.uint8)
            if i in [20, 50, 199]:
                imgs += [((img_gen + 1)*127.5).permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()]
            pred = softmax(model(img_gen))
            loss = loss_fn(pred[0][1], torch.tensor(1).to(device))
            loss.backward()
            optimizer.step()
        img_dict[b] = imgs
    with open(os.path.dirname(m) + "img_dict.pkl", 'wb') as f:
        pickle.dump(img_dict, f)
    return img_dict


def plot_model(df_mod, m, regen, subfig, label):
    base_imgs =  [0, 1, 2, 3] #[100, 200, 401, 501]#[100, 200, 300, 400, 500, 600, 700, 800]
    rows = len(base_imgs)
    img_dict = get_img_dict(df_mod, m, base_imgs, regen = regen)
    axs = subfig.subplots(rows, 7)
    col = 0
    for img_name in img_dict:
        for row in range(len(img_dict[img_name])):   
            axs[col][row].imshow(img_dict[img_name][row])
            #axs[col][row].spines['right'].set_visible(False)
            #axs[col][row].spines['top'].set_visible(False)
            #axs[col][row].spines['left'].set_visible(False)
            #axs[col][row].spines['bottom'].set_visible(False)
            axs[col][row].set_xticks([])
            axs[col][row].set_yticks([])
            axs[col][row].xaxis.set_label_position('top')
        #axs[col][0].set_ylabel(img_name, size = 18)
        col = col + 1
    count = 0
    
    subfig.subplots_adjust(left = 0, top = 1, right = 1, bottom = 0 + 2/18, wspace=0, hspace=0)
    axs[0][0].annotate(text="", xy=(0.98, 1.032), xytext=(0.505,1.032), xycoords="subfigure fraction",  arrowprops=dict(facecolor='C1'))
    axs[0][0].annotate(text="", xy=(0.02, 1.032), xytext=(0.495,1.032), xycoords="subfigure fraction",  arrowprops=dict(facecolor='C0'))
    axs[0][0].annotate(text=label[0], xy = (0.50,1.05), xycoords="subfigure fraction", ha="center", size = 18)
    axs[0][0].annotate(text=label[1], xy = (0.02, 1.05), xycoords="subfigure fraction", ha="center", size = 16)
    axs[0][0].annotate(text=label[2], xy = (0.98, 1.05), xycoords="subfigure fraction", ha="center", size = 16)

    subfig.text(-0.03,1.02,label[3], zorder = 30, clip_on = False, weight='bold', size = 20 )

    
fig_overall = plt.figure(figsize=(14, 18), dpi = 300)
subfig = fig_overall.subfigures(2, 1)
plt.tight_layout()
plot_model(df_mod, PROJECT_DIR + "/PRETRAINED_MODELS/PIK3CA/PIK3CA-HP0.zip", regen = False, subfig = subfig[0], label = ["PIK3CA", "WT", "MUT", "A"] )
plot_model(df_mod, PROJECT_DIR + "/PRETRAINED_MODELS/HRD/HRD-HP0.zip", regen = False, subfig = subfig[1], label = ["HRD", "Low", "High", "B"])

plt.show()


## The following was used to generate videos for pathologists to review for PIK3CA / HRD transitions

In [None]:
from slideflow.model.torch import load as load_torch_model
import torch
import slideflow as sf
device = torch.device('cuda:3')

feat_cols = list(df_mod.columns.values)
feat_cols = [f for f in feat_cols if 'Feature_' in f]
from slideflow.gan.stylegan3.stylegan3 import dnnlib, legacy, utils
with dnnlib.util.open_url(PROJECT_DIR + '/FINAL_MODELS/CTransPath/snapshot.pkl') as f:
    G = legacy.load_network_pkl(f)['G_ema'].to(device)
    
    
def get_img_dict_full(df_mod, m, base_img, regen = True):
    model = load_torch_model(m).eval().to(device)
    preprocess = sf.util.get_preprocess_fn(m)   

    
    vector_base = torch.tensor([df_mod[feat_cols].loc[base_img, :].values.tolist()]).to(device)
    vector_base.requires_grad = True
    lr = 5e-4

    optimizer = torch.optim.Adam([vector_base], lr=lr)
    loss_fn = torch.nn.L1Loss().to(device)
    softmax = torch.nn.Softmax()
    imgs = []
    for i in range(50):
        optimizer.zero_grad()
        img_gen = G(vector_base, 0, noise_mode ='const')#((G(vector_base, 0, noise_mode ='const') + 1)*127.5).clamp(0, 255).to(torch.uint8)
        pred = softmax(model(img_gen))
        loss = loss_fn(pred[0][1], torch.tensor(0).to(device))
        imgs += [((img_gen + 1)*127.5).permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()]     
        loss.backward()
        optimizer.step()

    imgs = list(reversed(imgs))

    vector_base = torch.tensor([df_mod[feat_cols].loc[base_img, :].values.tolist()]).to(device)
    vector_base.requires_grad = True
    optimizer = torch.optim.Adam([vector_base], lr=lr)


    for i in range(50):
        optimizer.zero_grad()
        img_gen = G(vector_base, 0, noise_mode ='const')#((G(vector_base, 0, noise_mode ='const') + 1)*127.5).clamp(0, 255).to(torch.uint8)
        imgs += [((img_gen + 1)*127.5).permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()]
        pred = softmax(model(img_gen))
        loss = loss_fn(pred[0][1], torch.tensor(1).to(device))
        loss.backward()
        optimizer.step()
    return imgs

def save_vectors_pathologist(df_mod, m):
    from slideflow.gan.stylegan3.stylegan3 import dnnlib, legacy, utils
    from random import randrange
    n_entries = len(df_mod.index)
    for i in range(50):
        r = randrange(n_entries)
        video_path = os.path.dirname(m) + "/" + str(r) + '__pik3ca_video.mp4'
        utils.save_video(get_img_dict_full(df_mod, m, r), path=video_path)        

df_mod = pd.read_csv(PROJECT_DIR + '/PROJECTS/HistoXGAN/SAVED_FEATURES/brca_features_part.csv')
save_vectors_pathologist(df_mod, PROJECT_DIR + "/PRETRAINED_MODELS/PIK3CA/PIK3CA-HP0_epoch3.zip")
save_vectors_pathologist(df_mod, PROJECT_DIR + "/PRETRAINED_MODELS/HRD/HRD-HP0_epoch3.zip")
