In [67]:
import pickle
import torch
import networkx as nx
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from node2vec import Node2Vec
from pathlib import Path
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, f1_score

## Data loading and preprocessing
Load data

In [95]:
node2vec_embedding = pd.read_pickle("../data/embeddings/node2vec_embedding.pkl")
spectral_embedding = pd.read_pickle("../data/embeddings/spectral_embedding.pkl")
node_info = pd.read_pickle("../data/preprocessed/node_info.pkl")

In [5]:
node2vec_embedding.head()

Unnamed: 0,node2vec_ACH_0,node2vec_ACH_1,node2vec_ACH_2,node2vec_ACH_3,node2vec_ACH_4,node2vec_ACH_5,node2vec_ACH_6,node2vec_ACH_7,node2vec_SER_0,node2vec_SER_1,...,node2vec_DA_6,node2vec_DA_7,node2vec_OCT_0,node2vec_OCT_1,node2vec_OCT_2,node2vec_OCT_3,node2vec_OCT_4,node2vec_OCT_5,node2vec_OCT_6,node2vec_OCT_7
720575940621280688,-2.1919,0.0495,-0.072872,2.529746,-0.076199,-0.797629,0.910458,2.138495,0.024376,0.049488,...,2.257592,-1.052635,-0.086376,-0.016342,-0.075253,0.132493,0.043961,-0.032677,0.089628,0.128388
720575940629174889,-1.859022,1.64102,0.985564,2.125586,-0.00706,1.693392,-0.502335,0.876726,0.076958,-1.087816,...,1.963795,-0.368879,0.061196,0.024918,0.113039,0.013703,0.092897,-0.082858,-0.096461,-0.012831
720575940637132389,-1.602327,0.61501,0.530065,1.076127,-0.032424,2.00513,-0.967994,0.755368,0.10047,-0.249987,...,1.295472,-1.272117,0.060951,0.014239,-0.058449,-0.046567,-0.092479,0.09028,0.040239,-0.106398
720575940654777505,-6.589922,-6.693934,2.004244,0.055977,-4.032084,3.622172,8.605543,3.299473,-0.115199,0.002871,...,0.115346,-0.042695,0.112006,0.021041,0.08648,0.051845,-0.034308,-0.00033,-0.061942,0.119736
720575940616159371,-6.747396,-7.012646,2.138878,0.055543,-4.097299,3.801588,8.909692,3.433254,-0.042529,0.076106,...,-0.012369,-0.027677,0.123636,0.08671,0.055061,0.047037,-0.03915,-0.10972,-0.061458,-0.098871


Merge node info (containing morphological features and target labels) with node features

In [96]:
node_info_cols = [
    "root_id", "super_class", "class", "hemilineage",
    "length_nm", "area_nm", "size_nm"
]
node_info_merged = pd.merge(
    node_info[node_info_cols].set_index("root_id"),
    node2vec_embedding,
    how="right",
    left_index=True,
    right_index=True,
)
node_info_merged = pd.merge(
    node_info_merged,
    spectral_embedding,
    how="left",
    left_index=True,
    right_index=True,
)
node_info_merged = node_info_merged.rename(
    columns={
        "length_nm": "morph_length", "area_nm": "morph_area", "size_nm": "morph_size"
    }
)

In [99]:
node_info_merged.head()

Unnamed: 0,super_class,class,hemilineage,morph_length,morph_area,morph_size,node2vec_ACH_0,node2vec_ACH_1,node2vec_ACH_2,node2vec_ACH_3,...,spectral_DA_6,spectral_DA_7,spectral_OCT_0,spectral_OCT_1,spectral_OCT_2,spectral_OCT_3,spectral_OCT_4,spectral_OCT_5,spectral_OCT_6,spectral_OCT_7
720575940621280688,central,MBIN,,42474828,170716759168,11388240896000,-2.1919,0.0495,-0.072872,2.529746,...,-0.000459,-0.000101,-0.004504,0.000367,-0.000205,-2.3e-05,-8.8e-05,0.000124,-6.2e-05,9.9e-05
720575940629174889,central,MBIN,,42979456,169932248320,11757734236160,-1.859022,1.64102,0.985564,2.125586,...,-0.000853,-0.000122,-0.004436,0.000358,-0.000198,-2.2e-05,-8.5e-05,0.000119,-5.9e-05,9.5e-05
720575940637132389,central,MBIN,,20778086,112883269248,6248738539520,-1.602327,0.61501,0.530065,1.076127,...,-0.000359,0.000273,-0.004382,0.000351,-0.000193,-2.2e-05,-8.2e-05,0.000115,-5.7e-05,9.2e-05
720575940654777505,central,,,3049104,14418566016,1265078210560,-6.589922,-6.693934,2.004244,0.055977,...,0.000108,9.1e-05,-0.004404,0.000341,-0.000164,-6.7e-05,-5.6e-05,0.000115,-6e-05,8.6e-05
720575940616159371,central,,VLPl2_medial,3678799,16298574976,1078699100160,-6.747396,-7.012646,2.138878,0.055543,...,0.000108,9.1e-05,-0.004382,0.000351,-0.000193,-2.2e-05,-8.2e-05,0.000115,-5.7e-05,9.2e-05


For each classification task, remove minor classes

In [100]:
class_size_thrs = {"super_class": 100, "class": 100, "hemilineage": 200}
filtered_dfs = {}
for pred_target, thr in class_size_thrs.items():
    counts = node_info_merged[pred_target].value_counts()
    valid_classes = counts[counts >= thr].index
    mask = node_info_merged[pred_target].isin(valid_classes)
    sel_df = node_info_merged[mask].copy()
    cols_to_drop = list(class_size_thrs.keys())
    cols_to_drop.remove(pred_target)
    filtered_dfs[pred_target] = sel_df.drop(columns=cols_to_drop)

In [101]:
dict(enumerate(filtered_dfs["class"]["class"].value_counts().index))

{0: 'Kenyon_Cell',
 1: 'CX',
 2: 'visual',
 3: 'olfactory',
 4: 'mechanosensory',
 5: 'AN',
 6: 'DN',
 7: 'ALPN',
 8: 'LHLN',
 9: 'ALLN',
 10: 'gustatory',
 11: 'DAN',
 12: 'TuBu',
 13: 'unknown_sensory',
 14: 'motor'}

Prepare dataset for ML tasks: Encode class labels as integers, standardize by column, and split train/validate/test sets

In [102]:
preprocessed_dfs = {}
id2name_lookups = {}
name2id_lookups = {}
for pred_target, raw_df in filtered_dfs.items():
    new_df = raw_df.copy()
    
    # Encode class strings with integers
    id2name = dict(enumerate(raw_df[pred_target].value_counts().index))
    name2id = {v: k for k, v in id2name.items()}
    new_df[pred_target] = [name2id[x] for x in new_df[pred_target]]
    id2name_lookups[pred_target] = id2name
    name2id_lookups[pred_target] = name2id

    # Standardize columns
    cols_to_standardize = [
        col for col in new_df.columns
        if col.split("_")[0] in ["morph", "node2vec", "spectral"]
    ]
    scaler = StandardScaler()
    scaler.fit(new_df[cols_to_standardize])
    new_df[cols_to_standardize] = scaler.transform(new_df[cols_to_standardize])

    # Split train/val/test sets:
    trainval_df, test_df = train_test_split(new_df, test_size=0.2, random_state=0)
    train_df, val_df = train_test_split(trainval_df, test_size=0.25, random_state=0)
    preprocessed_dfs[pred_target] = {
        "train": train_df, "val": val_df, "test": test_df
    }

In [103]:
preprocessed_dfs["class"]["train"].head()

Unnamed: 0,class,morph_length,morph_area,morph_size,node2vec_ACH_0,node2vec_ACH_1,node2vec_ACH_2,node2vec_ACH_3,node2vec_ACH_4,node2vec_ACH_5,...,spectral_DA_6,spectral_DA_7,spectral_OCT_0,spectral_OCT_1,spectral_OCT_2,spectral_OCT_3,spectral_OCT_4,spectral_OCT_5,spectral_OCT_6,spectral_OCT_7
720575940603580960,6,-0.187439,-0.272891,-0.275083,0.855002,-0.118813,0.11606,-0.979234,0.261596,-0.62537,...,0.015069,0.016606,0.065578,-0.070286,-0.004452,0.007246,-0.006862,-0.013832,0.002393,-0.022164
720575940617030585,6,0.618015,0.442175,0.172435,-1.835875,0.783314,0.715892,-0.881517,0.398301,0.729199,...,0.015069,0.016606,0.065578,-0.070286,-0.004452,0.007246,-0.006862,-0.013832,0.002393,-0.022164
720575940629644666,0,-0.187676,-0.233088,-0.199933,-0.678925,-0.555258,-0.714586,0.677049,-0.586057,-1.313155,...,0.011767,0.016477,0.065578,-0.070286,-0.004452,0.007246,-0.006862,-0.013832,0.002393,-0.022164
720575940638681845,4,-0.516843,-0.40894,-0.281376,0.165645,-0.954239,0.102962,-0.5981,-0.474486,-0.86437,...,0.015069,0.016606,0.065578,-0.070286,-0.004452,0.007246,-0.006862,-0.013832,0.002393,-0.022164
720575940638928035,4,-0.577595,-0.525436,-0.482841,0.661459,-0.568705,-1.496727,0.228866,0.997119,-0.52479,...,0.022439,0.02245,0.065578,-0.070286,-0.004452,0.007246,-0.006862,-0.013832,0.002393,-0.022164


## Non-GNN baselines
Logistic regression

In [110]:
results_all = []

In [111]:
_columns_all = list(preprocessed_dfs["class"]["train"].columns)
feature_columns = {
    "node2vec": [
        col for col in _columns_all
        if col.split("_")[0] in ("morph", "node2vec")
    ],
    "spectral": [
        col for col in _columns_all
        if col.split("_")[0] in ("morph", "spectral")
    ],
    "both": [
        col for col in _columns_all
        if col.split("_")[0] in ("morph", "node2vec", "spectral")
    ]
}

In [112]:
from sklearn.linear_model import LogisticRegression
for pred_target in ["super_class", "class", "hemilineage"]:
    for features_group in ["node2vec", "spectral", "both"]:
        model = LogisticRegression(max_iter=1000)
        x_cols = feature_columns[features_group]
        train_x = preprocessed_dfs[pred_target]["train"][x_cols]
        train_y = preprocessed_dfs[pred_target]["train"][pred_target]
        assert train_x.shape[1] == len(x_cols)
        val_x = preprocessed_dfs[pred_target]["val"][x_cols]
        val_y = preprocessed_dfs[pred_target]["val"][pred_target]
        model.fit(train_x, train_y)
        val_pred = model.predict(val_x)
        acc = accuracy_score(val_y, val_pred)
        f1 = f1_score(val_y, val_pred, average='macro')
        results_all.append(["lr", pred_target, features_group, acc, f1])
        print(
            f"Predicting {pred_target} from {features_group} features: "
            f"acc={acc:.4f}, f1={f1:.4f}"
        )

Predicting super_class from node2vec features: acc=0.8334, f1=0.5106
Predicting super_class from spectral features: acc=0.7582, f1=0.3509
Predicting super_class from both features: acc=0.8504, f1=0.5608
Predicting class from node2vec features: acc=0.8895, f1=0.8158
Predicting class from spectral features: acc=0.7325, f1=0.5780
Predicting class from both features: acc=0.9023, f1=0.8334
Predicting hemilineage from node2vec features: acc=0.4802, f1=0.5272
Predicting hemilineage from spectral features: acc=0.3654, f1=0.3580
Predicting hemilineage from both features: acc=0.5006, f1=0.5653


MLP

In [113]:
from sklearn.neural_network import MLPClassifier
for pred_target in ["super_class", "class", "hemilineage"]:
    for features_group in ["node2vec", "spectral", "both"]:
        model = MLPClassifier(hidden_layer_sizes=[16, 16], max_iter=1000)
        x_cols = feature_columns[features_group]
        train_x = preprocessed_dfs[pred_target]["train"][x_cols]
        train_y = preprocessed_dfs[pred_target]["train"][pred_target]
        assert train_x.shape[1] == len(x_cols)
        val_x = preprocessed_dfs[pred_target]["val"][x_cols]
        val_y = preprocessed_dfs[pred_target]["val"][pred_target]
        model.fit(train_x, train_y)
        val_pred = model.predict(val_x)
        acc = accuracy_score(val_y, val_pred)
        f1 = f1_score(val_y, val_pred, average='macro')
        results_all.append(["mlp", pred_target, features_group, acc, f1])
        print(
            f"Predicting {pred_target} from {features_group} features: "
            f"acc={acc:.4f}, f1={f1:.4f}"
        )

Predicting super_class from node2vec features: acc=0.8940, f1=0.5900
Predicting super_class from spectral features: acc=0.8513, f1=0.5045
Predicting super_class from both features: acc=0.9082, f1=0.6286
Predicting class from node2vec features: acc=0.9086, f1=0.8239
Predicting class from spectral features: acc=0.8824, f1=0.7610
Predicting class from both features: acc=0.9136, f1=0.8229
Predicting hemilineage from node2vec features: acc=0.4940, f1=0.5587
Predicting hemilineage from spectral features: acc=0.4759, f1=0.5036
Predicting hemilineage from both features: acc=0.5087, f1=0.5883


In [114]:
results_all_df = pd.DataFrame(
    results_all, columns=["model", "pred_target", "feature_group", "accuracy", "f1"]
)
results_all_df

Unnamed: 0,model,pred_target,feature_group,accuracy,f1
0,lr,super_class,node2vec,0.833447,0.510575
1,lr,super_class,spectral,0.758157,0.350862
2,lr,super_class,both,0.850394,0.560846
3,lr,class,node2vec,0.889453,0.815778
4,lr,class,spectral,0.732463,0.577957
5,lr,class,both,0.902274,0.833403
6,lr,hemilineage,node2vec,0.480169,0.527247
7,lr,hemilineage,spectral,0.365422,0.357977
8,lr,hemilineage,both,0.500578,0.565308
9,mlp,super_class,node2vec,0.894029,0.59001


In [115]:
results_all_df.to_pickle("../data/classification_stats.pkl")

## GNN models

### Build PyG graph object