# Integrative analysis of ATAC & RNA - Time series analysis 
- recommendation: use multiple cores for the clustering part eg 32
- goal: gene clusters with similar temporal behaviour
- input: INT-DEA results (epigenetic potential/transcriptional abundance)
- output: gene clusters

In [1]:
cd ../

/home/sreichl/projects/bmdm-stim


In [None]:
# libraries
import pandas as pd
import os

import matplotlib.pyplot as plt
import numpy as np
import math

import tslearn
from tslearn.utils import to_time_series_dataset
from tslearn.clustering import TimeSeriesKMeans
from tslearn.clustering import silhouette_score
from tslearn.metrics import cdist_dtw

In [5]:
# import util functions
import sys
sys.path.insert(1, os.path.join('src'))

import utils_dimred_UMAP_PCA

In [None]:
# configs
dir_data = os.path.join('results', 'INT')
dir_results=os.path.join(dir_data, 'time_series')

In [None]:
if not os.path.exists(dir_results):
        os.mkdir(dir_results)

# load annotation data

In [None]:
# Load sample annotation
annot = pd.read_csv(os.path.join(dir_data, 'INT_annotations.csv'), index_col=0, header=0,)
print(annot.shape)
annot.head()

In [None]:
# load gene annotations from RNA analysis
gene_annot = pd.read_csv(os.path.join('results','RNA','counts','gene_annotation.tsv'), sep='\t', index_col=0, header=0,)
print(gene_annot.shape)
gene_annot.head()

In [None]:
# load all DEA results
data = pd.read_csv(os.path.join(dir_data,'DEA','INT_DEA_all.csv'), header=0, index_col='rn')
print(data.shape)
data.head()

In [None]:
times = ['0h','2h', '4h', '6h', '8h','24h']
treatments = list(annot['treatment'].unique())
treatments.remove('untreated')

# prepare data, perform cluster analysis, plot & save results for each treatment

In [None]:
for treatment in treatments:
    
    ### make results directory
    dir_treatment_results = os.path.join(dir_results, treatment)
    if not os.path.exists(dir_treatment_results):
            os.mkdir(dir_treatment_results)

    ### generate time table = DEGs x time -> LFC values
    df_list = []
    for time in times:
        if time=='0h':
            df_list.append(data.loc[data['group']=='untreated_0h','logFC'])
        else:
            df_list.append(data.loc[data['group']==treatment+'_'+time,'logFC'])

    time_table = pd.concat(df_list, axis=1)
    time_table.columns = times

    print(time_table.shape)
    time_table.head()

    ### select genes of interest
    LFC_filter = 1
    AveExpr_filter = 1 #before: 0

    # genes_sig = data.loc[(data['adj.P.Val']<0.05), ].index.unique() # only by statistical significance
    # genes_sig = data.loc[(data['adj.P.Val']<0.05) & (data['AveExpr']>1), ].index.unique() # more stringent selection criteria
    genes_sig = data.loc[(data['adj.P.Val']<0.05) & (data['AveExpr']>AveExpr_filter) & (data['logFC'].abs()>LFC_filter) & (data['treatment']==treatment), ].index.unique() # most restrictive
    print("{} : {}".format(treatment,len(genes_sig)))

    ### make plot of LFCs over time of all genes
    plot_df = time_table.loc[genes_sig,].T
    plot_df.plot.line(legend=False, alpha=0.1)
    plt.savefig(
        fname=os.path.join(dir_treatment_results, "timecourse_allGenes_"+treatment+".svg"),
        format="svg",
        dpi=300,
        bbox_inches="tight",
    )
    plt.show()

    ### perform clustering with tslearn (30min/treatment with DTW metric, euclidean much faster)

    # clustering configs
#     metric = "dtw"
    metric = "euclidean" 
    ks = list(range(2,11))+[15,20]
    
    silh_scores = pd.DataFrame(index=ks, columns=['silhouette'])

    # prepare data as time-series for analysis
    ts_data = to_time_series_dataset(time_table.loc[genes_sig,])
    print(ts_data.shape)

    for k in ks:
        print(k)
        
        # make result folder per tested k
        dir_treatment_results_k = os.path.join(dir_results, treatment, "k_{}".format(k))
        if not os.path.exists(dir_treatment_results_k):
            os.mkdir(dir_treatment_results_k)
        
        km = TimeSeriesKMeans(n_clusters=k, metric=metric, random_state=42, n_jobs=-1, verbose=False)
        km.fit(ts_data)
        #tmp_silh = silhouette_score(dtw_dist, km.labels_, metric="precomputed", n_jobs=-1, verbose=False) ############## EXPERIMENT
        silh = silhouette_score(ts_data, km.labels_, metric=metric, n_jobs=-1, verbose=False) ############## EXPERIMENT
        silh_scores.loc[k,'silhouette']=silh

        ### plot LFC over time for each gene cluster as visual validation

        # plot all clusters and their center of LFCs over time of all genes
        plt.figure(figsize=(6, math.ceil(k/3)*2), dpi=300)
        for yi in np.unique(km.labels_):
            plt.subplot(math.ceil(k/3), 3, yi+1)
            for xx in ts_data[km.labels_ == yi]:
                plt.plot(xx.ravel(), "k-", alpha=.1)
                plt.plot(km.cluster_centers_[yi].ravel(), "r-")
                plt.title('Cluster {}\n(n={})'.format((yi + 1), sum(km.labels_ == yi)))
                plt.xticks(ticks=list(range(time_table.shape[1])), labels=time_table.columns.to_list())
        #         plt.xlabel('time')
        #         plt.ylabel('LFC')
        plt.tight_layout()
        plt.savefig(
            fname=os.path.join(dir_treatment_results_k, "timecourse_clusters_"+treatment+".svg"),
            format="svg",
            dpi=300,
            bbox_inches="tight",
        )
        plt.show()

        # make cluster center dataframe and plot
        centers = pd.DataFrame()
        plt.figure()
        for label in np.unique(km.labels_):
            plt.plot(km.cluster_centers_[label].ravel(), label='Cluster {} (n={})'.format((label + 1), sum(km.labels_ == label)))
            centers = centers.append(pd.Series(km.cluster_centers_[label].ravel()), ignore_index=True)
        centers.columns=time_table.columns
        plt.xticks(ticks=list(range(time_table.shape[1])), labels=time_table.columns.to_list())
        plt.xlabel('time')
        plt.ylabel('LFC')
        plt.title('Cluster Centers silh={}'.format(round(silh,3)))
        plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
        plt.tight_layout()
        plt.savefig(
            fname=os.path.join(dir_treatment_results_k, "timecourse_clustercenters_"+treatment+".svg"),
            format="svg",
            dpi=300,
            bbox_inches="tight",
        )
        plt.show()

        ### plot dimensionality reduced data (PCA & UMAP)

        gene_annot.loc[genes_sig,'cluster'] = km.labels_.astype(int)+1
        # gene_annot.loc[[gene not in genes_sig for gene in gene_annot.index],'cluster'] = -1
        gene_annot['cluster'] =gene_annot['cluster'].astype(str)

        # plot PCA & UMAP again with final cluster labels and gene_biotype
        dimred_UMAP_PCA.dimred_plot(data=time_table.loc[genes_sig,], 
                    annot=gene_annot.loc[genes_sig,], 
                    variables=['cluster','gene_biotype'], 
                   label='{}_{}_timeseries'.format(treatment, str(len(genes_sig))),
                    results_dir=os.path.join(dir_treatment_results_k),
                   )
        
         ### save clustering results
        # save clustering 
        pd.DataFrame([genes_sig,gene_annot.loc[genes_sig,'external_gene_name'], km.labels_.astype(int)+1]).T.to_csv(os.path.join(dir_treatment_results_k,  "clustering_{}.csv".format(treatment)))
        # save cluster centers to csv
        centers.to_csv(os.path.join(dir_treatment_results_k,  "clustercenters_{}.csv".format(treatment)))
        # save model
        km.to_pickle(os.path.join(dir_treatment_results_k,  "model_{}.pickle".format(treatment)))
        
    # save silhouette scores of treatment across tested ks
    silh_scores.to_csv(os.path.join(dir_treatment_results, "silhouette_scores_"+treatment+".csv"))
    
    # plot silh scores across ks
    silh_scores= silh_scores.apply(pd.to_numeric)
    silh_scores.plot.line(figsize=(5,4),legend=False, xlabel='# of clusters (k)', ylabel='silhouette score', title='Silhouette scores of {} time-series clusters'.format(treatment))
    plt.scatter(silh_scores.idxmax(), silh_scores.max(), marker='^', color='k')
    plt.scatter(silh_scores.idxmin(), silh_scores.min(), marker='v', color='k')
    plt.xticks(ticks=ks)
    plt.tight_layout()
    plt.savefig(
        fname=os.path.join(dir_treatment_results, "silhouette_scores_"+treatment+".svg"),
        format="svg",
        dpi=300,
        bbox_inches="tight",
    )
    plt.show()