In [1]:
from sklearn.preprocessing import StandardScaler
from sklearn.metrics.pairwise import euclidean_distances
from sklearn.decomposition import PCA
from sklearn.manifold import MDS
from sklearn.manifold import TSNE
from sklearn.linear_model import LinearRegression
from skbio.diversity import beta_diversity, alpha_diversity
from skbio.stats.ordination import pcoa, pcoa_biplot
from skbio import DistanceMatrix
from scipy.stats import spearmanr, pearsonr
import statsmodels.api as sm 
import umap
from io import StringIO
from os.path import join
import pandas as pd
import xlsxwriter
import openpyxl
import os
import numpy as np
import itertools
import itertools as it
import kaleido
from pandas import Series, ExcelWriter
import scipy.io as sio
import plotly.graph_objects as go
import plotly.io as pio
import plotly.express as px
import seaborn as sns
from IPython.display import display, HTML
from fpdf import FPDF
import scanpy as sc 
from anndata import AnnData
import csv
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.lines import Line2D  # for legend handle


#### Variables

In [2]:
place = "home"

In [3]:
## Cros data
if place == "lab":
    cros_file_path = "D:/user/OneDrive/Mor_TAU/BV_project/Data/abundance_tables/"
elif place == "home":
    cros_file_path = "C:/Users/morts/Documents/TAU/BV_Project/Pseudo_time_analysis/Ravel11/PCA_noTransitions_10012022/"

cros_file_name = "ps_results_10012022"
cros_sheet_name_data = 'abundance_pseudo'
cros_sheet_name_meta = 'metadata_pseudo'
cros_sheet_name_plot = 'graph_pseudo'

## Temp data
temp_file_name = "abundance_table_Gajer2012"
temp_sheet_name_data = 'abundance'
temp_sheet_name_meta = 'metadata'
temp_sheet_name_plot = 'temp_UMAP_projection'

## PCA projected data
pca_file = 'pca_projection_dfs_04042022'
cros_pca_sheet_name = 'cros_pca_df'
temp_pca_sheet_name = 'temp_pca_df'

## KNN tables
knn_file_name = 'knn_dfs'
euc_sheet_name = 'euc_knn_res_df'
bc_sheet_name = 'bc_knn_res_df'

#### Load data

In [4]:
def get_data(all_file_path, file_name, sheet_name):
    file_full_name = all_file_path + file_name + '.xlsx'        
    df = pd.read_excel(file_full_name, sheet_name = sheet_name, index_col = 0)
    
    if "subjectID" in df.columns:
        df["subjectID"] = df["subjectID"].astype(object)
        
    return df

In [9]:
## Pseudotime DF
cros_df = get_data(cros_file_path, cros_file_name, cros_sheet_name_data)

cros_meta_df = get_data(cros_file_path, cros_file_name, cros_sheet_name_meta)
branch_df = cros_meta_df.loc[:, cros_meta_df.columns.str.startswith('branch_')]
cros_meta_df = cros_meta_df.loc[:, ~ cros_meta_df.columns.str.startswith('branch_')]

cros_fa_df = get_data(cros_file_path, cros_file_name, cros_sheet_name_plot)

## Temp DF
temp_df = get_data(cros_file_path, temp_file_name, temp_sheet_name_data)

temp_meta_df = get_data(cros_file_path, temp_file_name, temp_sheet_name_meta)
temp_meta_df['DB_type'] = 'temp'
temp_meta_df.drop('age', inplace = True, axis = 1)

temp_fa_df = get_data(cros_file_path, temp_file_name, temp_sheet_name_plot)

## PCA DF
cros_pca = get_data(cros_file_path, pca_file, cros_pca_sheet_name)
temp_pca = get_data(cros_file_path, pca_file, temp_pca_sheet_name)

## KNN DF
euc_knn_res_df = get_data(cros_file_path, knn_file_name, euc_sheet_name)
bc_knn_res_df = get_data(cros_file_path, knn_file_name, bc_sheet_name)

  warn(msg)


In [32]:
branch_cols = [col for col in branch_df.columns if col.startswith('mt_branch_')]
markers_lst = ['o', '^', 's', 'D', '*', 'P', 'X']

branch_markers = dict(zip(branch_cols, markers_lst))

community_colors = {'I':'darkred', 'II':'orange', 'III':'chocolate', 'V':'yellow', 
          'IV':'mediumturquoise', 'IV-A':'mediumpurple', 'IV-B':'royalblue'}

pio.templates.default = 'seaborn'

#### Add nearest neighbor pseudotime

In [25]:
cros_meta_df['closest_sample'] = cros_meta_df.index

In [50]:
ord_temp_meta_df = pd.concat([temp_meta_df, bc_knn_res_df[['closest_sample']]], axis = 1)
ord_temp_meta_df = pd.merge(ord_temp_meta_df, cros_meta_df[['mt_pseudotime', 'closest_sample']], on = 'closest_sample')
ord_temp_meta_df.set_index(temp_meta_df.index, inplace = True)

In [52]:
list(community_colors.keys())

['I', 'II', 'III', 'V', 'IV', 'IV-A', 'IV-B']

In [56]:
ord_temp_meta_df['community_cat'] = pd.Categorical(
    ord_temp_meta_df['community'], 
    categories = list(community_colors.keys()), 
    ordered = True)
ord_temp_meta_df.sort_values('community_cat', inplace = True)
display(ord_temp_meta_df)

Unnamed: 0_level_0,time,subjectID,ethnicity,nugent,nugent_score,community,DB_type,closest_sample,mt_pseudotime,community_cat
sampleID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
416_042606,43,24,White,0.0,Low,I,temp,S352,0.628217,I
414_060906,29,28,White,0.0,Low,I,temp,S001,0.528216,I
414_061206,89,28,White,0.0,Low,I,temp,S001,0.528216,I
418_061006,110,24,White,1.0,Low,I,temp,S212,0.478722,I
418_060706,106,24,White,0.0,Low,I,temp,S212,0.478722,I
...,...,...,...,...,...,...,...,...,...,...
410_051206,110,3,White,7.0,High,IV-B,temp,S344,0.984778,IV-B
410_050806,106,3,White,7.0,High,IV-B,temp,S344,0.984778,IV-B
410_050506,99,3,White,4.0,Int,IV-B,temp,S344,0.984778,IV-B
411_032406,5,2,Black,8.0,High,IV-B,temp,S344,0.984778,IV-B


#### Plots

In [57]:
fig1 = px.scatter(ord_temp_meta_df, x = "time", y = "mt_pseudotime", color = "community", 
                  color_discrete_map = community_colors, opacity = 0.8,
                  facet_col = 'subjectID', facet_col_wrap = 4,
                  facet_row_spacing = 0.02, facet_col_spacing = 0.02,
                  width = 800, height = 700)

for axis in fig1.layout:
    if type(fig1.layout[axis]) == go.layout.YAxis:
        fig1.layout[axis].title.text = ''
        
fig1.update_layout(
    annotations = list(fig1.layout.annotations) + 
    [go.layout.Annotation(x = -0.07, y = 0.5, font = dict(size = 14),
            showarrow = False, text = "Pseudotime", textangle = -90, xref = "paper", yref = "paper")])

fig1.update_traces(marker = dict(size = 8, line = dict(width = 2, color = 'DarkSlateGrey')),
                  selector = dict(mode = 'markers'))
        
        
fig1.show() 

In [49]:
fig2 = px.line(ord_temp_meta_df, x = "time", y = "mt_pseudotime", 
                  facet_col = 'subjectID', facet_col_wrap = 4,
                  facet_row_spacing = 0.02, facet_col_spacing = 0.02,
                  width = 800, height = 700)

for axis in fig2.layout:
    if type(fig2.layout[axis]) == go.layout.YAxis:
        fig2.layout[axis].title.text = ''
        
fig1.update_layout(
    annotations = list(fig2.layout.annotations) + 
    [go.layout.Annotation(x = -0.07, y = 0.5, font = dict(size = 14),
            showarrow = False, text = "Pseudotime", textangle = -90, xref = "paper", yref = "paper")])        
        
fig2.show() 

#### Change branches to unique for chr plots

In [None]:
def get_partial_pseudotime(df, leiden_lst, column_name):
    full_column_name = 'mt_branch_' + column_name
    df[full_column_name] = [df.loc[index, 'mt_pseudotime'] if df.loc[index, 'leiden_anno'] in leiden_lst
                           else np.nan
                           for index in df.index]

    return df

In [None]:
def get_unique(df):
    new_df = df[['leiden_anno', 'mt_pseudotime', 'time', 'subjectID', 'nugent', 'community']]
    
    ## Branch columns
    new_df = get_partial_pseudotime(new_df, [3, 0], 'I/IV') 
    new_df = get_partial_pseudotime(new_df, [6], 'II/IV')
    new_df = get_partial_pseudotime(new_df, [9, 8, 1, 7], 'III/IV')
    new_df = get_partial_pseudotime(new_df, [12], 'V/IV')
    new_df = get_partial_pseudotime(new_df, [10], 'III/V')
    new_df = get_partial_pseudotime(new_df, [11, 13], 'I/III')
    new_df = get_partial_pseudotime(new_df, [4, 5, 2, 14], 'IV')

    return new_df

In [None]:
bc_ps_unique_df = get_unique(bc_ps_df)
print(bc_ps_unique_df.info())

display(bc_ps_unique_df)

In [None]:
euc_ps_unique_df = get_unique(euc_ps_df)
print(euc_ps_unique_df.info())

#### Plot with shape as branch

In [None]:
def get_ps_to_ch_shape(df):
    grouped = df.groupby('subjectID')
    rowlength = int(grouped.ngroups/4)
    branch_cols = [col for col in df.columns if col.startswith('mt_branch_')]
    
    fig, axs = plt.subplots(figsize = (9, 15), nrows = rowlength, ncols = 4, sharey = True, sharex = True) 
    
    targets = zip(grouped.groups.keys(), axs.flatten())
    
    for i, (key, ax) in enumerate(targets):       
        temp_df = grouped.get_group(key)
        sort_df = temp_df.sort_values('time')
        ax.plot(sort_df['time'], sort_df['mt_pseudotime'], color = 'black', zorder = -1, linewidth = 0.7)
        for col, shape in branch_markers.items():
            col_temp_df = temp_df[['time', 'community', col]].dropna()
            ax.scatter(x = col_temp_df['time'], y = col_temp_df[col],
              c = col_temp_df.loc[:, 'community'].map(community_colors), 
              alpha = 0.6, s = 90, edgecolor = 'black', marker = shape)

        ax.text(0.07, 0.1, key.replace('subID_', ''), ha = 'center', va = 'center', transform = ax.transAxes)
        ax.legend().remove()
        ax.tick_params(axis = 'both', which = 'major', labelsize = 12)
    
    handles_community = [Line2D([0], [0], marker = 'o', color = 'w', markerfacecolor = v, label = k, markersize = 15) 
               for k, v in community_colors.items()]
    handles_branch = [Line2D([0], [0], marker = v, color = 'black', linestyle = 'None', label = k.replace('mt_branch_', ''), 
               markersize = 11.5) 
               for k, v in branch_markers.items()]
    fig.legend(title = 'Community \n state type', handles = handles_community, bbox_to_anchor = (0.79, 0.85), 
               loc = 'upper left', fontsize = 13, title_fontsize = 15)
    fig.legend(title = 'Branch', handles = handles_branch, bbox_to_anchor = (0.79, 0.6), 
               loc = 'upper left', fontsize = 13, title_fontsize = 15)
    
    fig.text(0.5, 0.02, 'Chronological time (days)', ha = 'center', size = 18)
    fig.text(0.03, 0.45, 'Pseudotimetime', va = 'center', rotation = 'vertical', size = 18)
    
    fig.subplots_adjust(left = 0.125, right = 0.79, bottom = 0.05, top = 0.9, wspace = 0.08, hspace = 0.08)

    return fig

In [None]:
bc_fig_shape = get_ps_to_ch_shape(bc_ps_unique_df)

In [None]:
euc_fig_shape = get_ps_to_ch_shape(euc_ps_unique_df)

In [None]:
bc_shape_path = ps_file_path + 'Gajer2012_knn_figures/' + 'bc_ps_by_branch.png'
bc_fig_shape.savefig(bc_shape_path, dpi = 500)

euc_shape_path = ps_file_path + 'Gajer2012_knn_figures/' + 'euc_ps_by_branch.png'
euc_fig_shape.savefig(euc_shape_path, dpi = 500)

#### Shannon diversity index to pseudotime

In [None]:
def get_shannon(df):
    shannon_vec = alpha_diversity(counts = df.values, metric = 'shannon')
    shannon_df = pd.DataFrame(shannon_vec, columns = ['shannon_index'])
    shannon_df = shannon_df.set_index(df.index)
    
    return shannon_df

In [None]:
def get_reg(x, y):
    z = np.polyfit(x, y, 1)
    p = np.poly1d(z)
   
    correlation = np.corrcoef(x, y)[0,1]
    SSreg = correlation**2
    
    return p, round(SSreg, 4)

In [None]:
def get_pseudo_vs_time_plots(df, y_col_name, ps_x_col_name, ch_x_col_name):
    ## Variables
    figsize_x = 8
    figsize_y = 25
    ax_size = 15
    lab_size = 18
    point_size = 130
    leg_size = 12
    
    ## DF variables
    df.replace([np.inf, -np.inf], np.nan, inplace = True)
    branch_col_lst = [col for col in df.columns if 'mt_branch_' in col and 'IV' in col and 'IV/IV' not in col]
    nrow_subplots = len(branch_col_lst)
    
    ##
    figure, axs = plt.subplots(nrow_subplots, 1, sharey = True, figsize = (figsize_x, figsize_y))
    community_colors = {'I':'crimson', 'II':'orange', 'III':'chocolate', 'V':'gold', 
              'IV':'steelblue', 'IV-A':'aqua', 'IV-B':'springgreen'}
    
    ## Plots
    for i in range(0, nrow_subplots):
        # Order x, y arrays
        curr_df = df[df[branch_col_lst[i]].notna()]
        curr_df = curr_df[curr_df[y_col_name].notna()]
        curr_df = curr_df[~ curr_df['community'].isin(['IV-A', 'IV-B'])]
        
        ps_y_label_name = y_col_name.replace('_', ' ') + ' ' + branch_col_lst[i].replace('mt_branch_', '')
        ps_x = np.array(curr_df.loc[:, branch_col_lst[i]])
        y = np.array(curr_df.loc[:, y_col_name])
        
        ######### Pseudotime plot
        ax1 = axs[i]
        ps_sc = ax1.scatter(ps_x, y,
                   c = curr_df['community'].map(community_colors),
                   s = point_size, edgecolor = 'black', alpha = 0.6)
        ax1.set_ylabel(ps_y_label_name, size = lab_size)
        ax1.tick_params(axis = 'both', which = 'major', labelsize = ax_size)
        
        if i == nrow_subplots - 1:
            ax1.set_xlabel('Pseudotime', size = lab_size)
        if i == 0:
            handles = [Line2D([0], [0], marker = 'o', color = 'w', markerfacecolor = v, label = k, markersize = leg_size) 
            for k, v in community_colors.items()]
            ax1.legend(title = 'Community state type', handles = handles, loc = 'upper left')
        
        # Add reggression line and R squared
        ps_p, ps_r = get_reg(ps_x, y)
        ax1.plot(ps_x, ps_p(ps_x), "r-o", c = 'black', linewidth = 0.8, markersize = 2) 
        
        ps_r_squared = 'R^2 = ' + str(ps_r)
        ax1.annotate(ps_r_squared, xy = (1, 0), xycoords = 'axes fraction', fontsize = lab_size,
                xytext = (-5, 5), textcoords = 'offset points',
                ha = 'right', va = 'bottom')
    
    return figure

In [None]:
shannon_df = get_shannon(temp_df)
bc_meta_df = bc_ps_df.merge(shannon_df, left_index = True, right_index = True)
euc_meta_df = euc_ps_df.merge(shannon_df, left_index = True, right_index = True)

In [None]:
bc_shannon_fig = get_pseudo_vs_time_plots(bc_meta_df, 'shannon_index', 'mt_pseudotime', 'time')
bc_path = ps_file_path + 'Shannon_Figures_16082021/' + 'shannon_fig_matplotlib_bc_knn_30082021.png'
bc_shannon_fig.savefig(bc_path, dpi = 500)

In [None]:
euc_shannon_fig = get_pseudo_vs_time_plots(euc_meta_df, 'shannon_index', 'mt_pseudotime', 'time')
euc_path = ps_file_path + 'Shannon_Figures_16082021/' + 'shannon_fig_matplotlib_euc_knn_30082021.png'
euc_shannon_fig.savefig(euc_path, dpi = 500)