In [1]:
from joblib import load
from os.path import join
# import argparse
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

from mvmm_sim.tcga.TCGAPaths import TCGAPaths
from mvmm_sim.simulation.sim_viz import save_fig
from mvmm_sim.data_analysis.multi_view.viz_resuls import plot_Pi
from mvmm_sim.data_analysis.survival import plot_survival


inches = 10
dpi = 200

plt.rcParams["axes.labelsize"] = 25
sns.set_style("whitegrid")

save_dir = 'tcga_figures'

In [6]:
cancer_type = 'BRCA'
feat_list = 'icluster'
v0 = 'rna'
v1 = 'cp' # 'cp', 'mi_rna', 'dna_meth'

pro_data_dir = join(TCGAPaths().pro_data_dir, cancer_type)
results_dir = '' # Set the results directory!

interp_dir = join(results_dir, 'interpret', 'bd_mvmm')



In [7]:
pi_data = load(join(interp_dir, 'pi_data'))
# bd_mvmm = load(join(results_dir, 'model_fitting', 'selected_models'))['bd_mvmm'].final_
# survival = load(join(interp_dir, 'survival'))

# Block diagonal Pi matrix

In [8]:
D = pi_data['Pi_block_perm']
D.index.name = "RNA Cluster"
D.columns.name = "Copy Number cluster"

plt.figure(figsize=(inches, inches))
plot_Pi(D, mask=pi_data['Pi_block_perm_zero_mask'], cmap="Blues", cbar=True, square=False)

save_fig(join(save_dir, '{}_bd_pi.png'.format(v1)), dpi=dpi)

# metadata comparison

In [9]:
block_compare = load(join(interp_dir, 'block', 'metadata_comparisons_block'))

compare_subtype = block_compare.comparisons_.loc['block', 'Subtype_mRNA']

cross_counts = compare_subtype.cross_
cross_counts.index.name = 'MVMM Block'
cross_counts.columns.name = 'PAM50 Subtype'

cross_counts[cross_counts == 0] = np.nan

plt.figure(figsize=(4, 8))
sns.heatmap(cross_counts.T, vmin=0, cmap='Blues',
            linewidths=.5,
            annot=True, fmt='1.0f', cbar=False)

plt.xticks(fontsize=20)
plt.yticks(fontsize=20)

save_fig(join(save_dir, '{}_block_vs_pam50_subtype.png'.format(v1)), dpi=dpi)


# Survival

In [None]:
pval = survival['block']['pval']
plt.figure(figsize=(inches, inches))
plot_survival(df=survival['block']['df'], cat_col='cluster')
plt.xlabel("Time (days)")
plt.ylabel("Progression Free Interval")
plt.title('{} vs. block label, p = {:1.3f}'.format('PFI', pval))
save_fig(join(save_dir, '{}_block_vs_survial.png'.format(v1)), dpi=dpi)
