In [1]:
import numpy as np
import pandas as pd
import torch
from gcn_model import GCNModel
import utilities
from test_model import test_model
import os
import statistics

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
data_folder = "/home/groups/ConradLab/daniel/sharp_data/sharp_sims/splat_0.7_de_rq/"
# get labels
data_path = data_folder + "query_counts.csv"
tools = ["sctype","scsorter","scina","singler", "scpred"]
#tools = ["scsorter","scina","singler"]
ref_path = data_folder + "ref_counts.csv"
ref_label_path = data_folder + "ref_labels.csv"
marker_path = data_folder + "markers.txt"
if os.path.exists(data_folder + "preds.csv"):
    all_labels = pd.read_csv(data_folder + "preds.csv", index_col=0)
    if all_labels.shape[1] != len(tools): 
        all_labels = all_labels[tools]
        #raise Exception("wrong amount of tools in file")
else:
    all_labels = utilities.label_counts(data_path,tools,ref_path,ref_label_path,marker_path)

# read in dataset
X = pd.read_csv(data_path, index_col=0)
X, keep_cells = utilities.preprocess(np.array(X), scale=False, comps=500)

all_labels = all_labels.loc[keep_cells,:]

_,marker_names = utilities.read_marker_file(marker_path)

all_labels_factored = utilities.factorize_df(all_labels, marker_names)
encoded_labels = utilities.encode_predictions(all_labels_factored)

meta_path = data_folder + "query_meta.csv"
metadata = pd.read_csv(meta_path, index_col=0)
real_y = pd.factorize(metadata['Group'], sort=True)[0]
real_y = real_y[keep_cells]

confident_labels = utilities.get_consensus_labels(encoded_labels, necessary_vote = 3)

train_nodes = np.where(confident_labels != -1)[0]
test_nodes = np.where(confident_labels == -1)[0]



In [3]:
dataset  = torch.utils.data.TensorDataset(torch.tensor(X), torch.tensor(confident_labels))
dataloader = torch.utils.data.DataLoader(dataset, batch_size=65, shuffle=True)

test_dataset  = torch.utils.data.TensorDataset(torch.tensor(X), torch.tensor(real_y))
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=65, shuffle=False)


In [4]:
m = GCNModel("configs/4_60.txt", 2, dropout=0.0)
m.train(dataloader, 150)

Loss in epoch 0 = 21.533558
Loss in epoch 10 = 0.016814
Loss in epoch 20 = 0.003902
Loss in epoch 30 = 0.001146
Loss in epoch 40 = 0.000203
Loss in epoch 50 = 0.000174
Loss in epoch 60 = 0.000103
Loss in epoch 70 = 0.000127
Loss in epoch 80 = 0.000067
Loss in epoch 90 = 0.000067
Loss in epoch 100 = 0.000063
Loss in epoch 110 = 0.000029
Loss in epoch 120 = 0.000033
Loss in epoch 130 = 0.000034
Loss in epoch 140 = 0.000021


In [5]:
m.validation_metrics(test_dataloader, train_nodes, test_nodes)

(0.955955982208252,
 array([[214,   8,   5,   5],
        [  0, 266,   2,   0],
        [  1,   4, 225,   4],
        [  2,   7,   6, 250]]),
 0.9724972248077393,
 array([[178,   6,   2,   2],
        [  0, 264,   2,   0],
        [  1,   2, 211,   1],
        [  0,   5,   4, 231]]),
 0.7888888716697693,
 array([[36,  2,  3,  3],
        [ 0,  2,  0,  0],
        [ 0,  2, 14,  3],
        [ 2,  2,  2, 19]]))

In [7]:
test_accuracy = [0]*5
for i in range(5):
    print(i)
    m = GCNModel("configs/4_60.txt", 2, dropout=0.1)
    m.train(dataloader, 150, verbose=False)
    _,_,_,_,acc,_ = m.validation_metrics(test_dataloader, train_nodes, test_nodes)
    test_accuracy[i] = acc
print(statistics.mean(test_accuracy))
print(statistics.stdev(test_accuracy))

0
1
2
3
4
0.5400000035762786
0.06507355297733373
