In [1]:
import numpy as np
import pandas as pd
import torch
from gcn_model import GCNModel
import utilities
from test_model import test_model
import os
import statistics
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import random

  from .autonotebook import tqdm as notebook_tqdm


In [57]:
data_folder = "/home/groups/ConradLab/daniel/sharp_data/sharp_sims/splat_0.7_de_rq_v3/"
# get labels
data_path = data_folder + "query_counts.csv"
tools = ["sctype","scsorter","scina","singler", "scpred"]
#tools = ["scsorter","scina","singler", "sctype"]
ref_path = data_folder + "ref_counts.csv"
ref_label_path = data_folder + "ref_labels.csv"
marker_path = data_folder + "markers.txt"
if os.path.exists(data_folder + "preds.csv"):
    all_labels = pd.read_csv(data_folder + "preds.csv", index_col=0)
    if all_labels.shape[1] != len(tools): 
        all_labels = all_labels[tools]
        #raise Exception("wrong amount of tools in file")
else:
    all_labels = utilities.label_counts(data_path,tools,ref_path,ref_label_path,marker_path)

# read in dataset
X = pd.read_csv(data_path, index_col=0)
X, keep_cells, keep_genes, pca_obj = utilities.preprocess(np.array(X), scale=False, comps=500)

all_labels = all_labels.loc[keep_cells,:]

_,marker_names = utilities.read_marker_file(marker_path)

all_labels_factored = utilities.factorize_df(all_labels, marker_names)
encoded_labels = utilities.encode_predictions(all_labels_factored)

meta_path = data_folder + "query_meta.csv"
metadata = pd.read_csv(meta_path, index_col=0)
real_y = pd.factorize(metadata['Group'], sort=True)[0]
real_y = real_y[keep_cells]

confident_labels = utilities.get_consensus_labels(encoded_labels, necessary_vote = .51)

train_nodes = np.where(confident_labels != -1)[0]
test_nodes = np.where(confident_labels == -1)[0]

In [58]:
len(test_nodes)

98

In [48]:
encoded_labels[real_y==3]

array([[0., 1., 2., 1.],
       [0., 2., 0., 1.],
       [0., 1., 0., 3.],
       [0., 0., 1., 2.],
       [0., 1., 0., 3.],
       [1., 0., 1., 1.],
       [2., 0., 2., 0.],
       [1., 1., 0., 2.],
       [0., 2., 0., 1.],
       [0., 2., 0., 1.],
       [2., 0., 0., 2.],
       [0., 2., 1., 0.],
       [1., 0., 0., 2.],
       [0., 0., 1., 2.],
       [0., 0., 1., 3.],
       [0., 1., 2., 0.],
       [0., 2., 1., 1.],
       [0., 0., 2., 2.],
       [1., 0., 0., 2.],
       [0., 2., 0., 1.],
       [0., 1., 0., 2.],
       [0., 0., 2., 1.],
       [1., 0., 0., 2.],
       [0., 0., 2., 2.],
       [0., 1., 0., 3.],
       [0., 4., 0., 0.],
       [1., 0., 0., 2.],
       [0., 0., 1., 3.],
       [0., 2., 1., 1.],
       [0., 1., 0., 2.],
       [0., 1., 0., 2.],
       [0., 2., 1., 0.],
       [2., 0., 0., 2.],
       [0., 1., 0., 2.],
       [3., 0., 0., 1.],
       [1., 1., 2., 0.],
       [2., 0., 1., 0.],
       [0., 1., 0., 2.],
       [0., 2., 0., 1.],
       [0., 1., 0., 2.],


In [49]:
confusion_matrix(real_y[train_nodes], confident_labels[train_nodes])

array([[342,  20,   6,   2],
       [  1, 283,   1,   0],
       [  1,   1, 187,   0],
       [  3,  15,   4,  42]])

In [50]:
confusion_matrix(real_y, all_labels_factored["scina"])

array([[  0,   0,   0,   0,   0],
       [236, 126,  20,  37,   5],
       [ 99,   6, 165,  14,   5],
       [ 49,   8,   7, 133,   2],
       [ 51,   7,   5,  10,  15]])

In [7]:
confusion_matrix(real_y, all_labels_factored["sctype"])

array([[400,  13,  11,   0],
       [  7, 270,  12,   0],
       [  9,   8, 182,   0],
       [ 19,  48,  21,   0]])

In [8]:
confusion_matrix(real_y, all_labels_factored["singler"])

array([[280, 106,  30,   8],
       [  0, 289,   0,   0],
       [  0,   3, 196,   0],
       [  2,  16,   6,  64]])

In [9]:
confusion_matrix(real_y, all_labels_factored["scpred"])

array([[  0,   0,   0,   0,   0],
       [ 34, 351,  18,  21,   0],
       [ 35,  13, 210,  31,   0],
       [ 28,   6,  17, 148,   0],
       [ 36,  13,  19,  20,   0]])

In [10]:
confusion_matrix(real_y, all_labels_factored["scsorter"])

array([[333,  46,  23,  22],
       [  3, 279,   4,   3],
       [  4,   8, 175,  12],
       [  5,  14,   6,  63]])

In [59]:
print(utilities.pred_accuracy(all_labels_factored['scina'], real_y))
print(utilities.pred_accuracy(all_labels_factored['sctype'], real_y))
print(utilities.pred_accuracy(all_labels_factored['scsorter'], real_y))
print(utilities.pred_accuracy(all_labels_factored['singler'], real_y))
print(utilities.pred_accuracy(all_labels_factored['scpred'], real_y))
max_pred = torch.tensor(encoded_labels).max(dim=1)[1]
utilities.pred_accuracy(max_pred, real_y)

0.4390000104904175
0.8519999980926514
0.8500000238418579
0.8289999961853027
0.7089999914169312


  return float((torch.tensor(preds) == torch.tensor(real)).type(torch.FloatTensor).mean().numpy())


0.9150000214576721

In [53]:
print(utilities.pred_accuracy(np.array(all_labels_factored['scina'][test_nodes]), real_y[test_nodes]))
print(utilities.pred_accuracy(np.array(all_labels_factored['sctype'][test_nodes]), real_y[test_nodes]))
print(utilities.pred_accuracy(np.array(all_labels_factored['scsorter'][test_nodes]), real_y[test_nodes]))
print(utilities.pred_accuracy(np.array(all_labels_factored['singler'][test_nodes]), real_y[test_nodes]))
#print(utilities.pred_accuracy(np.array(all_labels_factored['scpred'][test_nodes]), real_y[test_nodes]))
max_pred = torch.tensor(encoded_labels).max(dim=1)[1]
print(utilities.pred_accuracy(max_pred[test_nodes], real_y[test_nodes]))

0.1304347813129425
0.5652173757553101
0.3586956560611725
0.5652173757553101
0.6195651888847351


In [54]:
dataset  = torch.utils.data.TensorDataset(torch.tensor(X), torch.tensor(confident_labels))
dataloader = torch.utils.data.DataLoader(dataset, batch_size=50, shuffle=True)

test_dataset  = torch.utils.data.TensorDataset(torch.tensor(X), torch.tensor(real_y))
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=50, shuffle=False)

In [55]:
m = GCNModel("configs/2_40.txt", 2, dropout=0.0)
m.train(dataloader, 150)

Loss in epoch 0 = 24.592543
Loss in epoch 10 = 0.062802
Loss in epoch 20 = 0.016492
Loss in epoch 30 = 0.011047
Loss in epoch 40 = 0.004784
Loss in epoch 50 = 0.002630
Loss in epoch 60 = 0.001744
Loss in epoch 70 = 0.001202
Loss in epoch 80 = 0.001645
Loss in epoch 90 = 0.000840
Loss in epoch 100 = 0.000533
Loss in epoch 110 = 0.000649
Loss in epoch 120 = 0.000449
Loss in epoch 130 = 0.000303
Loss in epoch 140 = 0.000302


In [56]:
m.validation_metrics(test_dataloader, train_nodes, test_nodes)

(0.925000011920929,
 array([[390,  25,   7,   2],
        [  1, 286,   2,   0],
        [  1,   1, 197,   0],
        [  6,  22,   8,  52]]),
 0.9405286312103271,
 array([[342,  20,   6,   2],
        [  1, 283,   1,   0],
        [  1,   1, 187,   0],
        [  3,  15,   4,  42]]),
 0.77173912525177,
 array([[48,  5,  1,  0],
        [ 0,  3,  1,  0],
        [ 0,  0, 10,  0],
        [ 3,  7,  4, 10]]))

In [None]:
# add copies of confidently labelled cell type 4 to boost numbers

In [28]:
X.shape

(1000, 500)

In [29]:
conf_four_cells = np.where(confident_labels == 3)[0]
X[conf_four_cells,:]

array([[ 2.540359  ,  0.32933518,  4.0460124 , ...,  1.7700831 ,
        -0.56088966, -0.5383885 ],
       [-0.24325468, -0.06944561,  2.0636573 , ...,  0.48108217,
        -0.1444801 , -0.9558514 ],
       [-5.470358  , -0.8024255 ,  0.4330117 , ..., -0.23332869,
         0.29994816,  0.54539615],
       ...,
       [ 3.7671275 ,  2.113782  ,  0.7602452 , ..., -0.8004527 ,
         1.1425638 ,  0.40496314],
       [-2.958261  , -1.6363332 ,  1.5049648 , ...,  0.4195115 ,
        -0.30215836, -1.1410402 ],
       [-1.6742705 , -0.24646087,  0.88787675, ..., -0.167398  ,
        -0.0692627 ,  0.33483595]], dtype=float32)

In [30]:
repeated_four = np.tile(X[conf_four_cells,:], (2,1))
repeated_four.shape

(48, 500)

In [31]:
test = np.array(X[confident_labels==0,18], dtype="float64")
statistics.stdev(test)

1.9222024599210852

In [32]:
random_mat = np.zeros(repeated_four.shape)
for i in range(X.shape[1]):
    temp_sd = statistics.stdev(np.array(X[confident_labels==0,i], dtype="float64"))
    random_column = np.random.normal(0,temp_sd, repeated_four.shape[0])
    random_mat[:,i] = random_column

In [33]:
random_mat

array([[-1.24235393e+00, -8.59261115e-01, -1.40901609e+00, ...,
         7.59488836e-01, -1.59615285e-03,  1.31556647e+00],
       [ 2.70436280e+00, -1.32602598e-01, -1.90457422e+00, ...,
         1.08755049e+00, -3.43271236e-01,  5.18055733e-01],
       [ 2.00321669e+00, -1.38146048e+00,  8.41548703e-01, ...,
         3.73329329e-01,  1.26370590e+00, -2.38368101e-01],
       ...,
       [-6.69574155e+00, -4.92453626e-01,  7.10593196e-01, ...,
        -4.96285943e-01, -2.25956939e-01, -7.90175762e-02],
       [-5.88930010e+00,  6.40525581e-01,  7.85815853e-01, ...,
        -4.62410951e-01, -1.02095044e+00, -3.44951853e-01],
       [-6.55653351e+00,  1.53147037e+00, -3.71055408e+00, ...,
        -6.52639707e-01, -1.06145691e-01, -9.54565514e-01]])

In [34]:
#repeated_four = repeated_four + np.random.normal(0, 4, repeated_four.shape)
repeated_four = repeated_four + random_mat
repeated_four

array([[ 1.29800509, -0.52992593,  2.63699631, ...,  2.52957191,
        -0.56248581,  0.77717798],
       [ 2.46110812, -0.20204821,  0.15908307, ...,  1.56863266,
        -0.48775133, -0.43779564],
       [-3.46714121, -2.18388599,  1.27456041, ...,  0.14000064,
         1.56365405,  0.30702805],
       ...,
       [-2.92861404,  1.6213283 ,  1.4708384 , ..., -1.29673865,
         0.91660688,  0.32594556],
       [-8.84756111, -0.99580765,  2.29078068, ..., -0.04289945,
        -1.32310879, -1.48599206],
       [-8.23080402,  1.2850095 , -2.82267733, ..., -0.82003771,
        -0.17540839, -0.61972957]])

In [35]:
X_extended = np.concatenate((X, repeated_four), axis=0)
X_extended.shape

(1048, 500)

In [36]:
confident_labels.shape

(1000,)

In [38]:
extended_conf_labels = np.concatenate((confident_labels, np.array([3]*48)))
extended_conf_labels.shape

(1048,)

In [39]:
extended_real_y = np.concatenate((real_y, np.array([3]*48)))
extended_real_y.shape

(1048,)

In [40]:
# mix in fake nodes
random.seed(8)
shuffled = list(range(len(extended_real_y)))
random.shuffle(shuffled)
extended_real_y = extended_real_y[shuffled]
extended_conf_labels = extended_conf_labels[shuffled]
X_extended = X_extended[shuffled,:]
train_nodes = np.where(extended_conf_labels != -1)[0]
test_nodes = np.where(extended_conf_labels == -1)[0]

In [41]:
len(test_nodes)

98

In [42]:
extended_conf_labels[1:100]

array([ 0.,  0.,  1.,  3.,  2.,  0.,  2.,  2.,  1.,  2.,  2.,  2.,  1.,
        0.,  2.,  1.,  1.,  1.,  1.,  1.,  0., -1.,  0.,  1.,  0.,  3.,
        1.,  3.,  2.,  0.,  0.,  2.,  2.,  1.,  0.,  2.,  1.,  1.,  1.,
        0.,  1.,  0.,  0.,  0.,  1.,  0., -1.,  0.,  1.,  1.,  1.,  0.,
        1.,  2.,  3.,  1.,  2.,  0.,  0.,  0.,  1.,  1.,  0.,  2.,  0.,
        2., -1.,  0., -1.,  2.,  1.,  1.,  1.,  1.,  1.,  0.,  1.,  1.,
        2.,  1.,  1.,  0.,  0.,  2.,  1., -1.,  2., -1.,  0.,  1.,  0.,
        1.,  0., -1.,  0.,  2.,  1.,  0.,  1.])

In [43]:
dataset  = torch.utils.data.TensorDataset(torch.tensor(X_extended), torch.tensor(extended_conf_labels))
dataloader = torch.utils.data.DataLoader(dataset, batch_size=50, shuffle=True)

test_dataset  = torch.utils.data.TensorDataset(torch.tensor(X_extended), torch.tensor(extended_real_y))
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=50, shuffle=False)

In [44]:
m = GCNModel("configs/2_40.txt", 2, dropout=0.0)
m.train(dataloader, 150)

Loss in epoch 0 = 26.924913
Loss in epoch 10 = 0.080819
Loss in epoch 20 = 0.017143
Loss in epoch 30 = 0.006573
Loss in epoch 40 = 0.004491
Loss in epoch 50 = 0.003160
Loss in epoch 60 = 0.002770
Loss in epoch 70 = 0.001824
Loss in epoch 80 = 0.000823
Loss in epoch 90 = 0.000830
Loss in epoch 100 = 0.000607
Loss in epoch 110 = 0.000862
Loss in epoch 120 = 0.000488
Loss in epoch 130 = 0.000279
Loss in epoch 140 = 0.000210


In [45]:
m.validation_metrics(test_dataloader, train_nodes, test_nodes)

(0.9179389476776123,
 array([[397,  18,   8,   1],
        [  1, 286,   2,   0],
        [  3,   0, 196,   0],
        [  9,  24,  20,  83]]),
 0.9547368288040161,
 array([[363,  13,   7,   1],
        [  0, 283,   1,   0],
        [  1,   0, 190,   0],
        [  3,  11,   6,  71]]),
 0.5612244606018066,
 array([[34,  5,  1,  0],
        [ 1,  3,  1,  0],
        [ 2,  0,  6,  0],
        [ 6, 13, 14, 12]]))