## Use Sankey diagram to visualize the relationships and sharing between aging features in broad and specific cell-types and the graph partitioning of their age associated feature connectivity

In [None]:
!date

#### import notebooks

In [None]:
from pandas import read_csv, concat, DataFrame
from itertools import product
import plotly.offline as pyoff
from json import load as json_load

#### set notebook variables

In [None]:
# parameters
project = 'aging_phase2'

# directories
wrk_dir = '/labshare/raph/datasets/adrd_neuro/brain_aging/phase2'
results_dir = f'{wrk_dir}/results'
figures_dir = f'{wrk_dir}/figures'

# in files
part_file = f'{figures_dir}/{project}.association.partitioned_factors.json'
gsea_file = f'{figures_dir}/{project}.features.gsea_enrichr.csv'

# out files
figure_file = f'{figures_dir}/{project}.association.partitions.sankey.html'

# constants and variables
DEBUG = True
categories = ['broad', 'specific']
modalities = ['GEX', 'ATAC']
REGRESSION_TYPE = 'glm_tweedie'
LINK_COLUMNS = ['source', 'target', 'weight']

#### Sankey diagramming function

In [None]:
# function from Viraj Deshpande at https://virajdeshpande.wordpress.com/portfolio/sankey-diagram/
def genSankey(df: DataFrame, cat_cols:list=[], value_cols:str='', title:str='Sankey Diagram'):
    labelList = []
    for catCol in cat_cols:
        labelListTemp =  list(set(df[catCol].values))
        labelList = labelList + labelListTemp
        
    # remove duplicates from labelList
    labelList = list(dict.fromkeys(labelList))
            
    # transform df into a source-target pair
    for i in range(len(cat_cols)-1):
        if i==0:
            sourceTargetDf = df[[cat_cols[i],cat_cols[i+1],value_cols]]
            sourceTargetDf.columns = ['source','target','count']
        else:
            tempDf = df[[cat_cols[i],cat_cols[i+1],value_cols]]
            tempDf.columns = ['source','target','count']
            sourceTargetDf = pd.concat([sourceTargetDf,tempDf])
        print(sourceTargetDf.shape)
        sourceTargetDf = sourceTargetDf.groupby(['source','target']).agg({'count':'sum'}).reset_index()
        print(sourceTargetDf.shape)
        
    # add index for source-target pair
    sourceTargetDf['sourceID'] = sourceTargetDf['source'].apply(lambda x: labelList.index(x))
    sourceTargetDf['targetID'] = sourceTargetDf['target'].apply(lambda x: labelList.index(x))
    
    # creating the sankey diagram
    data = dict(
        type='sankey',
        node = dict(
          pad = 15,
          thickness = 20,
          line = dict(
            color = 'black',
            width = 0.5
          ),
          label = labelList,
            color = 'purple'
        ),
        link = dict(
          source = sourceTargetDf['sourceID'],
          target = sourceTargetDf['targetID'],
          value = sourceTargetDf['count'],
        )
      )
    
    layout =  dict(
        title = title,
        font = dict(
          size = 10
        )
    )
       
    fig = dict(data=[data], layout=layout)
    return fig

### load age associated features

In [None]:
results = []
for category in categories:
    for modality in modalities:
        print(category, modality)
        in_file = (f'{results_dir}/{project}.{modality}.{category}.'
                   f'{REGRESSION_TYPE}_fdr_filtered.age.csv')
        this_df = read_csv(in_file)
        this_df['category'] = category
        this_df['modality'] = modality
        results.append(this_df)
age_glm_df = concat(results)
print(f'shape of all age associated features {age_glm_df.shape}')
if DEBUG:
    display(age_glm_df.sample(4))
    display(age_glm_df.modality.value_counts())
    display(age_glm_df.category.value_counts())

### feature modality to cell-types links

In [None]:
modality_types = age_glm_df.modality.unique()
cell_types = age_glm_df.tissue.unique()
modality_cell_pairs = list(product(modality_types, cell_types))
print(f'found {len(modality_cell_pairs)} combinations of modalities and cell-types')
if DEBUG:
    print(modality_types)
    print(cell_types)
    print(modality_cell_pairs)

In [None]:
modality_cell_shared = []
for modality, cell_type in modality_cell_pairs:
    modality_glm_df = age_glm_df.loc[age_glm_df.modality == modality]
    cell_glm_df = age_glm_df.loc[age_glm_df.tissue == cell_type]
    normalized_weight = (len(set(modality_glm_df.feature) & set(cell_glm_df.feature))
                         /modality_glm_df.feature.nunique())*100
    modality_cell_shared.append([modality, cell_type, normalized_weight])
modality_cell_links = DataFrame(data=modality_cell_shared, columns=LINK_COLUMNS)
print(f'modality_cell_links shape is {modality_cell_links.shape}')
if DEBUG:
    display(modality_cell_links.sample(5))

### broad to specific cell-type links
sharing of age associated features between broad and specific cell-types

#### build list of possible pairings between broad and specific

In [None]:
broad_cell_types = age_glm_df.loc[age_glm_df.category == 'broad'].tissue.unique()
specific_cell_types = age_glm_df.loc[age_glm_df.category == 'specific'].tissue.unique()
broad_specific_pairs = list(product(broad_cell_types, specific_cell_types))
print(f'found {len(broad_specific_pairs)} combinations of broad and specific cell-types')
if DEBUG:
    print(broad_cell_types)
    print(specific_cell_types)
    print(broad_specific_pairs)

#### for each possible broad/specific pairing find shared age associated features

In [None]:
broad_specific_shared = []
broad_glm_df = age_glm_df.loc[(age_glm_df.category == 'broad')]
specific_glm_df = age_glm_df.loc[(age_glm_df.category == 'specific')]
for broad_cell, specific_cell in broad_specific_pairs:
    broad_cell_df = broad_glm_df.loc[broad_glm_df.tissue == broad_cell]
    specific_cell_df = specific_glm_df.loc[specific_glm_df.tissue == specific_cell]
    normalized_weight = (len(set(broad_cell_df.feature) & set(specific_cell_df.feature))
                         /len(set(broad_cell_df.feature) | set(specific_cell_df.feature))*100)
    broad_specific_shared.append([broad_cell, specific_cell, normalized_weight])
broad_specific_links = DataFrame(data=broad_specific_shared, columns=LINK_COLUMNS)
print(f'broad_specific_links shape is {broad_specific_links.shape}')
if DEBUG:
    display(broad_specific_links.sample(5))

### cell-type to partitions

In [None]:
with open(part_file, 'r') as in_file:
    partitions = json_load(in_file)
print(f'length of partitions is {len(partitions)}')

#### add links between cell-types and graph partitions

In [None]:
age_latents = []
for part_index, latents in partitions.items():
    latent_name = f'Aging-{part_index}'
    pairs = [element.split(':')[0] for element in latents]
    for cell_type in set(pairs):
        age_latents.append([cell_type, latent_name, 1])
cell_partitions_links = DataFrame(data=age_latents, columns=LINK_COLUMNS)
print(f'cell_partitions_links shape is {cell_partitions_links.shape}')
if DEBUG:
    display(cell_partitions_links.sample(5))        

### partitions to GSEA links

In [None]:
gsea_df = read_csv(gsea_file, index_col=0)
print(f'gsea_df shape is {gsea_df.shape}')
if DEBUG:
    display(gsea_df.head())

In [None]:
gsea_links = gsea_df[['factor', 'Term', 'Odds Ratio']]
gsea_links.columns = LINK_COLUMNS
print(f'gsea_links shape is {gsea_links.shape}')
if DEBUG:
    display(gsea_links.head())  

#### fill None for partitions with empty GSEA enrichment

In [None]:
lists_to_add = []
missing_parts = set(cell_partitions_links.target) - set(gsea_links.source)
print(missing_parts)
for partition in missing_parts:
    this_item = [partition, 'No Enrichments', 1]
    lists_to_add.append(this_item)
misssing_df = DataFrame(data=lists_to_add, columns=LINK_COLUMNS)
print(f'shape of misssing_df {misssing_df.shape}')
gsea_links = concat([gsea_links, misssing_df])
print(f'updated gsea_links shape {gsea_links.shape}')
if DEBUG:
    display(gsea_links.head())
    display(gsea_links.tail())   

### visualize as Sankey diagram

### combine the link data

In [None]:
links_df = concat([modality_cell_links, broad_specific_links, cell_partitions_links, gsea_links])
print(f'shape of all links to include {links_df.shape}')
if DEBUG:
    display(links_df.sample(5))

In [None]:
fig = genSankey(links_df, cat_cols=['source','target'], value_cols='weight',
                title='Sharing of age associated features and their graph partitions with GSEA')
pyoff.plot(fig, validate=False, filename=figure_file)

In [None]:
!date