## Use Sankey diagram to visualize the relationships and sharing between aging features in broad and specific cell-types and the graph partitioning of latent factors based on their age associated feature connectivity

In [None]:
!date

#### import notebooks

In [None]:
from pandas import read_csv, concat, DataFrame
from itertools import product
import plotly.offline as pyoff

#### set notebook variables

In [None]:
# parameters
project = 'aging_phase2'

# directories
wrk_dir = '/labshare/raph/datasets/adrd_neuro/brain_aging/phase2'
results_dir = f'{wrk_dir}/results'
figures_dir = f'{wrk_dir}/figures'

# out files
figure_file = f'{figures_dir}/{project}.latent.partitioned_factors.sankey.html'

# constants and variables
DEBUG = True
categories = ['broad', 'specific']
modalities = ['GEX', 'ATAC']
REGRESSION_TYPE = 'glm_tweedie'
LINK_COLUMNS = ['source', 'target', 'weight']

### load age associated features

In [None]:
results = []
for category in categories:
    for modality in modalities:
        print(category, modality)
        in_file = (f'{results_dir}/{project}.{modality}.{category}.'
                   f'{REGRESSION_TYPE}_fdr_filtered.age.csv')
        this_df = read_csv(in_file)
        this_df['category'] = category
        this_df['modality'] = modality
        results.append(this_df)
age_glm_df = concat(results)
print(f'shape of all age associated features {age_glm_df.shape}')
if DEBUG:
    display(age_glm_df.sample(4))
    display(age_glm_df.modality.value_counts())
    display(age_glm_df.category.value_counts())

### sharing of age associated features between broad and specific cell-types

#### build list of possible pairings between broad and specific

In [None]:
broad_cell_types = age_glm_df.loc[age_glm_df.category == 'broad'].tissue.unique()
specific_cell_types = age_glm_df.loc[age_glm_df.category == 'specific'].tissue.unique()
broad_specific_pairs = list(product(broad_cell_types, specific_cell_types))
print(f'found {len(broad_specific_pairs)} combinations of broad and specific cell-types')
if DEBUG:
    print(broad_cell_types)
    print(specific_cell_types)
    print(broad_specific_pairs)

#### for each possible broad/specific pairing find shared age associated features

In [None]:
broad_specific_shared = []
broad_glm_df = age_glm_df.loc[(age_glm_df.category == 'broad')]
specific_glm_df = age_glm_df.loc[(age_glm_df.category == 'specific')]
for broad_cell, specific_cell in broad_specific_pairs:
    broad_cell_df = broad_glm_df.loc[broad_glm_df.tissue == broad_cell]
    specific_cell_df = specific_glm_df.loc[specific_glm_df.tissue == specific_cell]
    normalized_weight = (len(set(broad_cell_df.feature) & set(specific_cell_df.feature))
                         /len(set(broad_cell_df.feature) | set(specific_cell_df.feature))*100)
    broad_specific_shared.append([broad_cell, specific_cell, normalized_weight])
broad_specific_links = DataFrame(data=broad_specific_shared, columns=LINK_COLUMNS)
print(f'broad_specific_links shape is {broad_specific_links.shape}')
if DEBUG:
    display(broad_specific_links.sample(5))

### visualize as Sankey diagram

#### Sankey diagramming function

In [None]:
# function from Viraj Deshpande at https://virajdeshpande.wordpress.com/portfolio/sankey-diagram/
def genSankey(df,cat_cols=[],value_cols='',title='Sankey Diagram'):
    # # maximum of 6 value cols -> 6 colors
    # colorPalette = ['#FFD43B','#646464','#4B8BBE','#306998']
    labelList = []
    # colorNumList = []
    for catCol in cat_cols:
        labelListTemp =  list(set(df[catCol].values))
        # colorNumList.append(len(labelListTemp))
        labelList = labelList + labelListTemp
        
    # remove duplicates from labelList
    labelList = list(dict.fromkeys(labelList))
    
    # # define colors based on number of levels
    # colorList = []
    # for idx, colorNum in enumerate(colorNumList):
    #     colorList = colorList + [colorPalette[idx]]*colorNum
        
    # transform df into a source-target pair
    for i in range(len(cat_cols)-1):
        if i==0:
            sourceTargetDf = df[[cat_cols[i],cat_cols[i+1],value_cols]]
            sourceTargetDf.columns = ['source','target','count']
        else:
            tempDf = df[[cat_cols[i],cat_cols[i+1],value_cols]]
            tempDf.columns = ['source','target','count']
            sourceTargetDf = pd.concat([sourceTargetDf,tempDf])
        sourceTargetDf = sourceTargetDf.groupby(['source','target']).agg({'count':'sum'}).reset_index()
        
    # add index for source-target pair
    sourceTargetDf['sourceID'] = sourceTargetDf['source'].apply(lambda x: labelList.index(x))
    sourceTargetDf['targetID'] = sourceTargetDf['target'].apply(lambda x: labelList.index(x))
    
    # creating the sankey diagram
    data = dict(
        type='sankey',
        node = dict(
          pad = 15,
          thickness = 20,
          line = dict(
            color = "purple",
            width = 0.5
          ),
          label = labelList,
          # color = colorList
            color = 'purple'
        ),
        link = dict(
          source = sourceTargetDf['sourceID'],
          target = sourceTargetDf['targetID'],
          value = sourceTargetDf['count']
        )
      )
    
    layout =  dict(
        title = title,
        font = dict(
          size = 10
        )
    )
       
    fig = dict(data=[data], layout=layout)
    return fig

In [None]:
fig = genSankey(broad_specific_links, cat_cols=['source','target'], value_cols='weight', 
                title='Sharing of features and partitioned latent factors associated with age')
pyoff.plot(fig, validate=False, filename=figure_file)