In [1]:
import sys
sys.path.append("..")
import argparse
import numpy as np
import dgl
from dgl import DGLGraph
import torch
import torch.nn.functional as F
import time
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
from collections import Counter
import pickle
import h5py
import random
import glob2
import seaborn as sns

import train
import models

%load_ext autoreload
%autoreload 2

random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)
device = train.get_device()

Using backend: pytorch

In a future version of Scanpy, `scanpy.api` will be removed.
Simply use `import scanpy as sc` and `import scanpy.external as sce` instead.

  warn("Tensorflow not installed; ParametricUMAP will be unavailable")


In [3]:
pca_size = 50
epochs = 10
batch_size = 128

model_name = "GraphConv"
normalize_weights = "log_per_cell"
node_features = "scale"
same_edge_values = False
edge_norm = True
hidden_relu = False
hidden_bn = False
n_layers = 1
hidden_dim = 200
hidden = [300]
nb_genes = 3000
activation = F.relu
for category in ["balanced_data", "imbalanced_data", # "real_data",
                ]:
    results = pd.DataFrame()
    path= ".."
    if category in ["balanced_data", "imbalanced_data"]:
        files = glob2.glob(f'{path}/R/simulated_data/{category}/*.h5')
        files = [f[len(f"{path}/R/simulated_data/{category}/"):-3] for f in files]
        normalize_weights = "per_cell"
        if category == "balanced_data":
            results = pd.read_pickle(
                        f"../output/pickle_results/{category}/{category}_graph_creation.pkl")
            files = files[5:]
    else:
        files = glob2.glob(f'{path}/real_data/*.h5')
        files = [f[len(f"{path}/real_data/"):-3] for f in files]
        normalize_weights = "log_per_cell"
    print(files)
    

    df = pd.DataFrame(columns = ["dataset", "ARI", "NMI", "sil", "run", "time", "pred"])
    for dataset in files:
        if category in ["balanced_data", "imbalanced_data"]:
            data_mat = h5py.File(f"{path}/R/simulated_data/{category}/{dataset}.h5","r")
        else:
            data_mat = h5py.File(f"{path}/real_data/{dataset}.h5","r")
        print(f">> {dataset}")

        Y = np.array(data_mat['Y'])
        X = np.array(data_mat['X'])

        genes_idx, cells_idx = train.filter_data(X, highly_genes=nb_genes)
        X = X[cells_idx][:, genes_idx]
        Y = Y[cells_idx]
        n_clusters = len(np.unique(Y))

        for normalize_weights in ["log_per_cell", "per_cell", "none"
                                 ]:
            for nb_edges in [-0.75, -0.5, -0.25, 0.25, 0.5, 0.75, 1]:
                graph = train.make_graph(
                    X,
                    Y,
                    dense_dim=pca_size,
                    node_features=node_features,
                    normalize_weights=normalize_weights,
                    edge_norm =edge_norm,
                    nb_edges = nb_edges
                )

                labels = graph.ndata["label"]
                train_ids = np.where(labels != -1)[0]

                sampler = dgl.dataloading.MultiLayerFullNeighborSampler(n_layers)

                dataloader = dgl.dataloading.NodeDataLoader(
                    graph,
                    train_ids,
                    sampler,
                    batch_size=batch_size,
                    shuffle=True,
                    drop_last=False,
                    num_workers=1,
                )
                print(
                    f"INPUT: {model_name}  {hidden_dim}, {hidden}, {same_edge_values}, {edge_norm}"
                )
                t1 = time.time()

                for run in range(3):
                    t_start = time.time()
                    torch.manual_seed(run)
                    torch.cuda.manual_seed_all(run)
                    np.random.seed(run)
                    random.seed(run)

                    model = models.GCNAE(
                        in_feats=pca_size,
                        n_hidden=hidden_dim,
                        n_layers=n_layers,
                        activation=activation,
                        dropout=0.1,
                        hidden=hidden,
                        hidden_relu=hidden_relu,
                        hidden_bn=hidden_bn,
                    ).to(device)
                    if run == 0:
                        print(f">", model)

                    optim = torch.optim.Adam(model.parameters(), lr=1e-5)

                    scores = train.train(model, optim, epochs, dataloader, n_clusters, plot=False,
                                        cluster=["KMeans", "Leiden"])
                    scores["dataset"] = dataset
                    scores["run"] = run
                    scores["nb_genes"] = nb_genes
                    scores["node_features"] = node_features
                    scores["nb_edges"] = nb_edges
                    scores["edge_norm"] = edge_norm

                    results = results.append(scores, ignore_index = True)
                    if nb_edges == 1: # add record for -1
                        scores["nb_edges"] = -1
                        results = results.append(scores, ignore_index = True)

                    results.to_pickle(
                        f"../output/pickle_results/{category}/{category}_edges.pkl")
                    print("Done")

In [6]:
results = pd.read_pickle(
                            f"../output/pickle_results/{category}/{category}_edges.pkl"
                        )

In [8]:
results.groupby("nb_edges").mean()

Unnamed: 0_level_0,ae_end,edge_norm,kmeans_ari,kmeans_cal,kmeans_nmi,kmeans_sil,kmeans_time,leiden_ari,leiden_cal,leiden_nmi,leiden_sil,leiden_time,nb_genes,run
nb_edges,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1
-1.0,1615540000.0,1.0,0.780047,1766.693707,0.814784,0.477235,0.797753,0.526849,1104.269641,0.725382,0.298962,8.556631,3000.0,1.0
-0.75,1615540000.0,1.0,0.658802,1175.354617,0.724009,0.352466,0.869494,0.503127,796.705447,0.689204,0.215191,10.124943,3000.0,1.0
-0.5,1615540000.0,1.0,0.520958,1036.699416,0.610596,0.304917,0.890218,0.473302,698.553581,0.646131,0.18822,8.677346,3000.0,1.0
-0.25,1615540000.0,1.0,0.394776,1016.103518,0.49194,0.300991,0.890076,0.367753,570.390373,0.535782,0.182485,8.706435,3000.0,1.0
0.25,1615540000.0,1.0,0.668769,1450.803264,0.716458,0.409788,0.865254,0.475709,944.506492,0.675653,0.243917,8.831817,3000.0,1.0
0.5,1615540000.0,1.0,0.744782,1618.033278,0.788502,0.458543,0.816464,0.502738,1043.778183,0.708191,0.287102,8.17972,3000.0,1.0
0.75,1615540000.0,1.0,0.773236,1732.041243,0.81332,0.471161,0.788315,0.517009,1070.593886,0.719673,0.29143,8.412686,3000.0,1.0
1.0,1615540000.0,1.0,0.780047,1766.693707,0.814784,0.477235,0.797753,0.526849,1104.269641,0.725382,0.298962,8.556631,3000.0,1.0
