In [1]:
import numpy as np
import pandas as pd
import torch
from gcn_model import GCNModel
import utilities
from test_model import test_model

  from .autonotebook import tqdm as notebook_tqdm


In [1]:
data_folder = "simulations/splat_0.5_de_rq/"

In [5]:
# get labels
data_path = data_folder + "query_counts.csv"
tools = ["sctype","scsorter","scina","singler", "scpred"]
ref_path = data_folder + "ref_counts.csv"
ref_label_path = data_folder + "ref_labels.csv"
marker_path = data_folder + "markers.txt"
all_labels = utilities.label_counts(data_path,tools,ref_path,ref_label_path,marker_path)

In [10]:
all_labels

Unnamed: 0,scina,scsorter,sctype,singler,scpred
Cell1001,Group3,Group3,Group3,Group3,Group2
Cell1002,Group2,Group2,Group4,Group2,Group2
Cell1003,Group4,Group4,Group3,Group4,Group3
Cell1004,Group3,Group3,Group3,Group3,Group3
Cell1005,Group4,Group4,Group3,Group4,Group1
...,...,...,...,...,...
Cell1996,Group4,Group4,Group4,Group4,Group1
Cell1997,Group3,Group3,Group3,Group3,Group2
Cell1998,Group1,Group1,Group1,Group1,Group1
Cell1999,Group2,Group2,Group2,Group2,Group1


In [9]:
_,marker_names = utilities.read_marker_file(marker_path)
marker_names

['Group1', 'Group2', 'Group3', 'Group4']

In [12]:
all_labels['sctype'].unique()

array(['Group3', 'Group4', 'Group2', 'Group1'], dtype=object)

In [44]:
all_labels['sctype']

Cell1001    Group1
Cell1002    Group2
Cell1003    Group2
Cell1004    Group3
Cell1005    Group2
             ...  
Cell1996    Group2
Cell1997    Group1
Cell1998    Group3
Cell1999    Group1
Cell2000    Group2
Name: sctype, Length: 1000, dtype: object

In [13]:
all_labels_factored = utilities.factorize_df(all_labels, marker_names)
encoded_labels = utilities.encode_predictions(all_labels_factored)
encoded_labels

array([[0., 1., 4., 0.],
       [0., 4., 0., 1.],
       [0., 0., 2., 3.],
       ...,
       [5., 0., 0., 0.],
       [1., 4., 0., 0.],
       [5., 0., 0., 0.]])

In [14]:
# read in dataset
X = pd.read_csv(data_path, index_col=0)
X = utilities.preprocess(np.array(X), scale=False)
X.shape

(1000, 500)

In [15]:
meta_path = data_folder + "query_meta.csv"
metadata = pd.read_csv(meta_path, index_col=0)
real_y = pd.factorize(metadata['Group'], sort=True)[0]
real_y.shape

(1000,)

In [16]:
print(utilities.pred_accuracy(all_labels_factored['scina'], real_y))
print(utilities.pred_accuracy(all_labels_factored['sctype'], real_y))
print(utilities.pred_accuracy(all_labels_factored['scsorter'], real_y))
print(utilities.pred_accuracy(all_labels_factored['singler'], real_y))
print(utilities.pred_accuracy(all_labels_factored['scpred'], real_y))


0.9169999957084656
0.6769999861717224
0.9520000219345093
0.9710000157356262
0.5379999876022339


In [17]:
max_pred = torch.tensor(encoded_labels).max(dim=1)[1]
utilities.pred_accuracy(max_pred, real_y)

  return float((torch.tensor(preds) == torch.tensor(real)).type(torch.FloatTensor).mean().numpy())


0.9570000171661377

In [18]:
confident_labels = utilities.get_consensus_labels(encoded_labels, necessary_vote = 3)
confident_labels

array([ 2.,  1.,  3.,  2.,  3.,  0.,  2., -1.,  1.,  0.,  1.,  3.,  1.,
        1.,  1.,  2.,  0.,  0.,  1.,  3.,  0.,  2.,  1.,  0.,  1.,  3.,
        2.,  0.,  0.,  0.,  3.,  1.,  1.,  1.,  3.,  1.,  1.,  3.,  0.,
        2.,  0., -1.,  3., -1.,  3.,  2.,  2.,  2., -1.,  1.,  3.,  1.,
        1.,  2.,  3.,  3.,  3., -1.,  1.,  2.,  3.,  0.,  0.,  2.,  0.,
       -1.,  1.,  1.,  0.,  0.,  1.,  1.,  3.,  2.,  0.,  2.,  1.,  2.,
        0.,  1.,  3.,  1.,  3.,  3.,  2.,  0.,  3.,  2., -1., -1.,  1.,
        0.,  0.,  1.,  1.,  0.,  0.,  0.,  1.,  1.,  1.,  1.,  0.,  2.,
        1.,  0.,  3., -1.,  2.,  1.,  2.,  1.,  2.,  3.,  1.,  0.,  1.,
        0.,  2.,  3.,  1., -1.,  0.,  2.,  3.,  1.,  0.,  0.,  2.,  3.,
        1.,  2.,  3.,  2.,  3.,  3.,  1.,  1.,  0.,  1.,  1.,  1.,  1.,
        3.,  1.,  1.,  3.,  0., -1.,  1.,  1.,  0.,  2.,  0.,  3.,  3.,
        1.,  2.,  1.,  1.,  3.,  2.,  1.,  0.,  2.,  3.,  0.,  3.,  1.,
        3.,  2.,  1.,  2.,  1.,  0.,  3.,  1.,  0.,  2.,  0.,  1

In [19]:
train_nodes = np.where(confident_labels != -1)[0]
test_nodes = np.where(confident_labels == -1)[0]
print(np.unique(confident_labels))
print(np.unique(confident_labels[train_nodes]))
print(np.unique(confident_labels[test_nodes]))

[-1.  0.  1.  2.  3.]
[0. 1. 2. 3.]
[-1.]


In [20]:
print(utilities.pred_accuracy(confident_labels[train_nodes], real_y[train_nodes]))

0.9841938614845276


In [21]:
# tool accuracy on test
print(utilities.pred_accuracy(np.array(all_labels_factored['scina'][test_nodes]), real_y[test_nodes]))
print(utilities.pred_accuracy(np.array(all_labels_factored['sctype'][test_nodes]), real_y[test_nodes]))
print(utilities.pred_accuracy(np.array(all_labels_factored['scsorter'][test_nodes]), real_y[test_nodes]))
print(utilities.pred_accuracy(np.array(all_labels_factored['singler'][test_nodes]), real_y[test_nodes]))
print(utilities.pred_accuracy(np.array(all_labels_factored['scpred'][test_nodes]), real_y[test_nodes]))
max_pred = torch.tensor(encoded_labels).max(dim=1)[1]
print(utilities.pred_accuracy(max_pred[test_nodes], real_y[test_nodes]))

0.19607843458652496
0.29411765933036804
0.47058823704719543
0.7647058963775635
0.0784313753247261
0.45098039507865906


  return float((torch.tensor(preds) == torch.tensor(real)).type(torch.FloatTensor).mean().numpy())


In [22]:
dataset  = torch.utils.data.TensorDataset(torch.tensor(X), torch.tensor(confident_labels))
dataloader = torch.utils.data.DataLoader(dataset, batch_size=35, shuffle=True)

test_dataset  = torch.utils.data.TensorDataset(torch.tensor(X), torch.tensor(real_y))
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=35, shuffle=False)

In [26]:
m = GCNModel("configs/2_8.txt", 2, dropout=0.8)

In [29]:
m.train(dataloader, 100)

Loss in epoch 0 = 0.001826
Loss in epoch 10 = 0.001181
Loss in epoch 20 = 0.000946
Loss in epoch 30 = 0.000834
Loss in epoch 40 = 0.000576
Loss in epoch 50 = 0.000492
Loss in epoch 60 = 0.000438
Loss in epoch 70 = 0.000288
Loss in epoch 80 = 0.000287
Loss in epoch 90 = 0.000271


In [25]:
m.validation_metrics(test_dataloader, train_nodes, test_nodes)

(0.9649999737739563,
 array([[243,   0,   0,   2],
        [  0, 253,   1,   2],
        [ 11,   0, 232,   3],
        [  6,   5,   5, 237]]),
 0.9841938614845276,
 array([[241,   0,   0,   0],
        [  0, 252,   1,   0],
        [  2,   0, 224,   1],
        [  2,   4,   5, 217]]),
 0.6078431606292725,
 array([[ 2,  0,  0,  2],
        [ 0,  1,  0,  2],
        [ 9,  0,  8,  2],
        [ 4,  1,  0, 20]]))

In [30]:
m.validation_metrics(test_dataloader, train_nodes, test_nodes)

(0.9750000238418579,
 array([[244,   0,   0,   1],
        [  1, 254,   1,   0],
        [  4,   2, 239,   1],
        [  3,   5,   7, 238]]),
 0.9841938614845276,
 array([[241,   0,   0,   0],
        [  0, 252,   1,   0],
        [  2,   0, 224,   1],
        [  2,   4,   5, 217]]),
 0.8039215803146362,
 array([[ 3,  0,  0,  1],
        [ 1,  2,  0,  0],
        [ 2,  2, 15,  0],
        [ 1,  1,  2, 21]]))

In [2]:
data_folders = ["/home/groups/ConradLab/daniel/sharp_sims/splat_0.5_de_rq/", "/home/groups/ConradLab/daniel/sharp_sims/splat_0.6_de_rq/", "/home/groups/ConradLab/daniel/sharp_sims/splat_0.7_de_rq/"]
tools = ["sctype","scsorter","scina","singler", "scpred"]
votes_necessary = 3
model_file = "configs/2_8.txt"
neighbors = 2
batch_size=35
training_epochs=100
random_inits = 3

In [3]:
results = test_model(data_folders, tools, votes_necessary, model_file, neighbors, batch_size, training_epochs, random_inits)

NotImplementedError: Conversion 'rpy2py' not defined for objects of type '<class 'rpy2.rinterface.SexpClosure'>'

In [6]:
results

Unnamed: 0,data_name,method,total_accuracy,train_accuracy,test_accuracy,total_sd,train_sd,test_sd
0,splat_0.4_de_rq,GCN,0.717667,0.972222,0.524061,0.033081,0.0,0.058241
1,splat_0.4_de_rq,Max Col.,0.821,0.972222,0.705986,0.0,0.0,0.0
2,splat_0.4_de_rq,Confident Labels,,0.972222,,0.0,0.0,0.0
3,splat_0.4_de_rq,sctype,0.0,0.0,0.0,0.0,0.0,0.0
4,splat_0.4_de_rq,scsorter,0.82,0.972222,0.704225,0.0,0.0,0.0
5,splat_0.4_de_rq,scina,0.527,0.972222,0.18838,0.0,0.0,0.0
6,splat_0.4_de_rq,singler,0.806,0.972222,0.679577,0.0,0.0,0.0
7,splat_0.4_de_rq,scpred,0.0,0.0,0.0,0.0,0.0,0.0
0,splat_0.5_de_rq,GCN,0.895333,0.957831,0.590196,0.010066,0.0,0.059214
1,splat_0.5_de_rq,Max Col.,0.87,0.957831,0.441176,0.0,0.0,0.0
