In [133]:
import pickle
import torch
import networkx as nx
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm import tqdm

import torch.nn as nn
import torch.optim as optim
import torch_geometric.nn as gnn
import torch_geometric.transforms
import torch_geometric.utils
import torch.nn.functional as F

In [2]:
data_dir = Path().absolute().parent / "data"

## Build dataframe with all node features and target classes

In [144]:
n2v_features_df = pd.read_pickle(data_dir / "embeddings/node2vec_embedding.pkl")
node_info_df = pd.read_pickle(data_dir / "preprocessed/node_info.pkl")
edge_info_df = pd.read_pickle(data_dir / "preprocessed/edge_info.pkl")
node_info_df = node_info_df[node_info_df["super_class"] != "optic"]
edge_info_df = edge_info_df[
    edge_info_df["pre_root_id"].isin(node_info_df["root_id"])
    & edge_info_df["post_root_id"].isin(node_info_df["root_id"])
]
additional_node_features = ["side", "length_nm", "area_nm", "size_nm"]
classes = ["super_class", "class", "hemilineage"]
node_info_df = node_info_df[["root_id"] + additional_node_features + classes]
node_info_df = pd.merge(
    n2v_features_df, node_info_df, how="left", left_index=True, right_on="root_id"
).set_index("root_id")

Encode discrete attributes/classes: side, super_class, class, hemilineage:

In [145]:
def encode_str(series, threshold=1):
    value_counts = series.value_counts()
    if pd.isna(series).any() or value_counts.min() < threshold:
        id2name = {0: "Other"}
        offset = 1
    else:
        id2name = {}
        offset = 0
    for i, (name, count) in enumerate(value_counts.items()):
        if count < threshold:
            break
        id2name[i + offset] = name
    return id2name, {v: k for k, v in id2name.items()}

In [146]:
sclass_id2name, sclass_name2id = encode_str(node_info_df["super_class"], 100)
class_id2name, class_name2id = encode_str(node_info_df["class"], 100)
hemilineage_id2name, hemilineage_name2id = encode_str(node_info_df["hemilineage"], 200)

In [147]:
print(f"superclass: {len(sclass_id2name)} classes")
print(f"class: {len(class_id2name)} classes")
print(f"hemilineage: {len(hemilineage_id2name)} classes")

superclass: 8 classes
class: 16 classes
hemilineage: 31 classes


In [148]:
for col, id2name, name2id in [("super_class", sclass_id2name, sclass_name2id),
                              ("class", class_id2name, class_name2id),
                              ("hemilineage", hemilineage_id2name, hemilineage_name2id)]:
    node_info_df[col] = node_info_df[col].map(name2id)
    node_info_df.loc[node_info_df[col].isna(), col] = 0
    node_info_df[col] = node_info_df[col].astype(np.int64)


In [149]:
node_info_df["is_left"] = (node_info_df["side"] == "left").astype(np.int64)
node_info_df["is_right"] = (node_info_df["side"] == "right").astype(np.int64)
node_info_df.drop(columns=["side"], inplace=True)

In [150]:
node_info_df.sample(5)

Unnamed: 0_level_0,ACH_0,ACH_1,ACH_2,ACH_3,ACH_4,ACH_5,ACH_6,ACH_7,SER_0,SER_1,...,OCT_6,OCT_7,length_nm,area_nm,size_nm,super_class,class,hemilineage,is_left,is_right
root_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
720575940629730575,-0.566653,1.215151,-0.539002,0.527027,0.026241,0.309878,2.317807,1.446208,0.105612,-0.111307,...,0.063825,0.055493,138746,366689536,14756843520,3,5,0,0,1
720575940630942506,-2.064043,0.148701,-1.770268,1.378116,-1.586561,1.643034,-0.217206,-0.405092,-0.06869,-0.253149,...,-0.031883,-0.10419,949925,2990539776,196513402880,1,2,0,0,1
720575940609395042,-1.885002,-0.02507,-0.638797,0.184277,-1.074937,0.564175,-0.189009,0.828821,-0.103525,-0.12456,...,0.044243,0.029568,1162953,2754063360,168185272320,1,0,26,1,0
720575940616921300,-0.66692,-0.391602,0.170298,0.362353,0.266168,1.909746,0.99916,-0.638981,-0.061791,-0.029292,...,0.079921,0.057925,1232977,3225652736,218947799040,1,0,0,0,1
720575940615682955,-0.093363,0.120704,0.06931,0.059378,-0.041162,-0.055812,-0.045702,-0.011661,-0.058572,0.051419,...,-0.002613,-0.05844,275727,669456128,43327170560,1,2,0,0,1


In [158]:
num_dims_per_nt = 8
nt_types = ["ACH", "GABA", "GLUT", "SER", "DA", "OCT"]
num_node_features = len(nt_types) * num_dims_per_nt + 2 + 3

## Convert data into a PyG graph

First, make a NetworkX graph:

In [131]:
edge_info_df.head()

Unnamed: 0,pre_root_id,post_root_id,neuropil,syn_count,nt_type
0,720575940619238582,720575940634854554,AVLP_R,35,ACH
1,720575940634034839,720575940660217473,AL_R,38,SER
2,720575940612615570,720575940604789676,EB,26,GABA
3,720575940614901215,720575940626983952,PRW,70,GABA
4,720575940627312104,720575940625498512,SMP_R,28,GLUT


In [134]:
edge_info_new = edge_info_df[["pre_root_id", "post_root_id"]].copy()
for nt in nt_types:
    edge_info_new[f"weight_{nt}"] = np.zeros(len(edge_info_new))
edge_info_new = edge_info_new.set_index(["pre_root_id", "post_root_id"]).sort_index()
for _, etr in tqdm(edge_info_df.iterrows(), total=len(edge_info_df)):
    key = (etr["pre_root_id"], etr["post_root_id"])
    edge_info_new.loc[key, f"weight_{etr['nt_type']}"] = etr["syn_count"]
edge_info_new = edge_info_new.reset_index()

100%|██████████| 1301936/1301936 [04:18<00:00, 5035.94it/s]


In [151]:
nx_graph = nx.from_pandas_edgelist(
    edge_info_new,
    source="pre_root_id",
    target="post_root_id",
    edge_attr=[f"weight_{nt}" for nt in nt_types],
    create_using=nx.DiGraph,
)
nx.set_node_attributes(nx_graph, node_info_df.to_dict(orient="index"))

In [152]:
list(nx_graph.edges(data=True))[0]

(720575940600433181,
 720575940605214636,
 {'weight_ACH': 9.0,
  'weight_GABA': 0.0,
  'weight_GLUT': 0.0,
  'weight_SER': 0.0,
  'weight_DA': 0.0,
  'weight_OCT': 0.0})

In [153]:
list(nx_graph.nodes(data=True))[0]

(720575940600433181,
 {'ACH_0': -0.6283556818962097,
  'ACH_1': -0.33705607056617737,
  'ACH_2': 0.553609311580658,
  'ACH_3': 0.7189285755157471,
  'ACH_4': -0.060541942715644836,
  'ACH_5': 0.21442939341068268,
  'ACH_6': 2.3636906147003174,
  'ACH_7': -0.6410021781921387,
  'SER_0': -0.016590997576713562,
  'SER_1': 0.0708422064781189,
  'SER_2': -0.08429275453090668,
  'SER_3': -0.08012261986732483,
  'SER_4': 0.11299248039722443,
  'SER_5': -0.10283048450946808,
  'SER_6': 0.0822170078754425,
  'SER_7': -0.12428461015224457,
  'GABA_0': 0.18916143476963043,
  'GABA_1': 0.03238673135638237,
  'GABA_2': 0.12384393811225891,
  'GABA_3': -0.2576524615287781,
  'GABA_4': 0.052079446613788605,
  'GABA_5': 0.08941518515348434,
  'GABA_6': 0.7003392577171326,
  'GABA_7': 0.2536327540874481,
  'GLUT_0': 0.10493053495883942,
  'GLUT_1': -0.08961445093154907,
  'GLUT_2': -0.04550528526306152,
  'GLUT_3': -0.10579411685466766,
  'GLUT_4': 0.1064734011888504,
  'GLUT_5': 0.10387170314788818,
 

Convert NetworkX graph to PyG graph:

In [160]:
node_attributes_in = (
    [f"{nt}_{i}" for nt in nt_types for i in range(num_dims_per_nt)] + 
    ["is_left", "is_right", "length_nm", "area_nm", "size_nm"]
)
pg_graph = torch_geometric.utils.from_networkx(
    nx_graph,
    group_node_attrs=node_attributes_in,
    group_edge_attrs=[f"weight_{nt}" for nt in nt_types],
)

Save NX and GyG graphs:

In [164]:
graphs_dir = data_dir / "graphs"
graphs_dir.mkdir(exist_ok=True)
with open(graphs_dir / "nx_graph.pkl", "wb") as f:
    pickle.dump(nx_graph, f)
with open(graphs_dir / "pg_graph.pkl", "wb") as f:
    pickle.dump(pg_graph, f)

## Define GCN

In [117]:
class MyGCN(nn.Module):
    def __init__(self, n_in, n_hidden, n_out):
        super(MyGCN, self).__init__()
        self.conv1 = gnn.GCNConv(n_in, n_hidden)
        self.conv2 = gnn.GCNConv(n_hidden, n_hidden)
        self.conv3 = gnn.GCNConv(n_hidden, n_out)
    
    def forward(self, data):
        features = data.edge_attr
        edge_index = data.edge_index
        print(features.shape, edge_index.shape)
        x = self.conv1(features, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        output = self.conv3(x, edge_index)
        return output

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
assert torch.cuda.is_available()

gcn_model = MyGCN(
    num_edge_features=pg_graph.num_edge_features,
    hidden_channels=16,
    num_classes=len(class_id2name),
).to(device)

optimizer = optim.Adam(gcn_model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()