## Create a graph of the age associated features across broad and specific cell-types

In [None]:
!date

#### import libraries

In [None]:
from pandas import read_csv, concat
from igraph import Graph, Plot, union
from igraph.drawing.colors import ClusterColoringPalette
import matplotlib.pyplot as plt
from IPython.display import Image
from matplotlib.image import imread
import leidenalg
from json import dump as json_dump
from matplotlib.pyplot import rc_context
import igraph as ig
from matplotlib.patches import Patch

%matplotlib inline
# for white background of figures (only for docs rendering)
%config InlineBackend.print_figure_kwargs={'facecolor' : "w"}
%config InlineBackend.figure_format='retina'

#### set notebook variables

In [None]:
# parameters
project = 'aging_phase2'

# directories
wrk_dir = '/labshare/raph/datasets/adrd_neuro/brain_aging/phase2'
results_dir = f'{wrk_dir}/results'
figures_dir = f'{wrk_dir}/figures'

# in files

# out files
graphml_file = f'{figures_dir}/{project}.association.graphml'
gml_file = f'{figures_dir}/{project}.association.gml'
image_file = f'{figures_dir}/{project}.association_graph.png'
communities_file = f'{figures_dir}/{project}.association.partitioned_factors.json'

# constants and variables
DEBUG = True
modalities = ['GEX', 'ATAC']
categories = ['broad', 'specific']

### load input data

#### load age associated feature results
get the age associated GEX and ATAC features detected per cell-type

In [None]:
%%time
age_results = []
for modality in modalities:
    for category in categories:
        print(modality, category)
        in_file = f'{results_dir}/{project}.{modality}.{category}.glm_tweedie_fdr_filtered.age.csv'
        this_df = read_csv(in_file)
        this_df['modality'] = modality
        this_df['category'] = category
        age_results.append(this_df)
age_results_df = concat(age_results)
print(f'shape of the age results is {age_results_df.shape}')
if DEBUG:
    display(age_results_df.sample(5))
    display(age_results_df.modality.value_counts())
    display(age_results_df.tissue.value_counts())
    display(age_results_df.category.value_counts())    

### convert the feature age associations into a weighted graph per cell-type

In [None]:
%%time
age_graphs = {}
for cell_type in age_results_df.tissue.unique():
    # create the empty graph
    age_graph = Graph()
    age_graph.add_vertex(name=cell_type, type='cell_type')
    # add features as vertices by modality
    cell_results_df = age_results_df.loc[(age_results_df.tissue == cell_type)]
    for modality in cell_results_df.modality.unique():
        modality_results = cell_results_df.loc[(cell_results_df.modality == modality)]
        for feature in modality_results.feature.unique():
            age_graph.add_vertex(name=feature, type=f'{modality}_feature')
    # add the age associations as edges
    for row in cell_results_df.itertuples():
        age_graph.add_edge(row.tissue, row.feature, type=row.modality, 
                             category=row.category, effect=abs(row.z))
    # save this cell-type's graph
    cell_graphml_file = f'{figures_dir}/{project}.{cell_type}.association_graph.graphml'
    age_graph.write_graphml(cell_graphml_file)
    # add to dict of graphs
    age_graphs[cell_type] = age_graph
    if DEBUG:
        print(f'{cell_type} has {age_graph.vcount()} vertices')
        print(f'{cell_type} has {age_graph.ecount()} edges')
print(f'{len(age_graphs)} graphs created and saved')
if DEBUG:
    print(f'a graph for each of these cell-types was created: {age_graphs.keys()}')

### create consensus graph across cell-types by intersecting the graphs

In [None]:
consensus_graph = union(age_graphs.values(), byname=True)
if DEBUG:
    print(f'consensus_graph has {consensus_graph.vcount()} vertices')
    print(f'consensus_graph has {consensus_graph.ecount()} edges')

### partition the graph

In [None]:
%%time
graph_cluster = leidenalg.find_partition(consensus_graph, leidenalg.ModularityVertexPartition, 
                                         n_iterations=25)
# add the partition info to the graph
consensus_graph.vs['membership'] = graph_cluster.membership
if DEBUG:
    print(len(graph_cluster))
    print(len(graph_cluster.membership))
    print(type(graph_cluster))

#### inspect the partitioned latent factors

In [None]:
cell_nodes = consensus_graph.vs.select(type='cell_type')
print(len(cell_nodes))

In [None]:
community_factors = {}
for vertex in cell_nodes:
    members = community_factors.get(vertex.attributes().get('membership'))
    if members:
        members.append(vertex.attributes().get('name'))
    else:
        members = [vertex.attributes().get('name')]
    community_factors[vertex.attributes().get('membership')] = members

In [None]:
display(community_factors)

#### save the partitioned latent factor communities

In [None]:
with open(communities_file, 'w') as o_file:
    json_dump(community_factors, o_file, indent=4)

### save the graph

In [None]:
consensus_graph.write_graphml(graphml_file)

### draw the graph visualization

In [None]:
%%time
visual_style = {}
visual_style['bbox'] = (1600, 1200)
visual_style['margin'] = 50
layout_algorithm = 'drl' # 'fruchterman_reingold', 'drl', 'lgl', others available but much slower

p = Plot(image_file, bbox=(1600, 1200), background='white')
layout = consensus_graph.layout(layout_algorithm)
pal = ClusterColoringPalette(len(graph_cluster))
consensus_graph.vs['color'] = pal.get_many(graph_cluster.membership)
consensus_graph.es['color'] = 'rgba(192, 192, 192, 0.3)'
p.add(consensus_graph, layout=layout, **visual_style)
p.redraw()

# p.show()
p.save()

#### annotate the visualization

In [None]:
%%time
# Create a legend
unique_clusters = set(graph_cluster.membership)
legend_elements = [Patch(facecolor=pal[cluster], label=f'Aging-{cluster}') 
                   for cluster in unique_clusters]

# Load the image
img = imread(image_file) 
with rc_context({'figure.figsize': (12, 12), 'figure.dpi': 200}):
    plt.style.use('seaborn-v0_8-talk')
    # Create the plot
    plt.figure()
    plt.imshow(img)
    # Add title
    plt.title('Partitioned graph of cell-types and their age associated features')
    # Add legend
    plt.legend(handles=legend_elements, bbox_to_anchor=(1.05, 1), loc=2, 
               borderaxespad=0, prop={'size': 8})
    plt.axis('off')
    # save the updated figure
    plt.savefig(image_file)

In [None]:
display(Image(image_file))

In [None]:
!date