In [None]:
import sys
sys.path.append("..")
import argparse
import numpy as np
import dgl
from dgl import DGLGraph
import torch
import torch.nn.functional as F
import time
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
from collections import Counter
import pickle
import h5py
import random
import glob2
import seaborn as sns

import train
import models

%load_ext autoreload
%autoreload 2

random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)
device = train.get_device()

In [None]:

model_name = "GraphConv"
normalize_weights = "log_per_cell"
node_features = "scale"
same_edge_values = False
edge_norm = True
hidden_relu = False
hidden_bn = False
n_layers = 1
pca_size = 50
epochs = 10
batch_size = 128
hidden_dim = 200
hidden = [300]
nb_genes = 3000
activation = F.relu
for category in ["real_data"#"balanced_data", "imbalanced_data", 
                ]:
    results = pd.DataFrame()
    path= ".."
    if category in ["balanced_data", "imbalanced_data"]:
        files = glob2.glob(f'{path}/R/simulated_data/{category}/*.h5')
        files = [f[len(f"{path}/R/simulated_data/{category}/"):-3] for f in files]

    else:
        files = glob2.glob(f'{path}/real_data/*.h5')
        files = [f[len(f"{path}/real_data/"):-3] for f in files]
        results = pd.read_pickle(
                        f"../output/pickle_results/{category}/{category}_graph_creation.pkl")
        files = files[1:]
    print(files)
    
    for dataset in files:
        if category in ["balanced_data", "imbalanced_data"]:
            data_mat = h5py.File(f"{path}/R/simulated_data/{category}/{dataset}.h5","r")
        else:
            data_mat = h5py.File(f"{path}/real_data/{dataset}.h5","r")
        print(f">> {dataset}")

        Y = np.array(data_mat['Y'])
        X = np.array(data_mat['X'])

        genes_idx, cells_idx = train.filter_data(X, highly_genes=nb_genes)
        X = X[cells_idx][:, genes_idx]
        Y = Y[cells_idx]
        n_clusters = len(np.unique(Y))

        for normalize_weights in ["log_per_cell", "per_cell", "none"]:
            for node_features in ["scale", "none", "scale_by_cell"]:
                for edge_norm in [True, False]:

                    graph = train.make_graph(
                        X,
                        Y,
                        dense_dim=pca_size,
                        node_features=node_features,
                        normalize_weights=normalize_weights,
                        edge_norm =edge_norm
                    )

                    labels = graph.ndata["label"]
                    train_ids = np.where(labels != -1)[0]

                    sampler = dgl.dataloading.MultiLayerFullNeighborSampler(n_layers)

                    dataloader = dgl.dataloading.NodeDataLoader(
                        graph,
                        train_ids,
                        sampler,
                        batch_size=batch_size,
                        shuffle=True,
                        drop_last=False,
                        num_workers=1,
                    )
                    print(
                        f"INPUT: {model_name}  {hidden_dim}, {hidden}, {same_edge_values}, {edge_norm}"
                    )
                    t1 = time.time()

                    for run in range(3):
                        t_start = time.time()
                        torch.manual_seed(run)
                        torch.cuda.manual_seed_all(run)
                        np.random.seed(run)
                        random.seed(run)

                        model = models.GCNAE(
                            in_feats=pca_size,
                            n_hidden=hidden_dim,
                            n_layers=n_layers,
                            activation=activation,
                            dropout=0.1,
                            hidden=hidden,
                            hidden_relu=hidden_relu,
                            hidden_bn=hidden_bn,
                        ).to(device)
                        if run == 0:
                            print(f">", model)

                        optim = torch.optim.Adam(model.parameters(), lr=1e-5)

                        scores = train.train(model, optim, epochs, dataloader, n_clusters,  plot=False,
                                            cluster=["KMeans", "Leiden"])
                        scores["dataset"] = dataset
                        scores["run"] = run
                        scores["nb_genes"] = nb_genes
                        scores["node_features"] = node_features
                        scores["normalize_weights"] = normalize_weights
                        scores["edge_norm"] = edge_norm

                        results = results.append(scores, ignore_index = True)

                        results.to_pickle(
                            f"../output/pickle_results/{category}/{category}_graph_creation.pkl")
                        print("Done")

In [None]:
order = {"log_per_cell": 2,
        "none": 0,
        "per_cell":1}
for category in ["real_data", "balanced_data", "imbalanced_data"]:
    print(category)
    results = pd.read_pickle(
        f"../output/pickle_results/{category}/{category}_graph_creation.pkl")
    results = results.sort_values(by="edge_norm", ascending = False).groupby([ 
        "edge_norm", "node_features","normalize_weights",
                     ])[["kmeans_ari"]].mean().round(2).reset_index()
    results["order"] = results["normalize_weights"].apply(lambda x: order[x])
    display(results.sort_values(by= ["edge_norm", "node_features" ,"order",],
                               ascending = [False, True, True]))