In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import LogNorm

In [None]:
datasets = [
    "Gaublomme_GSE75109_TPM_clean",
    "Gaublomme_GSE75110_TPM_clean",
    "Gaublomme_GSE75111_TPM_clean",
    "kakadarov_tpm",
    "somatosensory_converted_into_tpm",
    "Dopaminergic_TPM_clean",
    "Rbp4_positive_cells",
    "Cheng_ES_TPM",
    "Alveolar_cells_Type_II_Merged_Batches",
    "Alveolar_cells_Type_I_Merged_Batches",
#     "Alveolar_cells_both_types",
    "klein",
    "hepat_TPM_yang_clean",
    "Yu_First_wave_endocrine_cells",
]

families = "clean_panther4march"

dfs = {x: pd.read_csv("{}/{}/results/family_IC.csv".format(x, families), index_col="family_id") for x in datasets}
dichotomised_dfs = {x: pd.read_csv("{}/{}/intermediate/dichotomised_genes.csv".format(x, families), index_col="gene_id") for x in datasets}

In [None]:
cell_type_names = {
    "somatosensory_converted_into_tpm": "Somatosensory N",
    "Dopaminergic_TPM_clean": "Dopaminergic N",
    "kakadarov_tpm": "CD8+ T cell",
    "Cheng_ES_TPM": "Isolated ESC",
    "Gaublomme_GSE75109_TPM_clean": "Th17 A",
    "Gaublomme_GSE75110_TPM_clean": "Th17 B",
    "Gaublomme_GSE75111_TPM_clean": "Th17 C",
    "Rbp4_positive_cells": "Corticostriatal N",
    "Alveolar_cells_Type_I_Merged_Batches": "Lung ACI",
    "Alveolar_cells_Type_II_Merged_Batches": "Lung ACII",
#     "Alveolar_cells_both_types",
    "klein": "Cultured ESC",
    "hepat_TPM_yang_clean": "Liver HB/HC",
    "Yu_First_wave_endocrine_cells": "Pancreatic EC",
}

In [None]:
family_df = pd.read_csv("{}.csv".format(families))

In [None]:
df = dfs["Gaublomme_GSE75109_TPM_clean"]
df.loc["PTHR23430"]

In [None]:
_x = None
def plot_heatmaps(k, k_name, v, font_scale=1.4):
    sns.set(font_scale=font_scale, style="ticks", font="Arial")
    base = 5.2
    basey = 0.22
    def get_L(dataset):
        dichotomised_df = dichotomised_dfs[dataset]
        x = dichotomised_df.loc[family_df[family_df.family_id == k].gene_symbol].dropna().copy()#.T.describe()
        x = x[x.sum(axis=1) > 0]
        return x.shape[0]
    print([get_L(x) for x in v], max([get_L(x) for x in v]))
    l_y = basey * max([get_L(x) for x in v])
    
    l = base*len(v) if len(v) > 1 else base * 0.9
    fig, axs = plt.subplots(1,len(v), figsize=(l,l_y))
    fig_hist, axs_hist = plt.subplots(1,len(v), figsize=(l,base))
    if len(v) == 1:
        axs = [axs]
        axs_hist = [axs_hist]
    for i, dataset in enumerate(v):
        df = dfs[dataset]
        dichotomised_df = dichotomised_dfs[dataset]
        x = dichotomised_df.loc[family_df[family_df.family_id == k].gene_symbol].dropna().copy()#.T.describe()
        x = x[x.sum(axis=1) > 0]
        
        print(dataset,"Mean ON-cell TPM:",x.mean().mean())
        exp_per_cell = x.sum(axis=0)
        
        x = x.loc[:, x.sum(axis=0).sort_values().index]
        def get_first_index(row):
            for i in range(len(row)):
                if row.iloc[i] == 1.0:
                    return i
            return len(row)

        x = x.loc[x.apply(get_first_index, axis=1).sort_values().index, :]
        sns.heatmap(x,ax=axs[i],xticklabels=False,yticklabels=True, cbar=False, cmap="gray_r", vmin=-0.03, vmax=1.0)#, linewidth=0.1, linecolor="grey")
#         sns.heatmap(x.loc[x.sum(axis=1).sort_values().index, x.sum(axis=0).sort_values().index],ax=axs[1,i])
        sns.distplot(exp_per_cell, kde=False, ax=axs_hist[i], hist_kws={"range": [exp_per_cell.min()-0.5, exp_per_cell.max()+0.5]}, bins=int(exp_per_cell.max() - exp_per_cell.min() + 1), color="grey")
        axs[i].title.set_text("{} (IC={:.2f})".format(cell_type_names[dataset], df.loc[k].ic))
#         axs[i].set_yticklabels(axs[i].get_yticklabels(), rotation=45)
        axs[i].set_ylabel("")
        axs_hist[i].set_xlabel("Number of ON genes")
        axs_hist[i].title.set_text("{} (IC={:.2f})".format(cell_type_names[dataset], df.loc[k].ic))
    fig.subplots_adjust(wspace=0.4)
    return x
#     plt.show()
#     return summary_rows

In [None]:
def get_genes(k, dataset):
    dichotomised_df = dichotomised_dfs[dataset]
    x = dichotomised_df.loc[family_df[family_df.family_id == k].gene_symbol].dropna().copy()#.T.describe()
    x = x[x.sum(axis=1) > 0]
    
    x = x.loc[:, x.sum(axis=0).sort_values().index]
    def get_first_index(row):
        for i in range(len(row)):
            if row.iloc[i] == 1.0:
                return i
        return len(row)

    x = x.loc[x.apply(get_first_index, axis=1).sort_values().index, :]
    
    return x

In [None]:
x = get_genes("PTHR23268", "Gaublomme_GSE75111_TPM_clean")

sns.set(font_scale=1.2, style="ticks", font="Arial")
plt.figure(figsize=(4.2*0.9, x.shape[0]*0.19))
sns.heatmap(x, xticklabels=False,yticklabels=True, cbar=False, cmap="gray_r", vmin=-0.03, vmax=1.0)#, linewidth=0.1, linecolor="grey")
ic = dfs["Gaublomme_GSE75111_TPM_clean"].loc["PTHR23268"].ic
plt.title("{} (IC={:.2f})".format(cell_type_names["Gaublomme_GSE75111_TPM_clean"], ic))
plt.ylabel("")
plt.show()

sns.set(font_scale=1.4, style="ticks", font="Arial")
plt.figure(figsize=(1.2, x.shape[0]*0.19))
exp_per_cell = x.sum(axis=0)
sns.distplot(exp_per_cell, kde=False, hist_kws={"range": [exp_per_cell.min()-0.5, exp_per_cell.max()+0.5]}, bins=int(exp_per_cell.max() - exp_per_cell.min() + 1), color="grey")
plt.xlabel("Number of ON genes")
plt.xticks(list(range(0,4)))
plt.show()

In [None]:
x = plot_heatmaps("PTHR23430", "HISTONE H2A (PTHR23430)", ["Gaublomme_GSE75111_TPM_clean", "hepat_TPM_yang_clean"])

In [None]:
x = plot_heatmaps("PTHR23268", "T-CELL RECEPTOR BETA CHAIN (PTHR23268)", ["Gaublomme_GSE75111_TPM_clean"])

In [None]:
x = plot_heatmaps("PTHR18952", "CARBONIC ANHYDRASE (PTHR18952)", ["Gaublomme_GSE75111_TPM_clean", "somatosensory_converted_into_tpm", "hepat_TPM_yang_clean"], 2.1)

In [None]:
def get_first_index(row):
    for i in range(len(row)):
        if row.iloc[i] == 1.0:
            return i
    return len(row)

x.apply(get_first_index, axis=1)

In [None]:
plot_heatmaps("PTHR19443", "HEXOKINASE (PTHR19443)", ["Cheng_ES_TPM", "Yu_First_wave_endocrine_cells"])