In [1]:
import numpy as np
import pandas as pd
import torch
from gcn_model import GCNModel
import utilities
from test_model import test_model
import os
import statistics
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import seaborn as sns

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
data_folder = "/home/groups/ConradLab/daniel/sharp_data/pbmc_proto/"
# get labels
data_path = data_folder + "query.csv"
tools = ["sctype","scsorter","scina","singler", "scpred"]
#tools = ["scsorter","scina","sctype"]
ref_path = data_folder + "ref.csv"
ref_label_path = data_folder + "ref_labels_tcombined.csv"
marker_path = data_folder + "markers_cd4-8.txt"
if os.path.exists(data_folder + "preds_3kref_cd4-8.csv"):
    all_labels = pd.read_csv(data_folder + "preds_3kref_cd4-8.csv", index_col=0)
    if all_labels.shape[1] != len(tools): 
        all_labels = all_labels[tools]
        #raise Exception("wrong amount of tools in file")
else:
    all_labels = utilities.label_counts(data_path,tools,ref_path,ref_label_path,marker_path)

# read in dataset
counts = pd.read_csv(data_path, index_col=0)
X, keep_cells, keep_genes, pca_obj = utilities.preprocess(np.array(counts), scale=False, comps=500)

all_labels = all_labels.loc[keep_cells,:]

_,marker_names = utilities.read_marker_file(marker_path)

all_labels_factored = utilities.factorize_df(all_labels, marker_names)
encoded_labels = utilities.encode_predictions(all_labels_factored)

confident_labels = utilities.get_consensus_labels(encoded_labels, necessary_vote = .51)

meta_path = data_folder + "query_labels_cd4-8.csv"
metadata = pd.read_csv(meta_path, index_col=0)
real_y,cell_names = pd.factorize(metadata['labels'], sort=True)
real_y = real_y[keep_cells]

train_nodes = np.where(confident_labels != -1)[0]
test_nodes = np.where(confident_labels == -1)[0]
print(cell_names)

Index(['b_cells', 'cd14_monocytes', 'cd4_t_cell', 'cd56_nk', 'cd8_t_cell'], dtype='object')


In [3]:
scores = np.zeros(len(tools))
for i, tool in enumerate(tools):
    scores[i] = utilities.pred_accuracy(all_labels_factored[tool].to_numpy()[train_nodes], confident_labels[train_nodes])
scores

array([0.94068003, 0.68073308, 0.78659272, 0.77260673, 0.81697613])

In [4]:
scores /= scores.sum()
scores = np.log(scores)
scores

array([-1.44684359, -1.77027635, -1.62573602, -1.64367647, -1.58783675])

In [5]:
results = all_labels_factored.to_numpy()
results_exp = np.zeros((results.shape[0], results.shape[1], 5))

results_exp[results == 0, :] = np.array([1,0,0,0,0])
results_exp[results == 1, :] = np.array([0,1,0,0,0])
results_exp[results == 2, :] = np.array([0,0,1,0,0])
results_exp[results == 3, :] = np.array([0,0,0,1,0])
results_exp[results == 4, :] = np.array([0,0,0,0,1])

tY = torch.tensor(results_exp).float()

In [6]:
dataset  = torch.utils.data.TensorDataset(torch.tensor(X), torch.tensor(tY))
dataloader = torch.utils.data.DataLoader(dataset, batch_size=50, shuffle=True)

test_dataset  = torch.utils.data.TensorDataset(torch.tensor(X), torch.tensor(real_y))
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=50, shuffle=False)

  dataset  = torch.utils.data.TensorDataset(torch.tensor(X), torch.tensor(tY))


In [13]:
m = GCNModel("configs/2_40_5.txt", 2, scores, weights_mode = True, learn_weights = True, dropout=0.0)
m.train(dataloader, 250, verbose=True)

Loss in epoch 0 = 29.444733
Loss in epoch 10 = 13.568172
Loss in epoch 20 = 12.605890
Loss in epoch 30 = 12.393337
Loss in epoch 40 = 12.325214
Loss in epoch 50 = 12.284460
Loss in epoch 60 = 12.265996
Loss in epoch 70 = 12.318281
Loss in epoch 80 = 12.230808
Loss in epoch 90 = 12.224720
Loss in epoch 100 = 12.439528
Loss in epoch 110 = 12.198706
Loss in epoch 120 = 12.216996
Loss in epoch 130 = 12.207728
Loss in epoch 140 = 12.228580
Loss in epoch 150 = 12.188997
Loss in epoch 160 = 12.194435
Loss in epoch 170 = 12.186616
Loss in epoch 180 = 12.192631
Loss in epoch 190 = 12.193334
Loss in epoch 200 = 12.181057
Loss in epoch 210 = 12.183413
Loss in epoch 220 = 12.177344
Loss in epoch 230 = 12.180355
Loss in epoch 240 = 12.182647


In [14]:
m.validation_metrics(test_dataloader, train_nodes, test_nodes)

(0.9173339605331421,
 array([[ 483,    0,    0,    0,    0],
        [   1,  123,    0,    0,    0],
        [   0,    1, 2055,    1,   17],
        [   0,    0,    7,  378,   31],
        [   0,    0,  292,    1,  856]]),
 0.9312756061553955,
 array([[ 483,    0,    0,    0,    0],
        [   1,  123,    0,    0,    0],
        [   0,    1, 2049,    0,   14],
        [   0,    0,    2,  368,   13],
        [   0,    0,  253,    1,  839]]),
 0.3333333432674408,
 array([[ 6,  1,  3],
        [ 5, 10, 18],
        [39,  0, 17]]))