In [None]:
import os

try:
    snakemake
except NameError:
    from snakemk_util import load_rule_args
    
    os.chdir('../..')
    
    snakemake = load_rule_args(
        snakefile = os.getcwd() + "/workflow/Snakefile",
        rule_name = 'differential_apa_c2c12',
        root = '..',
    )

In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
from lapa.result import LapaResult
from met_brewer import met_brew
import matplotlib.pyplot as plt

sns.set_context('poster')

In [None]:
result = LapaResult(snakemake.input['long_read'])
samples = result.samples

In [None]:
df = result.fisher_exact_test({'undif': ['PB154', 'PB155'], 'dif': ['PB213', 'PB214']}, min_gene_count=50)

In [None]:
df['-log10(p-value corrected)'] = -np.log10(df['pval_adj'])

In [None]:
df = df.reset_index().rename(columns={'index': 'polya_site'})
df = df.merge(pd.read_csv(snakemake.input['mapping']), on='gene_id')

In [None]:
from adjustText import adjust_text

df['significant'] = np.where(
    (df['delta_usage'].abs() > 0.3) & (df['pval_adj'] < 0.05),
    np.where(df['delta_usage'] > 0.3, 'Up', 'Down'),
    'No Sig.'
)

colors = met_brew('Cassatt1', n=7, brew_type="continuous")

plt.figure(dpi=300, figsize=(10, 10))

sns.scatterplot(data=df, x='delta_usage', y='-log10(p-value corrected)', hue='significant',
                s = 40 * (df['significant'] != 'No Sig.') + 20, 
                palette = ['gray', colors[-1], colors[0]], 
                alpha = list(0.5 + (df['significant'] != 'No Sig.') * 0.5))

plt.legend(title=None, bbox_to_anchor=(0.95, 0.6))
plt.xlabel('$\Delta usage$')
plt.ylabel('$-\log_{10}(P_{corrected})$')

_df = df[(df['delta_usage'].abs() > 0.3) & (df['-log10(p-value corrected)'] > 150)]

text = [
    plt.text(row['delta_usage'], row['-log10(p-value corrected)'], row['gene_name'], fontsize=20)
    for _, row in _df.iterrows()
]
adjust_text(text)
sns.despine(offset=10, trim=True)

plt.savefig(snakemake.output['volcona_plot'], dpi=400, bbox_inches='tight', transparent=True)

In [None]:
df = df.loc[df.groupby('gene_id')['pval_adj'].idxmin()]
df = df[(df['pval_adj'] < 0.05) & (df['delta_usage'].abs() > 0.3)]

In [None]:
df.sort_values('delta_usage')

In [None]:
usage = result.attribute('usage').loc[df['polya_site']][['PB154', 'PB155', 'PB213', 'PB214']]

In [None]:
import seaborn as sns
import matplotlib.patches as mpatches

colors = met_brew('Cassatt1', n=9, brew_type="continuous")[::-1]

col_colors = met_brew('Nattier')
col_colors = [col_colors[2], col_colors[4]]

fig = sns.clustermap(usage.fillna(0), figsize=(9, 12),
                     col_colors=[col_colors[0], col_colors[0], col_colors[1], col_colors[1]],
                     xticklabels=False, yticklabels=False, col_cluster=False,
                     cmap=sns.color_palette(colors, as_cmap=True), 
                     cbar_pos=(-0.085, 0.8, 0.05, 0.18), dendrogram_ratio=(0.1, 0))

fig.ax_heatmap.set_xlabel('Undifferentiated' + ' ' * 7 +'Differentiated')
fig.ax_heatmap.set_ylabel('Poly(A)-site in each gene')

plt.savefig(snakemake.output['heatmap_plot'], dpi=400, bbox_inches='tight', transparent=True)