In [57]:
import pickle
import torch
import networkx as nx
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path

import torch.nn as nn
import torch.optim as optim
import torch_geometric.nn as gnn
import torch_geometric.transforms
import torch_geometric.utils
import torch.nn.functional as F


In [15]:
data_dir = Path("../data/preprocessed").absolute()
node_info_raw = pd.read_pickle(data_dir / "node_info.pkl")
edge_info_raw = pd.read_pickle(data_dir / "edge_info.pkl")

In [27]:
nt_types = ["ACH", "GABA", "GLUT", "SER", "DA", "OCT"]
superclass_id2name = dict(enumerate(node_info_raw["super_class"].unique()))
superclass_name2id = {name: id for id, name in superclass_id2name.items()}
class_id2name = dict(enumerate(node_info_raw["class"].unique()))
class_name2id = {name: id for id, name in class_id2name.items()}

In [32]:
node_info = node_info_raw[["root_id"]].copy()
node_info.loc[:, "super_class"] = [
    superclass_name2id[name] for name in node_info_raw["super_class"]
]
node_info.loc[:, "class"] = [class_name2id[name] if not pd.isna(name) else -1 for name in node_info_raw["class"]]

edge_info = edge_info_raw[["pre_root_id", "post_root_id"]].copy()
for nt in nt_types:
    edge_info[f"weight_{nt}"] = np.zeros(len(edge_info))
edge_info = edge_info.set_index(["pre_root_id", "post_root_id"]).sort_index()
for _, etr in edge_info_raw.iterrows():
    key = (etr["pre_root_id"], etr["post_root_id"])
    edge_info.loc[key, f"weight_{etr['nt_type']}"] = etr["syn_count"]
edge_info = edge_info.reset_index()

In [34]:
node_info.to_pickle(data_dir / "node_info_encoded.pkl")
edge_info.to_pickle(data_dir / "edge_info_encoded.pkl")

In [92]:
nx_graph = nx.from_pandas_edgelist(
    edge_info,
    source="pre_root_id",
    target="post_root_id",
    edge_attr=[f"weight_{nt}" for nt in nt_types],
    create_using=nx.DiGraph,
)
node_attr_names = ["super_class", "class"]
node_info_sel = node_info[["root_id", *node_attr_names]].set_index("root_id")
node_info_sel = node_info_sel.rename(
    columns={"class": "neuron_class", "super_class": "neuron_super_class"}
)
nx.set_node_attributes(nx_graph, node_info_sel.to_dict(orient="index"))
pg_graph = torch_geometric.utils.from_networkx(
    nx_graph,
    group_edge_attrs=[f"weight_{nt}" for nt in nt_types],
)

In [105]:
split = torch_geometric.transforms.RandomNodeSplit(
    num_val=0.1, num_test=0.2, key="neuron_super_class"
)
pg_graph = split(pg_graph)

In [107]:
pg_graph

Data(edge_index=[2, 2393286], neuron_super_class=[124733], neuron_class=[124733], edge_attr=[2393286, 6], num_nodes=124733, train_mask=[124733], val_mask=[124733], test_mask=[124733])

In [111]:
class MyGCN(nn.Module):
    def __init__(self, num_edge_features, hidden_channels, num_classes):
        super(MyGCN, self).__init__()
        self.conv1 = gnn.GCNConv(num_edge_features, hidden_channels)
        self.conv2 = gnn.GCNConv(hidden_channels, hidden_channels)
        self.conv3 = gnn.GCNConv(hidden_channels, num_classes)
    
    def forward(self, data):
        features = data.edge_attr
        edge_index = data.edge_index
        print(features.shape, edge_index.shape)
        x = self.conv1(features, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        output = self.conv3(x, edge_index)
        return output

In [113]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
assert torch.cuda.is_available()

gcn_model = MyGCN(
    num_edge_features=pg_graph.num_edge_features,
    hidden_channels=16,
    num_classes=len(superclass_id2name),
).to(device)

optimizer = optim.Adam(gcn_model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()


In [114]:
def train(model, graph, optimizer, criterion, num_epochs, target, device):
    for epoch in range(num_epochs):
        model.train()
        optimizer.zero_grad()
        out = model(graph)
        loss = criterion(out[graph.train_mask], graph[target][graph.train_mask])
        loss.backward()
        optimizer.step()
        acc = get_accuracy(model, graph, graph.train_mask, target)

        if epoch % 10 == 0:
            print(
                f"Epoch {epoch} | Train loss {loss.item():.4f} | Train accuracy {acc:.4f}"
            )
    return model


def get_accuracy(model, graph, mask, target):
    model.eval()
    pred = model(graph).argmax(dim=1)
    return (pred[mask] == graph[target][mask]).sum() / mask.sum()

In [115]:
train(gcn_model, pg_graph.to(device), optimizer, criterion, 100, "neuron_super_class", device)

torch.Size([2393286, 6]) torch.Size([2, 2393286])


IndexError: The shape of the mask [124733] at index 0 does not match the shape of the indexed tensor [2393286, 9] at index 0

In [40]:
nx_graph.nodes[720575940660863105]

{'super_class': 4, 'class': -1}

In [46]:
for root_id in node_info_raw['root_id'].sample(10):
    print(nx_graph.nodes[root_id])

{'super_class': 0, 'class': 0}
{'super_class': 2, 'class': -1}
{'super_class': 2, 'class': -1}
{'super_class': 0, 'class': 9}
{'super_class': 0, 'class': 0}
{'super_class': 0, 'class': 0}
{'super_class': 2, 'class': -1}
{'super_class': 5, 'class': 6}
{'super_class': 0, 'class': 0}
{'super_class': 0, 'class': 0}


In [45]:
pd.isna(node_info_raw['class']).sum() / len(node_info_raw)

0.24203699101280335

In [5]:
list(nx_graph.nodes(data=True))[0]

(720575940619238582,
 {'name': 'AVLP_L.AVLP_R.12',
  'group': 'AVLP_L.AVLP_R',
  'nt_type': 'ACH',
  'nt_type_score': 0.94,
  'cluster': 'C3128.2',
  'flow': 'intrinsic',
  'super_class': 'central',
  'class': nan,
  'sub_class': nan,
  'cell_type': nan,
  'hemibrain_type': nan,
  'hemilineage': nan,
  'side': 'left',
  'nerve': nan,
  'length_nm': 2669622,
  'area_nm': 8905502208,
  'size_nm': 595439001600})

In [None]:
nodelist

In [10]:
pg_graph = torch_geometric.utils.from_networkx(
    nx_graph,
    group_node_attrs=["super_class"],
    group_edge_attrs=["syn_count"],
)

AttributeError: 'list' object has no attribute 'dim'

In [11]:
from torch_geometric.datasets import Planetoid

dataset = Planetoid(root='/tmp/Cora', name='Cora')
graph = dataset[0]

Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index
Processing...
Done!


In [51]:
graph

Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])