In [1]:
import torch 
import torch.nn as nn 

import gnn
import data 
data.sc.settings.verbosity = 0
import utils

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

scanpy==1.8.2 anndata==0.8.0 umap==0.5.2 numpy==1.21.5 scipy==1.8.0 pandas==1.3.5 scikit-learn==1.0.2 statsmodels==0.13.2 python-igraph==0.9.9 pynndescent==0.5.6


In [3]:
adata = data.preprocess_pbmc3k_preprocessed(data.get_pbmc3k_preprocessed())
target_col_categories = 'louvain'
target_col = 'louvain_ind'

n_train_perc = 0.7
n_train = int(n_train_perc * adata.X.shape[0])

adata

AnnData object with n_obs × n_vars = 2638 × 1838
    obs: 'n_genes', 'percent_mito', 'n_counts', 'louvain', 'louvain_ind'
    var: 'n_cells'
    uns: 'draw_graph', 'louvain', 'louvain_colors', 'neighbors', 'pca', 'rank_genes_groups'
    obsm: 'X_pca', 'X_tsne', 'X_umap', 'X_draw_graph_fr'
    varm: 'PCs'
    obsp: 'distances', 'connectivities'

baselines

In [4]:
# SVM baseline 
import svm 
x_train, x_test, y_train, y_test = svm.split_data_cellclass(adata, n_train, target_col=target_col)
clf = svm.svm_cellclass(x_train, y_train, x_test, y_test, kernel='linear')
# MLP baseline 
import mlp 
NUM_EPOCHS = 1000
LR = 0.001 
x = torch.tensor(adata.X)
y = data.target2onehot(adata, target_col=target_col)
input_dim = x.shape[1]
output_dim = data.num_categories(adata, target_col=target_col_categories)
hidden_dim = 64
x_train, y_train = x[:n_train], y[:n_train]
x_test, y_test = x[n_train:], y[n_train:]

for layers in [
    [(input_dim, output_dim, hidden_dim),], 
    [(input_dim, hidden_dim, hidden_dim),(hidden_dim, output_dim, hidden_dim),], 
    [(input_dim, hidden_dim, hidden_dim),(hidden_dim, hidden_dim, hidden_dim),(hidden_dim, output_dim, hidden_dim),], 
]: 
    model = nn.Sequential(*[mlp.SimpleMLP(a,b,c) for a,b,c in layers])
    losses = mlp.train(model, x_train, y_train, lr=LR, num_epochs=NUM_EPOCHS)
    print(f'MLP-{len(layers)}', end=' ')
    test_acc = mlp.eval(model, x_test, y_test)


SVM test accuracy: 94.1919191919192
MLP-1 accuracy: 89.52020263671875
MLP-2 accuracy: 88.76262664794922
MLP-3 accuracy: 87.1212158203125


edge index

In [5]:
adata

AnnData object with n_obs × n_vars = 2638 × 1838
    obs: 'n_genes', 'percent_mito', 'n_counts', 'louvain', 'louvain_ind'
    var: 'n_cells'
    uns: 'draw_graph', 'louvain', 'louvain_colors', 'neighbors', 'pca', 'rank_genes_groups'
    obsm: 'X_pca', 'X_tsne', 'X_umap', 'X_draw_graph_fr'
    varm: 'PCs'
    obsp: 'distances', 'connectivities'

In [6]:
adata = data.preprocess_leiden(adata, n_neighbors=30)
edge_index = data.connectivities2edge_index(adata)

edge_index.shape[1] / adata.X.shape[1]**2
edge_index = edge_index.to(device)

In [7]:
x = torch.tensor(adata.X).to(device)
y = data.target2onehot(adata, target_col=target_col).to(device)

In [8]:
from gnn import GNN_DGN_wrapper 

RUNS = 3
NUM_EPOCHS = 2001
LR = 0.001

models = [
    gnn.GNN('GCN', input_dim, output_dim, num_layers=1, hidden_dim=64), 
    gnn.GNN('GCN', input_dim, output_dim, num_layers=2, hidden_dim=64), 
    gnn.GNN('GCN', input_dim, output_dim, num_layers=4, hidden_dim=64), 
    GNN_DGN_wrapper(
        gnn.GNN('GCN', input_dim, 64, num_layers=1, hidden_dim=64), 
        output_dim, k=10, hidden_dim=64
    ), 
    GNN_DGN_wrapper(
        gnn.GNN('GCN', input_dim, 64, num_layers=2, hidden_dim=64), 
        output_dim, k=10, hidden_dim=64
    ), 
    GNN_DGN_wrapper(
        gnn.GNN('GCN', input_dim, 64, num_layers=4, hidden_dim=64), 
        output_dim, k=10, hidden_dim=64
    )

]

In [9]:
objs = []
for model in models: 
    model = model.to(device)
    obj = gnn.train(model, x, edge_index, y, n_train, lr=LR, num_epochs=NUM_EPOCHS, log_interval=1000)
    objs.append(obj)
    torch.cuda.empty_cache()

[EPOCH 0] train loss: 2.094449996948242
[EPOCH 1000] train loss: 1.2794954776763916
[EPOCH 2000] train loss: 1.2784538269042969
[EPOCH 0] train loss: 2.0805649757385254
[EPOCH 1000] train loss: 1.283355951309204
[EPOCH 2000] train loss: 1.2832362651824951
[EPOCH 0] train loss: 2.081427574157715
[EPOCH 1000] train loss: 1.2935792207717896
[EPOCH 2000] train loss: 1.2974096536636353
[EPOCH 0] train loss: 2.0764737129211426
[EPOCH 1000] train loss: 1.2866332530975342
[EPOCH 2000] train loss: 1.279139757156372
[EPOCH 0] train loss: 2.0771775245666504
[EPOCH 1000] train loss: 1.29063081741333
[EPOCH 2000] train loss: 1.281067967414856
[EPOCH 0] train loss: 2.0782394409179688
[EPOCH 1000] train loss: 1.2989500761032104
[EPOCH 2000] train loss: 1.290362000465393


In [19]:
print(
    torch.tensor(objs[0]['test_accs']).max(), 
    torch.tensor(objs[1]['test_accs']).max(), 
    torch.tensor(objs[2]['test_accs']).max(), 
)

tensor(96.0859) tensor(96.0859) tensor(95.2020)


In [11]:
for obj in objs: 
    print(torch.tensor(obj['test_accs']).max(), obj['test_accs'][-1])

tensor(96.0859) 95.32828330993652
tensor(96.4646) 95.58081030845642
tensor(96.0859) 95.45454382896423
tensor(95.8333) 87.62626051902771
tensor(96.4646) 94.57070827484131
tensor(96.3384) 95.20202279090881


In [None]:
print(
    objs[0]['test_accs'][-1], 
    objs[1]['test_accs'][-1], 
    objs[2]['test_accs'][-1], 
)

95.58081030845642 95.58081030845642 95.20202279090881
