In [2]:
import pandas as pd
import json
import networkx as nx
from networkx.drawing.nx_pydot import graphviz_layout

from plot_helpers import plot_plotly_graph

In [5]:
def load_allen_brain_hierarchy(path_to_allen_json):
    # Load json file
    with open(path_to_allen_json,) as file:
        allen_data = json.load(file)

    # Unpack json
    count = 0
    while True:
        record = ['msg'] + ['children']*count
        df = pd.json_normalize(allen_data,record_path=record)
        if count==0:
            allen_df = df.copy()
        else:
            allen_df = pd.concat([allen_df, df])
        count += 1

        if df.empty:
            break
    
    allen_df = allen_df.rename(columns={'name':'region_name'})
    return allen_df

In [6]:
# Show the first three regions
path_to_allen_json = '../AllenMouseBrainOntology.json'
allen_df = load_allen_brain_hierarchy(path_to_allen_json)
allen_df.head(3)

Unnamed: 0,id,atlas_id,ontology_id,acronym,region_name,color_hex_triplet,graph_order,st_level,hemisphere_id,parent_structure_id,children
0,997,-1.0,1,root,root,FFFFFF,0,0,3,,"[{'id': 8, 'atlas_id': 0, 'ontology_id': 1, 'a..."
0,8,0.0,1,grey,Basic cell groups and regions,BFDAE3,1,1,3,997.0,"[{'id': 567, 'atlas_id': 70, 'ontology_id': 1,..."
1,1009,691.0,1,fiber tracts,fiber tracts,CCCCCC,1101,1,3,997.0,"[{'id': 967, 'atlas_id': 686, 'ontology_id': 1..."


In [27]:
# Load graph as nx object
G = nx.from_pandas_edgelist(brain_df, 'id', 'parent_structure_id', 
                            create_using = nx.DiGraph())

# Remove the node called 'None'
G.remove_node(None)

# Set the index of the brain DF as 'id'
brain_df = brain_df.set_index('id')

In [28]:
# Add attributes to the regions
attribute_columns = ['acronym','region_name','color_hex_triplet']
for col_name in attribute_columns:
    nx.set_node_attributes(G, brain_df[col_name].to_dict(), col_name)

In [29]:
# Calculate positions of the nodes (this may take some time)
pos = graphviz_layout(G, prog='dot')

In [31]:
# Plot brain region hierarchy
fig = plot_plotly_graph(G,pos)
fig.show()