In [None]:
import slideflow as sf
import torch
import numpy as np
from slideflow.io.torch import whc_to_cwh
from PIL import Image
from slideflow.mil.eval import run_inference
from slideflow.mil.utils import load_model_weights
from slideflow.model.extractors import rebuild_extractor
import os
import pandas as pd
import pickle
PROJECT_DIR = os.getcwd()

PROJECT_DIR = '/mnt/data/fred/slideflow-gan'
device = torch.device('cuda:0')

# MIL Model Feature Traversal

In [None]:

import os
import pickle
import matplotlib.pyplot as plt

from slideflow.gan.stylegan3.stylegan3 import dnnlib, legacy, utils
with dnnlib.util.open_url(PROJECT_DIR + '/FINAL_MODELS/CTransPath/snapshot.pkl') as f:
    G = legacy.load_network_pkl(f)['G_ema'].to(device)

    
## Applies gradient descent to transition a 'base' feature vector towards a lower or higher prediction
## for the outcome of interest from the MIL model.
## Dataframe df_mod to provide a list of 'base' feature vectors from the dataset
## model_path is a path to the model to evaluate transitions
## model, config - are the loaded torch model
## base_imgs - provides the img to select from df_mod to visualize
def get_img_dict(df_mod, model_path, model, config, base_imgs, regen = True):
    feat_cols = list(df_mod.columns.values)
    feat_cols = [f for f in feat_cols if 'Feature_' in f]
    if not regen and os.path.isfile(os.path.dirname(model_path) + "/img_dict.pkl"):
        with open(os.path.dirname(model_path) + "/img_dict.pkl", 'rb') as f:
            img_dict = pickle.load(f)
            return img_dict
        
    img_dict = {}

    loss_fn = torch.nn.MSELoss().to(device) 
    for b in [base_imgs]:
        vector_base = torch.tensor([df_mod[feat_cols].loc[b, :].values.tolist()]).to(device)
        vector_base = torch.unsqueeze(vector_base, 0)
        vector_base.requires_grad = True
        lr = 2e-2

        optimizer = torch.optim.Adam([vector_base], lr=lr)
        imgs = []
        for i in range(200):
            optimizer.zero_grad()
            pred, _, _ = run_inference(model, vector_base, attention = True, use_lens = config.model_config.use_lens)
            loss = loss_fn(pred[0][1], torch.tensor(0.0).to(device))
            if i in [0, 20, 50, 199]: 
                img_gen = G(vector_base[0], 0, noise_mode ='const')
                imgs += [((img_gen + 1)*127.5).permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()]     
            loss.backward()
            optimizer.step()

        imgs = list(reversed(imgs))

        vector_base = torch.tensor([df_mod[feat_cols].loc[b, :].values.tolist()]).to(device)
        vector_base = torch.unsqueeze(vector_base, 0)
        vector_base.requires_grad = True
        optimizer = torch.optim.Adam([vector_base], lr=lr)


        for i in range(200):
            optimizer.zero_grad()
            pred, _, _ = run_inference(model, vector_base, attention = True, use_lens = config.model_config.use_lens)
            loss = loss_fn(pred[0][1], torch.tensor(1.0).to(device))
            if i in [20, 50, 199]:
                img_gen = G(vector_base[0], 0, noise_mode ='const')
                imgs += [((img_gen + 1)*127.5).permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()]
            loss.backward()
            optimizer.step()
        img_dict[b] = imgs
    with open(os.path.dirname(model_path) + "/img_dict.pkl", 'wb') as f:
        pickle.dump(img_dict[base_imgs], f)

    return img_dict[base_imgs]

## Applies gradient descent to transition a 'base' feature vector towards a lower or higher attention
## for the MIL model.
## Dataframe df_mod to provide a list of 'base' feature vectors from the dataset
## model_path is a path to the model to evaluate transitions
## model, config - are the loaded torch model
## base_imgs - provides the img to select from df_mod to visualize

def get_img_dict_attn(df_mod, model_path, model, config, base_imgs, regen = True):
    feat_cols = list(df_mod.columns.values)
    feat_cols = [f for f in feat_cols if 'Feature_' in f]
    if not regen and os.path.isfile(os.path.dirname(model_path) + "/img_dict_attn.pkl"):
        with open(os.path.dirname(model_path) + "/img_dict_attn.pkl", 'rb') as f:
            img_dict = pickle.load(f)
            return img_dict[base_imgs]
        
    img_dict = {}

    loss_fn = torch.nn.MSELoss().to(device) 
    for b in [base_imgs]:
        vector_base = torch.tensor([df_mod[feat_cols].loc[b, :].values.tolist()]).to(device)
        vector_base = torch.unsqueeze(vector_base, 0)
        vector_base.requires_grad = True
        lr = 3e-3

        optimizer = torch.optim.Adam([vector_base], lr=lr)
        imgs = []
        for i in range(200):
            optimizer.zero_grad()
            _, at, _ = run_inference(model, vector_base, attention = True, use_lens = config.model_config.use_lens)
            loss =  1/(1 + torch.exp(-at))
            if i in [0, 50, 100, 199]: 
                img_gen = G(vector_base[0], 0, noise_mode ='const')
                imgs += [((img_gen + 1)*127.5).permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()]     
            loss.backward()
            optimizer.step()

        imgs = list(reversed(imgs))

        vector_base = torch.tensor([df_mod[feat_cols].loc[b, :].values.tolist()]).to(device)
        vector_base = torch.unsqueeze(vector_base, 0)
        vector_base.requires_grad = True
        optimizer = torch.optim.Adam([vector_base], lr=lr)


        for i in range(200):
            optimizer.zero_grad()
            _, at, _ = run_inference(model, vector_base, attention = True, use_lens = config.model_config.use_lens)
            loss =  1/(1 + torch.exp(at))
            if i in [50, 100, 199]:
                img_gen = G(vector_base[0], 0, noise_mode ='const')
                imgs += [((img_gen + 1)*127.5).permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()]
            loss.backward()
            optimizer.step()
        img_dict[b] = imgs
    with open(os.path.dirname(model_path) + "/img_dict_attn.pkl", 'wb') as f:
        pickle.dump(img_dict[base_imgs], f)

    return img_dict[base_imgs]


## Applies gradient descent to transition a 'base' feature vector towards a lower or higher prediction or attention
## for the specified MIL model, and plots the resulting images.
##
## Dataframe df_mod to provide a list of 'base' feature vectors from the dataset
## model_path is a path to the model to evaluate transitions
## model, config - are the loaded torch model
## base_imgs - provides the img to select from df_mod to visualize
## regen - set to true to regenerate images, if false will use saved pickled images
## datasets - provide the name of the datasets used for the models

def plt_imgs(df_mod, model_path, model, config, base_imgs, regen, attn, subfig, label, datasets):
    
    rows = len(base_imgs)
    img_dict = {}
    if attn:
        for i in [0,1,2,3]:
            img_dict[i] = get_img_dict_attn(df_mod = df_mod[i], model_path = model_path[i], model = model[i], config = config[i], base_imgs = base_imgs[i], regen = regen)
    else:
        for i in [0,1,2,3]:
            img_dict[i] = get_img_dict(df_mod = df_mod[i], model_path = model_path[i], model = model[i], config = config[i], base_imgs = base_imgs[i], regen = regen)

    col = 0
    select_rows = [1, 3, 5]
    axs = subfig.subplots(rows, len(select_rows))
    for img_name in img_dict:
        if col == 0 or col == 2:
            select_rows = [0, 3, 6]
        else:
            select_rows = [1, 3, 5]
        for row_num in range(len(select_rows)):
            row = select_rows[row_num]
            axs[col][row_num].imshow(img_dict[img_name][row])
            #axs[col][row].spines['right'].set_visible(False)
            #axs[col][row].spines['top'].set_visible(False)
            #axs[col][row].spines['left'].set_visible(False)
            #axs[col][row].spines['bottom'].set_visible(False)
            axs[col][row_num].set_xticks([])
            axs[col][row_num].set_yticks([])
            axs[col][row_num].xaxis.set_label_position('top')
        axs[col][0].set_ylabel(datasets[col], size = 18)
        col = col + 1
    count = 0
    subfig.subplots_adjust(left = 0 + 1/24, top = 1, right = 1 - 1/24, bottom = 0 + 2/16, wspace=0, hspace=0)
    padding = 3
    axs[0][0].annotate(text="", xy=(0.92, 1.032), xytext=(0.505,1.032), xycoords="subfigure fraction",  arrowprops=dict(facecolor='C1'))
    axs[0][0].annotate(text="", xy=(0.08, 1.032), xytext=(0.495,1.032), xycoords="subfigure fraction",  arrowprops=dict(facecolor='C0'))
    axs[0][0].annotate(text=label[0], xy = (0.50,1.05), xycoords="subfigure fraction", ha="center", size = 18)
    axs[0][0].annotate(text=label[1], xy = (0.08, 1.05), xycoords="subfigure fraction", ha="center", size = 16)
    axs[0][0].annotate(text=label[2], xy = (0.92, 1.05), xycoords="subfigure fraction", ha="center", size = 16)
    if label[0] == 'Subtype Prediction':
        axs[0][0].annotate(text="Ductal", xy = (0.0, 1.0), xytext=(padding, -padding), xycoords="axes fraction", textcoords='offset points', ha="left", va = "top", size = 16,  bbox=dict(pad = padding, facecolor='white', edgecolor='black'))
        axs[0][len(select_rows) - 1].annotate(text="Lobular", xy = (1.0, 1.0), xytext=(-padding, -padding), xycoords="axes fraction", textcoords='offset points', ha="right", va = "top", size = 16,  bbox=dict(pad = padding, facecolor='white', edgecolor='black'))
        axs[1][0].annotate(text="Adeno", xy = (0.0, 1.0), xytext=(padding, -padding), xycoords="axes fraction", textcoords='offset points', ha="left", va = "top", size = 16,  bbox=dict(pad = padding, facecolor='white', edgecolor='black'))
        axs[1][len(select_rows) - 1].annotate(text="Squamous", xy = (1.0, 1.0), xytext=(-padding, -padding), xycoords="axes fraction", textcoords='offset points', ha="right", va = "top", size = 16,  bbox=dict(pad = padding, facecolor='white', edgecolor='black'))
        axs[2][0].annotate(text="Adeno", xy = (0.0, 1.0), xytext=(padding, -padding), xycoords="axes fraction", textcoords='offset points', ha="left", va = "top", size = 16,  bbox=dict(pad = padding, facecolor='white', edgecolor='black'))
        axs[2][len(select_rows) - 1].annotate(text="Squamous", xy = (1.0, 1.0), xytext=(-padding, -padding), xycoords="axes fraction", textcoords='offset points', ha="right", va = "top", size = 16,  bbox=dict(pad = padding, facecolor='white', edgecolor='black'))
        axs[3][0].annotate(text="Clear", xy = (0.0, 1.0), xytext=(padding, -padding), xycoords="axes fraction", textcoords='offset points', ha="left", va = "top", size = 16,  bbox=dict(pad = padding, facecolor='white', edgecolor='black'))
        axs[3][len(select_rows) - 1].annotate(text="Papillary", xy = (1.0, 1.0),xytext=(-padding, -padding), xycoords="axes fraction", textcoords='offset points', ha="right", va = "top", size = 16,  bbox=dict(pad = padding, facecolor='white', edgecolor='black'))

    subfig.text(-0.01,1.02,label[3], zorder = 30, clip_on = False, weight='bold', size = 20 )
    

    
device = torch.device('cuda:0')

import os
import pickle
import matplotlib.pyplot as plt


base_imgs =  [100, 200, 401, 501]
fig_overall = plt.figure(figsize=(13, 18), dpi = 300)
subfigs = fig_overall.subfigures(2, 2)
plt.tight_layout()


df_mod = []
datasets = ['BRCA', 'PAAD', 'HNSC', 'PRAD']

for d in datasets:
    df_mod += [pd.read_csv(PROJECT_DIR + '/PROJECTS/HistoXGAN/SAVED_FEATURES/' + d.lower() + '_features_part.csv')]

model_paths = []
for d in datasets:
    model_paths += [PROJECT_DIR + '/PRETRAINED_MODELS/' + d  + '_GRADE_MIL/']
models = []
configs = []

for m in model_paths:
    model, config = load_model_weights(m)
    model.eval().to(device)
    models += [model]
    configs += [config]

base_imgs = [100, 100, 200, 200]

plt_imgs(df_mod, model_paths, models, configs, base_imgs, regen = True, attn = False, subfig = subfigs[0][0], label = ['Grade Prediction', 'Low', 'High', 'A'],datasets = datasets)
plt_imgs(df_mod, model_paths, models, configs, base_imgs, regen = True, attn = True, subfig = subfigs[0][1], label = ['Grade Attention', 'Low', 'High', 'B'], datasets = datasets)

base_imgs = [100, 100, 200, 6]
df_mod = []
datasets = ['BRCA', 'LUNG', 'ESCA', 'KIDNEY']

for d in datasets:
    dset = d
    if dset == 'LUNG':
        dset = 'LUAD'
    if dset == 'KIDNEY':
        dset = 'KIRP'
    df_mod += [pd.read_csv(PROJECT_DIR + '/PROJECTS/HistoXGAN/SAVED_FEATURES/' + dset.lower() + '_features_part.csv')]

model_paths = []
for d in datasets:
    model_paths += [PROJECT_DIR + '/PRETRAINED_MODELS/' + d  + '_SUBTYPE_MIL/']
models = []
configs = []

for m in model_paths:
    model, config = load_model_weights(m)
    model.eval().to(device)
    models += [model]
    configs += [config]

base_imgs = [100, 100, 100, 100]
plt_imgs(df_mod, model_paths, models, configs, base_imgs, regen = True, attn = False, subfig = subfigs[1][0], label = ['Subtype Prediction', '', '', 'C'], datasets = datasets)
plt_imgs(df_mod, model_paths, models, configs, base_imgs, regen = True, attn = True, subfig = subfigs[1][1], label = ['Subtype Attention', 'Low', 'High', 'D'], datasets = datasets)

plt.show()



# Generation of Features from MRI Images

### Loading extracted mean feature vectors and MRI radiomic features

In [None]:
#Load dataframes trained / test split with mean ctranspath features per slide as well as MRI radiomic features
with open(PROJECT_DIR + "/pub_pkl/mri_recreation/dsf_stats_testk.pkl", 'rb') as f:
    dsf_stats_testk = pickle.load(f)

with open(PROJECT_DIR + "/pub_pkl/mri_recreation/dsf_stats_traink.pkl", 'rb') as f:
    dsf_stats_traink = pickle.load(f)

with open(PROJECT_DIR + "/pub_pkl/mri_recreation/feat_cols.pkl", 'rb') as f:
    feat_cols = pickle.load(f)

with open(PROJECT_DIR + "/pub_pkl/mri_recreation/feat_cols_pca.pkl", 'rb') as f:
    feat_cols_pca = pickle.load(f)
    
with open(PROJECT_DIR + "/pub_pkl/mri_recreation/df_composite_corr.pkl", 'rb') as f:
    df_composite_corr = pickle.load(f)

with open(PROJECT_DIR + "/pub_pkl/mri_recreation/df_corr.pkl", 'rb') as f:
    df_corr = pickle.load(f)

with open(PROJECT_DIR + "/pub_pkl/mri_recreation/img_reals.pkl", 'rb') as f:
    img_reals = pickle.load(f)

list_sigs = dsf_stats_testk[0].columns.tolist()[-777:]

### Train simple encoder to convert from MRI features to CTransPath features

In [None]:
input_dim = len(feat_cols)
encoded_dim = 768
from slideflow.gan.stylegan3.stylegan3 import dnnlib, legacy, utils
with dnnlib.util.open_url(PROJECT_DIR + '/FINAL_MODELS/CTransPath/snapshot.pkl') as f:
    G = legacy.load_network_pkl(f)['G_ema'].to(device)
    
    
from slideflow.model import build_feature_extractor
ctranspath = build_feature_extractor('ctranspath', tile_px=512, device = device, force_uint8 = False, no_grad = False)
for i in [0,1,2,3,4]:
    encoder = torch.nn.Sequential(
                          torch.nn.Linear(input_dim, encoded_dim),
                          torch.nn.LeakyReLU(),
                          torch.nn.Linear(encoded_dim, encoded_dim)
    ).to(device)

    loss_fn = torch.nn.L1Loss()
    optimizer = torch.optim.Adam(encoder.parameters(), lr=1e-4, weight_decay = 1e-6)
    
    for epoch in range(5):
        loss_total = 0
        loss_steps = 0
        for chunk in np.array_split(dsf_stats_traink[i], len(dsf_stats_traink[i].index)/16):
            optimizer.zero_grad()
            
            source = torch.tensor(chunk[feat_cols].values).to(device).to(torch.float32)
            target = torch.tensor(chunk.ctranspath_mean.apply(pd.Series).values).to(device).to(torch.float32)
            
            e = encoder(source.requires_grad_())
            e2 = ctranspath(((G(e, 0, noise_mode = 'const') + 1)*127.5).requires_grad_())
            loss = loss_fn(e, target) + loss_fn(e2, target)
            loss.backward()
            optimizer.step()

    encoder.eval()
    ev = encoder(torch.tensor(dsf_stats_testk[i][feat_cols].values).to(device).to(torch.float32))
    dsf_stats_testk[i]['ctranspath_pred'] = [*(ev).detach().cpu().numpy()]
    dsf_stats_testk[i]["ctranspath_pred_regen"] = dsf_stats_testk[i]["ctranspath_pred"].apply(lambda x: ctranspath(((G(torch.unsqueeze(torch.tensor(x).to(device), dim= 0), 0, noise_mode = 'const') + 1)*127.5).clamp(0, 255).to(torch.uint8))[0].detach().cpu().numpy())


### Calculate statistics from trained encoder results
Note - this requires all 777 trained SSL models for prediction of expression signatures, grade, and subtype to be present in the appropriate subfolder. We have provided pretrained versions of these models

In [None]:
from scipy.stats import pearsonr, spearmanr
sf.util.setLoggingLevel(50)

dsf_stats_merge = pd.concat(dsf_stats_testk)
for c in list_sigs: 
    if c != 'patient' and c != 'slide':
        model_name = f'attention_mil-{c}'
        import os
        matching = [
            o for o in os.listdir(PROJECT_DIR + "/PRETRAINED_MODELS/mil/")
            if o == model_name
        ]
        if c != 'slide' and c != 'patient' and len(matching) >= 1:
            model_name = PROJECT_DIR + "/PRETRAINED_MODELS/mil/" +  matching[len(matching)-1]
                                                
            model, config = load_model_weights(model_name)
            model.eval().to(device)
            
            dsf_stats_merge[c + "_pred"] = dsf_stats_merge['ctranspath_pred_regen'].apply(lambda x: run_inference(model, torch.unsqueeze(torch.unsqueeze(torch.tensor(x).to(device), dim= 0), dim= 0), attention = False, use_lens = config.model_config.use_lens)[0].detach().cpu().numpy()[0][1])
            dsf_stats_merge[c + "_st"] = dsf_stats_merge['ctranspath_mean'].apply(lambda x: run_inference(model, torch.unsqueeze(torch.unsqueeze(torch.tensor(x).to(device), dim= 0), dim= 0), attention = False, use_lens = config.model_config.use_lens)[0].detach().cpu().numpy()[0][1])
    
list_corr = []
for c in list_sigs:
    r, p = spearmanr(dsf_stats_merge[c + "_st"], dsf_stats_merge[c + "_pred"])
    list_corr += [[c, r, p]]
    

### Predictions from MRI using Logistic Regression to Directly Predict Expression Signatures

In [None]:
from sklearn.linear_model import LogisticRegression, LinearRegression
from sklearn.metrics import roc_auc_score

#Train logistic regression with a fit of MRI from PCA for overoptimisitc estimate of performance for a 'upper bound' on performance from reconstructed histology
pred = None
truevals = None
y_preds_full = []
y_trues_full = []
for c in list_sigs:
    y_preds = []
    y_trues = []

    for k in [0,1,2,3,4]:
        X_train = dsf_stats_traink[k][feat_cols_pca]
        Y_train = dsf_stats_traink[k][c] > 0.5
        clf = LogisticRegression().fit(X_train, Y_train)
        y_preds += clf.predict_proba(dsf_stats_testk[k][feat_cols_pca])[:, 1].tolist()
        y_trues += dsf_stats_testk[k][c].tolist()
    #print(y_preds)
    #print(y_trues)
    y_preds_full += [y_preds]
    y_trues_full += [y_trues]
    
list_corr2 = []
for i in range(len(list_sigs)):
    r, p = pearsonr(y_preds_full[i], y_trues_full[i])
    list_corr2 += [[c, r, p]]
    
list_comb = []
for row1, row2 in zip(list_corr, list_corr2):
    list_comb += [[row1[0], row1[1], row1[2], row2[1], row2[2]]]
list_comb = np.array(list_comb)

df_corr = pd.DataFrame(list_comb, columns = ['signature_name', 'gen_r', 'gen_p', 'mri_r', 'mri_p'])
from statsmodels.stats.multitest import fdrcorrection 
df_corr['gen_p_fdr'] = fdrcorrection(df_corr['gen_p'])[1]
df_corr['mri_p_fdr'] = fdrcorrection(df_corr['mri_p'])[1]

### Can load the results of pretrained encoder here for replication of publication statistics

In [None]:
with open(PROJECT_DIR + "/pub_pkl/mri_recreation/dsf_stats_merge.pkl", 'rb') as f:
    dsf_stats_merge = pickle.load(f)
    
with open(PROJECT_DIR + "/pub_pkl/sig_clusters.pkl", 'rb') as f:
    saved_clusters = pickle.load(f)

with open(PROJECT_DIR + "/pub_pkl/xind.pkl", 'rb') as f:
    xind = pickle.load(f)

xind.insert(180, 775)
xind.insert(345, 776)
saved_clusters = np.concatenate([saved_clusters, [2, 3]])

### Plot figure of results

In [None]:
import scipy
import matplotlib 
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import cv2

#feature_links = scipy.cluster.hierarchy.linkage(df_composite_corr, method='ward', metric='euclidean')
#patient_links = scipy.cluster.hierarchy.linkage(df_composite_corr.transpose(), method='ward', metric='euclidean')


from slideflow.gan.stylegan3.stylegan3 import dnnlib, legacy, utils
with dnnlib.util.open_url(PROJECT_DIR + '/FINAL_MODELS/CTransPath/snapshot.pkl') as f:
    G = legacy.load_network_pkl(f)['G_ema'].to(device)

fig_overall = plt.figure(figsize=(14, 14), dpi = 300)
subfigs_overall = fig_overall.subfigures(2, 1)

plot_list = [304, 274, 906, 298] #[304, 274, 906, 475, 298, 481, 378]

subfigs_top = subfigs_overall[0].subfigures(1,2, width_ratios = [6, 8])
ax = subfigs_top[0].subplots(1,1)

im = cv2.imread(PROJECT_DIR + '/FINAL_IMAGES/mripath.png')

ax.imshow(cv2.cvtColor(cv2.resize(im, (1285, 1279), interpolation=cv2.INTER_LINEAR), cv2.COLOR_BGR2RGB))
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
for side in ['top','right','bottom','left']:
    ax.spines[side].set_visible(False)
axs = subfigs_top[1].subplots(3, 4)


#imgs_plot = {'mean':{}, 'pred':{}}
#for i in range(len(dsf_stats_merge.index)):
#    with torch.no_grad():
#        imgs_plot['pred'][i] = ((G(torch.tensor([dsf_stats_merge.iloc[i]['ctranspath_pred']]).to(device), 0, noise_mode ='const') + 1)*127.5).permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
#        imgs_plot['mean'][i] = ((G(torch.tensor([dsf_stats_merge.iloc[i]['ctranspath_mean']]).to(device), 0, noise_mode ='const') + 1)*127.5).permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()

for i in range(len(plot_list)):
    axs[0][i].imshow(img_reals[plot_list[i]])
    axs[1][i].imshow(imgs_plot['mean'][plot_list[i]])
    axs[2][i].imshow(imgs_plot['pred'][plot_list[i]])
    axs[0][i].set_xticks([])
    axs[0][i].set_yticks([])
    axs[1][i].set_xticks([])
    axs[1][i].set_yticks([])
    axs[2][i].set_xticks([])
    axs[2][i].set_yticks([])
axs[0][0].set_ylabel("Tile from Slide")
axs[1][0].set_ylabel("Generated from Avg. Features")
axs[2][0].set_ylabel("Generated from MRI")



subfigs_top[0].subplots_adjust(left = 0, right = 0.9, top = 1, bottom = 0.15, wspace = 0, hspace = 0)

subfigs_top[1].subplots_adjust(left = 0, right = 1, top = 1, bottom = 0.15, wspace = 0, hspace = 0)


subfigs = subfigs_overall[1].subfigures(1, 3, width_ratios = [2, 7, 5])

ax = subfigs[0].subplots()
labels = ["Mean Inter-Tumor\nFeature Difference", "Mean Feature Difference\nVirtual Biopsy vs Digital Slide", "Mean Intra-Tumor\nFeature Difference"]
cmap = matplotlib.cm.get_cmap('Blues')
norm = matplotlib.colors.Normalize(vmin=0, vmax = 5)

ax.bar(0, 0.074, edgecolor = 'black', color = cmap(norm(1)))
ax.errorbar(0, 0.074, yerr=0.001, fmt='none', color='black', capsize=3)
ax.bar(1, 0.078, edgecolor = 'black', color = cmap(norm(3)))
ax.errorbar(1, 0.078, yerr=0.001, fmt='none', color='black', capsize=3)
ax.bar(2, 0.092, edgecolor = 'black', color = cmap(norm(5)))
ax.errorbar(2, 0.092, yerr=0.001, fmt='none', color='black', capsize=3)
ax.set_ylim(0.05, 0.10)
ax.set_ylabel("Mean Absolute Difference")

ax.set_xticks([0, 1, 2])
ax.set_xticklabels(labels, rotation=90)

subfigs[0].subplots_adjust(left = 0, right = 0.8, top = 1.0, bottom = 2/5, wspace = 0.1, hspace = 0.1)

ax = subfigs[2].subplots()

import scipy.cluster.hierarchy as hac
z = hac.linkage(df_composite_corr, method='ward', metric='euclidean')
clust_num = 5
part1 = hac.fcluster(z, clust_num, 'maxclust')
from matplotlib.lines import Line2D
ax.axhline(y = 0.071, color='C0', linestyle='dashed', alpha = 0.5)
ax.axvline(x = 0.081, color='C0', linestyle='dashed', alpha = 0.5, label = 'Significance Threshold (FDR corrected)')

for i in range(clust_num):
    #ax.scatter(df_corr['gen_r'].iloc[part1 == i + 1], df_corr['mri_r'].iloc[part1 == i + 1], color = cmap(norm(i + 1)))
    ax.scatter(df_corr['gen_r'].iloc[saved_clusters == i + 1], df_corr['mri_r'].iloc[saved_clusters == i + 1], color = cmap(norm(i + 1)))

#(PMID 24516633)
ax.annotate("IFN3", xy = (df_corr['gen_r'].iloc[190], df_corr['mri_r'].iloc[190]),
                   xytext = (df_corr['gen_r'].iloc[190] - 0.01, df_corr['mri_r'].iloc[190] - 0.02), bbox=dict(boxstyle='square,pad=0.01', fc='none', ec='none'),
                       arrowprops=dict(arrowstyle="->",
                            connectionstyle="arc3", relpos=(0.5,1)))
# (PMID34711841)
ax.annotate("Trop2", xy = (df_corr['gen_r'].iloc[679], df_corr['mri_r'].iloc[679]),
                   xytext = (df_corr['gen_r'].iloc[679], df_corr['mri_r'].iloc[679] + 0.015), bbox=dict(boxstyle='square,pad=0.01', fc='none', ec='none'),
                       arrowprops=dict(arrowstyle="->",
                            connectionstyle="arc3", relpos=(0.5,0)))
#(PMID 15591335)
ax.annotate("Oncotype", xy = (df_corr['gen_r'].iloc[670], df_corr['mri_r'].iloc[670]),  
                                        xytext = (df_corr['gen_r'].iloc[670] + 0.02, df_corr['mri_r'].iloc[670] - 0.02), bbox=dict(boxstyle='square,pad=0.01', fc='none', ec='none'),
            arrowprops=dict(arrowstyle="->",
                            connectionstyle="arc3", relpos=(0.1,0.8)))
#(PMID 19204204)
ax.annotate("Prosigna", xy = (df_corr['gen_r'].iloc[666], df_corr['mri_r'].iloc[666]),
                                        xytext = (df_corr['gen_r'].iloc[666] + 0.01, df_corr['mri_r'].iloc[666] + 0.02), bbox=dict(boxstyle='square,pad=0.01', fc='none', ec='none'),
            
            arrowprops=dict(arrowstyle="->",
                            connectionstyle="arc3", relpos=(0.1,0.2)))
legs = []
for c, i in zip(["Hypoxia / Angiogensis", "Immune", "Basal", "Luminal", "Other"], [5,4,3,2,1]):
    legs += [Line2D([0], [0], marker = 'o', color = 'w', markerfacecolor = cmap(norm(i)), markersize = 8, label = c)]
legs += [Line2D([0], [0], color='C0', linestyle = 'dashed', alpha = 0.5, label='Significant (FDR corrected)')]
                       
leg = ax.legend(handles=legs, loc='lower right', framealpha=1)
for lh in leg.legendHandles:
    lh.set_alpha(1)

#ax.fill([0.081, 0.081, 0.25, 0.25], [0.071, 0.4, 0.4, 0.071], color = 'C0', alpha = 0.2, label = "Signficant with FDR corr.")
#ax.legend()
ax.set_xlabel("Pearson r, Prediction from Generated Histology")
ax.set_ylabel("Pearson r, Prediction from MRI")
#ax.set_xlim([-0.1,0.3])
subfigs[2].subplots_adjust(left = 0, right = 1, top = 1.0, bottom = +1/5, wspace = 0.1, hspace = 0.1)

axs = subfigs[1].subplots(2, 2, height_ratios = [8, 0.5], width_ratios = [8, 0.5])
#col_dendrogram = scipy.cluster.hierarchy.dendrogram(feature_links, no_plot=True)
#row_dendrogram = scipy.cluster.hierarchy.dendrogram(patient_links, no_plot=True)

#col_dendrogram = scipy.cluster.hierarchy.dendrogram(feature_links, ax=ax_col_dendrogram)
#row_dendrogram = scipy.cluster.hierarchy.dendrogram(patient_links, no_plot=True)
#ax_col_dendrogram.set_axis_off()

#xind = col_dendrogram['leaves']
#yind = row_dendrogram['leaves']
data = pd.DataFrame(df_composite_corr)
#cm = axs[0][0].imshow(data.iloc[xind,yind].T, cmap='magma', vmin = -0.9, vmax = 0.9)
cm = axs[0][0].imshow(data.iloc[xind,xind].T, cmap='magma', vmin = -1, vmax = 1)


cbar = subfigs[1].colorbar(cm, cax = axs[0][1], shrink = 0.5)
#axs[0][1].set_xticks([0,1])
#axs[0][1].set_yticks([0,1])

#axs[0][1].annotate("1", (0.5, 1.0),  clip_on = False)
#axs[0][1].annotate("-1", (0.5, 0),  clip_on = False,)

cbar.set_label('Pearson Correlation', rotation=90, zorder = 100, labelpad = -10)
cbar.ax.set_yticks([-1, 1])
cbar.ax.set_yticklabels(['-1', '1'], zorder = 100)
#axs[0][1].set_ylabel('Pearson Correlation', clip_on = False)
#axs[0][1].set_yticks([-0.8, -0.4, 0.4, 0.8])

axs[0][0].set_title("Correlation Matrix of Predicted Expression Signatures")
axs[0][0].set_xticks([])
axs[0][0].set_yticks([])
newcmp = ListedColormap(["C0","C1","C2","C3","C4"], name='RedBlue')


#axs[1][0].imshow(np.vstack([part1[xind],part1[xind]]), extent=[0,len(part1),0,1], aspect = 'auto', cmap='Blues', vmin = 0, vmax = 5)
axs[1][0].imshow(np.vstack([saved_clusters[xind],saved_clusters[xind]]), extent=[0,len(saved_clusters),0,1], aspect = 'auto', cmap='Blues', vmin = 0, vmax = 5)
tick_locs = []
tick_a = 0
for i in range(5):
    #tick_locs += [tick_a + sum(part1 == (i + 1)) / 2]
    #tick_a += sum(part1 == (i + 1))
    tick_locs += [tick_a + sum(saved_clusters == (i + 1)) / 2]
    tick_a += sum(saved_clusters == (i + 1))

axs[1][0].set_xticks(tick_locs)
#axs[1][0].set_xticklabels(["Luminal", "Stroma", "Immune", "IFN / Basal", "Other"])
axs[1][0].set_xticklabels(["Other", "Luminal", "Basal", "Immune", "Hypoxia / Angiogensis"])
axs[1][0].get_yaxis().set_visible(False)

axs[0][1].get_yaxis().set_visible(True)
subfigs[1].subplots_adjust(left = 0, right = 0.8, top = 1.0, bottom = +1/5, wspace = 0.1, hspace = 0.1)
subfigs[1].delaxes(axs[1][1])


#axs[1][1].get_xaxis().set_visible(False)
#axs[1][1].get_yaxis().set_visible(False)

subfigs_overall[0].text(-0.03,1.02, "A", zorder = 30, clip_on = False, weight='bold', size = 20 )
subfigs[0].text(-0.13,1.02, "B", zorder = 30, clip_on = False, weight='bold', size = 20 )
subfigs[1].text(-0.03,1.02, "C", zorder = 30, clip_on = False, weight='bold', size = 20 )
subfigs[2].text(-0.03 ,1.02, "D", zorder = 30, clip_on = False, weight='bold', size = 20 )
#plt.savefig("/mnt/data/fred/figure_6.png")
plt.show()
print("Done")