## Use Sankey diagram to visualize the relationships between individual cell-type latent factors and their GSEA results

In [None]:
!date

#### import notebooks

In [None]:
from pandas import read_csv, concat, DataFrame
from itertools import product
import plotly.offline as pyoff
from json import load as json_load

#### set notebook variables

In [None]:
# parameters
cell_type = ''
latent_type = ''

In [None]:
# parameters
project = 'aging_phase2'

# directories
wrk_dir = '/labshare/raph/datasets/adrd_neuro/brain_aging/phase2'
results_dir = f'{wrk_dir}/results'
figures_dir = f'{wrk_dir}/figures'

# in files
assoc_file = f'{results_dir}/{project}.latent.age_glm.csv'
gsea_file = f'{figures_dir}/{project}.cell_type_latents.all.gsea_enrichr.csv'

# out files
figure_file = f'{figures_dir}/{project}.{cell_type}.{latent_type}.cell_type_latents.sankey.html'

# constants and variables
DEBUG = False
LINK_COLUMNS = ['source', 'target', 'weight']
ALPHA = 0.05
marker_sets = ['MSigDB_Hallmark',
               'KEGG']

In [None]:
if DEBUG:
    print(assoc_file)
    print(gsea_file)
    print(figure_file)

#### Sankey diagramming function

In [None]:
# function from Viraj Deshpande at https://virajdeshpande.wordpress.com/portfolio/sankey-diagram/
def genSankey(df: DataFrame, cat_cols:list=[], value_cols:str='', title:str='Sankey Diagram'):
    labelList = []
    for catCol in cat_cols:
        labelListTemp =  list(set(df[catCol].values))
        labelList = labelList + labelListTemp
        
    # remove duplicates from labelList
    labelList = list(dict.fromkeys(labelList))
            
    # transform df into a source-target pair
    for i in range(len(cat_cols)-1):
        if i==0:
            sourceTargetDf = df[[cat_cols[i],cat_cols[i+1],value_cols]]
            sourceTargetDf.columns = ['source','target','count']
        else:
            tempDf = df[[cat_cols[i],cat_cols[i+1],value_cols]]
            tempDf.columns = ['source','target','count']
            sourceTargetDf = pd.concat([sourceTargetDf,tempDf])
        print(sourceTargetDf.shape)
        sourceTargetDf = sourceTargetDf.groupby(['source','target']).agg({'count':'sum'}).reset_index()
        print(sourceTargetDf.shape)
        
    # add index for source-target pair
    sourceTargetDf['sourceID'] = sourceTargetDf['source'].apply(lambda x: labelList.index(x))
    sourceTargetDf['targetID'] = sourceTargetDf['target'].apply(lambda x: labelList.index(x))
    
    # creating the sankey diagram
    data = dict(
        type='sankey',
        node = dict(
          pad = 15,
          thickness = 20,
          line = dict(
            color = 'black',
            width = 0.5
          ),
          label = labelList,
            color = 'purple'
        ),
        link = dict(
          source = sourceTargetDf['sourceID'],
          target = sourceTargetDf['targetID'],
          value = sourceTargetDf['count'],
        )
      )
    
    layout =  dict(
        title = title,
        font = dict(
          size = 10
        )
    )
       
    fig = dict(data=[data], layout=layout)
    return fig

### load input data

#### load the latent factor age associations

In [None]:
age_glm_df = read_csv(assoc_file, index_col=0)
print(f'shape of age_glm_df is {age_glm_df.shape}')
age_glm_df['key_name'] = age_glm_df.cell_type + ':' + age_glm_df.feature
if DEBUG:
    display(age_glm_df.head())
    print(f'age_glm_df has {age_glm_df.key_name.nunique()} keys')

### subset the latent factor to only those with a statistically significant age association and are of the specified latent model type

In [None]:
if latent_type == 'all':
    age_glm_df = age_glm_df.loc[age_glm_df.fdr_bh <= ALPHA]
else:
    age_glm_df = age_glm_df.loc[(age_glm_df.fdr_bh <= ALPHA) & (age_glm_df.model_type == latent_type)]    
print(f'shape of age_glm_df is {age_glm_df.shape}')
if DEBUG:
    display(age_glm_df.head())
    display(age_glm_df.model_type.value_counts())

### subset the latent to only the broad and specific cell-types that match the cell-type being visualized

In [None]:
age_glm_df = age_glm_df.loc[age_glm_df.cell_type.str.startswith(cell_type)]
print(f'shape of age_glm_df is {age_glm_df.shape}')
if DEBUG:
    display(age_glm_df.head())
    display(age_glm_df.model_type.value_counts())
    display(age_glm_df.cell_type.value_counts())

### create the celltype to celltype latent factor links

In [None]:
cell_latent_links = age_glm_df[['cell_type', 'key_name', 'z']].copy()
cell_latent_links.z = abs(cell_latent_links.z)
cell_latent_links.columns = LINK_COLUMNS
print(f'cell_latent_links shape is {cell_latent_links.shape}')
if DEBUG:
    display(cell_latent_links.head())

### latent aging factor to GSEA links

In [None]:
gsea_df = read_csv(gsea_file, index_col=0)
print(f'gsea_df shape is {gsea_df.shape}')
if DEBUG:
    display(gsea_df.head())

#### subset to only cell-type latent factors needed for this cell-type's visualization

In [None]:
gsea_df = gsea_df.loc[gsea_df.factor.isin(age_glm_df.key_name)]
print(f'gsea_df shape is {gsea_df.shape}')
if DEBUG:
    display(gsea_df.head())

#### subset on the marker set specified

In [None]:
gsea_df = gsea_df.loc[gsea_df.Gene_set.isin(marker_sets)]
print(f'gsea_df shape is {gsea_df.shape}')
if DEBUG:
    display(gsea_df.head())

In [None]:
gsea_links = gsea_df[['factor', 'Term', 'Odds Ratio']].copy()
gsea_links.columns = LINK_COLUMNS
# some of the log odds are huge so temp just set to 1 for count sum
gsea_links.weight = 1
print(f'gsea_links shape is {gsea_links.shape}')
if DEBUG:
    display(gsea_links.head())  

#### fill None for lantent aging factors with empty GSEA enrichment

In [None]:
lists_to_add = []
missing_latents = set(cell_latent_links.target) - set(gsea_links.source)
print(missing_latents)
for latent in missing_latents:
    this_item = [latent, 'No Enrichments', 1]
    lists_to_add.append(this_item)
misssing_df = DataFrame(data=lists_to_add, columns=LINK_COLUMNS)
print(f'shape of misssing_df {misssing_df.shape}')
gsea_links = concat([gsea_links, misssing_df])
print(f'updated gsea_links shape {gsea_links.shape}')
if DEBUG:
    display(gsea_links.head())
    display(gsea_links.tail())   

### visualize as Sankey diagram

### combine the link data

In [None]:
links_df = concat([cell_latent_links, gsea_links])
print(f'shape of all links to include {links_df.shape}')
if DEBUG:
    display(links_df.head())

In [None]:
fig = genSankey(links_df, cat_cols=['source','target'], value_cols='weight',
                title=('Sharing of features and partitioned latent aging '
                       f'factors associated with age for {cell_type} using {latent_type.upper()}'))
pyoff.plot(fig, validate=False, filename=figure_file)

In [None]:
!date