In [None]:
import re
from pathlib import Path

import numpy as np
import pandas as pd

from sklearn import preprocessing

import seaborn as sns
import matplotlib.pyplot as plt

from bioinf_common.plotting import corrplot

In [None]:
sns.set_context('talk')

# Parameters

In [None]:
expr_fname = snakemake.input.expr_fname
info_fname = snakemake.input.info_fname

out_dir = Path(snakemake.output.out_dir)

# Load data

In [None]:
df_expr = pd.read_csv(expr_fname, dtype={'node': str}).set_index('node')
df_expr.head()

In [None]:
df_info = pd.read_csv(info_fname)
df_info.head()

# Count heatmap

In [None]:
plt.figure(figsize=(8, 6))
sns.heatmap(df_expr, square=False)

plt.tight_layout()
plt.savefig(out_dir / 'count_heatmap.pdf')

# Correlation clustermap

In [None]:
df_node_colors = pd.DataFrame({
    'node': df_expr.index,
    'idx': preprocessing.LabelEncoder().fit_transform([n.split('_')[0] for n in df_expr.index])
}).set_index('node')

pal = sns.color_palette(palette='tab10', n_colors=df_node_colors['idx'].nunique())
df_node_colors['color'] = df_node_colors['idx'].apply(lambda x: pal[x])
df_node_colors.drop(columns=['idx'], inplace=True)

df_node_colors.head()

In [None]:
g = sns.clustermap(
    df_expr.T.corr(),
    xticklabels=True, yticklabels=True,
    row_colors=df_node_colors, col_colors=df_node_colors
)

g.ax_heatmap.set_xticklabels(g.ax_heatmap.get_xmajorticklabels(), fontsize=8)
g.ax_heatmap.set_yticklabels(g.ax_heatmap.get_ymajorticklabels(), fontsize=8)

g.savefig(out_dir / 'clustermap.pdf')

# Compare expression levels between conditions

In [None]:
def get_pathway_type(n):
    match = re.match(r'[a-zA-Z]+\d+([a-zA-Z]+)_[a-zA-Z]+', n)
    return match.group(1) if match is not None else n.split('_')[0]

index = pd.MultiIndex.from_tuples([(get_pathway_type(n), n) for n in df_expr.index], names=['pathway_type', 'node'])
df_expr.set_index(index, inplace=True)
df_expr.head()

In [None]:
for pathway, group_expr in df_expr.groupby(level=[0]):
    bins = np.linspace(group_expr.to_numpy().min(), group_expr.to_numpy().max(), 30)
    
    plt.figure()
    for condition, group_cond in df_info.groupby('condition'):
        sns.distplot(group_expr[group_cond['sample']].values.ravel(), kde=False, label=condition, bins=bins)
    
    plt.xlabel('Gene expression')
    plt.ylabel('Frequency')
    
    plt.yscale('log')
    plt.title(f'Pathway type: {pathway}')
    plt.legend(loc='best')
    
    plt.tight_layout()
    plt.savefig(out_dir / f'expression_histogram_{pathway}.pdf')