In [21]:
from joblib import load
from os.path import join
# import argparse
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd


from mvmm.simulation.sim_viz import save_fig
# from mvmm.data_analysis.utils import load_data
# from mvmm.simulation.utils import make_and_get_dir
from mvmm.data_analysis.multi_view.viz_resuls import plot_Pi

inches = 15
dpi = 200

plt.rcParams["axes.labelsize"] = 25
sns.set_style("whitegrid")

save_dir = 'mouse_et_figures'

# load data

In [22]:
fpaths=['/Users/iaincarmichael/Dropbox/Research/mvmm/simulations/mouse_et/pro_data/transcriptomic_select_markers_pca_feats.csv',
        '/Users/iaincarmichael/Dropbox/Research/mvmm/simulations/mouse_et/pro_data/ephys_pca_feats.csv']
results_dir='/Users/iaincarmichael/Dropbox/Research/mvmm/simulations/mouse_et/analysis_select_markers_pca/'

vars2compare_fpath='/Users/iaincarmichael/Dropbox/Research/mvmm/simulations/mouse_et/pro_data/vars2compare.csv'
metadata_fpath = '/Users/iaincarmichael/Dropbox/Research/mvmm/simulations/mouse_et/pro_data/metadata.csv'

interp_dir = join(results_dir, 'interpret', 'bd_mvmm')



In [7]:
vars2compare = pd.read_csv(vars2compare_fpath, index_col=0)

# map super types to sub types
super2sub = vars2compare.groupby('transcr_super_type')['transcr_subtype'].unique()

pi_data = load(join(interp_dir, 'pi_data'))
bd_mvmm = load(join(results_dir, 'model_fitting', 'selected_models'))['bd_mvmm'].final_

# BD Pi matrix

In [23]:
D = pi_data['Pi_block_perm']
D.index.name = "RNA clusters"
D.columns.name = "Ephys clusters"


plt.figure(figsize=(inches, inches))
plot_Pi(D, mask=pi_data['Pi_block_perm_zero_mask'], cmap="Blues", cbar=True)

save_fig(join(save_dir, 'mouse_et_bd_pi.png'), dpi=dpi)

# Block labels vs. transcriptomics subtypes

In [11]:
block_compare = load(join(interp_dir, 'block', 'metadata_comparisons_block'))
# block_compare = block_compare.comparisons_.iloc[0, 0]

compare_subtype = block_compare.comparisons_.loc['block', 'transcr_subtype']
compare_supertype = block_compare.comparisons_.loc['block', 'transcr_super_type']

cross_counts = compare_subtype.cross_
cross_counts.index.name = 'MVMM Block'
cross_counts.columns.name = 'Transcriptomic subtype'

sorted_cols = cross_counts.sum(axis=0).sort_values(ascending=False).index
cross_counts = cross_counts[sorted_cols]

# order super labels by size
ordered_super_labels = vars2compare['transcr_super_type'].value_counts().index
super2sub = super2sub.loc[ordered_super_labels]
subtype_super_ordering = np.concatenate([super2sub[s] for s in super2sub.index])
cross_counts = cross_counts[subtype_super_ordering]

# where the super categories change
super_sizes = [len(super2sub[s]) for s in super2sub.index]
break_idxs = np.cumsum(super_sizes)

super_colors = sns.color_palette("Set2", super2sub.shape[0])
tick_colors = np.concatenate([[super_colors[i]] * super_sizes[i] for i in range(len(super_sizes))])

# set zeros to Nans to automatically kill zeros
cross_counts[cross_counts == 0] = np.nan

In [24]:
# plt.figure(figsize=(5, 15))
plt.figure(figsize=(inches / 3, inches))

sns.heatmap(cross_counts.T, vmin=0, cmap='Blues',
            # mask=(cross_counts == 0).values.T,
            linewidths=.5,
            annot=True, fmt='1.0f', cbar=False)

for idx in break_idxs[:-1]:
    plt.axhline(idx, color='black')
    
    
for idx, tick in enumerate(plt.gca().get_yticklabels()):
    tick.set_color(tick_colors[idx])

    
pos_y = [10, 28, 40, 53, 57, 60]
pox_x = [-1.3, -1.3, -1.3, -1.3, -1.1, -1.3]
for idx, super_label in enumerate(super2sub.index):
    
    plt.annotate(super_label, xy=(0, pos_y[idx]), xytext=(pox_x[idx], pos_y[idx]), 
                 rotation=90, color=super_colors[idx], fontweight='bold', fontsize=15)

save_fig(join(save_dir, 'mouse_et_block_vs_transcr.png'), dpi=dpi)

In [None]:
# plt.figure(figsize=(20, 3))
# sns.heatmap(cross_counts, vmin=0, cmap='Blues',
#             linewidths=.5,
#             annot=True, fmt='1.0f', cbar=False)

# for idx in break_idxs:
#     plt.axvline(idx, color='black')
    
    
# for idx, tick in enumerate(plt.gca().get_xticklabels()):
#     tick.set_color(tick_colors[idx])