In [1]:
import numpy as np
import pandas as pd
import torch
from gcn_model import GCNModel
import utilities
from test_model import test_model
import os

In [2]:
#data_folder = "/home/groups/ConradLab/daniel/sharp_sims/splat_0.7_de_rq/"
data_folder = "simulations/splat_0.7_de_rq/"

In [3]:
os.path.exists(data_folder + "preds.csv")

True

In [3]:
# get labels
data_path = data_folder + "query_counts.csv"
tools = ["sctype","scsorter","scina","singler", "scpred"]
#tools = ["scsorter","scina","singler"]
ref_path = data_folder + "ref_counts.csv"
ref_label_path = data_folder + "ref_labels.csv"
marker_path = data_folder + "markers.txt"
if os.path.exists(data_folder + "preds.csv"):
    all_labels = pd.read_csv(data_folder + "preds.csv", index_col=0)
    if all_labels.shape[1] != len(tools): 
        all_labels = all_labels[tools]
        #raise Exception("wrong amount of tools in file")
else:
    all_labels = utilities.label_counts(data_path,tools,ref_path,ref_label_path,marker_path)

In [43]:
all_labels.shape[1]

3

In [6]:
all_labels

Unnamed: 0,scina,scsorter,sctype,singler,scpred
Cell1001,Group2,Group2,Group1,Group2,Group2
Cell1002,Group4,Group2,Group1,Group4,Group2
Cell1003,Group1,Group4,Group1,Group2,
Cell1004,Group4,Group4,Group4,Group4,Group1
Cell1005,Group1,Group1,Group1,Group1,Group1
...,...,...,...,...,...
Cell1996,Group2,Group2,Group4,Group2,Group1
Cell1997,Group1,Group1,Group1,Group1,Group1
Cell1998,Group3,Group3,Group4,Group3,Group3
Cell1999,Group4,Group4,Group1,Group4,Group4


In [4]:
# read in dataset
X = pd.read_csv(data_path, index_col=0)
X, keep_cells = utilities.preprocess(np.array(X), scale=False)
X.shape

(999, 500)

In [8]:
all_labels = all_labels.loc[keep_cells,:]

In [9]:
_,marker_names = utilities.read_marker_file(marker_path)
marker_names

['Group1', 'Group2', 'Group3', 'Group4']

In [10]:
all_labels['scsorter']

Cell1001    Group2
Cell1002    Group2
Cell1003    Group4
Cell1004    Group4
Cell1005    Group1
             ...  
Cell1996    Group2
Cell1997    Group1
Cell1998    Group3
Cell1999    Group4
Cell2000    Group2
Name: scsorter, Length: 999, dtype: object

In [11]:
all_labels_factored = utilities.factorize_df(all_labels, marker_names)
encoded_labels = utilities.encode_predictions(all_labels_factored)
encoded_labels

array([[1., 4., 0., 0.],
       [1., 2., 0., 2.],
       [2., 1., 0., 1.],
       ...,
       [0., 0., 4., 1.],
       [1., 0., 0., 4.],
       [1., 3., 0., 0.]])

In [12]:
meta_path = data_folder + "query_meta.csv"
metadata = pd.read_csv(meta_path, index_col=0)
real_y = pd.factorize(metadata['Group'], sort=True)[0]
real_y = real_y[keep_cells]
real_y.shape

(999,)

In [61]:
len(real_y[real_y==0])

245

In [22]:
print(utilities.pred_accuracy(all_labels_factored['scina'], real_y))
print(utilities.pred_accuracy(all_labels_factored['sctype'], real_y))
print(utilities.pred_accuracy(all_labels_factored['scsorter'], real_y))
print(utilities.pred_accuracy(all_labels_factored['singler'], real_y))
print(utilities.pred_accuracy(all_labels_factored['scpred'], real_y))


0.792792797088623
0.4114114046096802
0.826826810836792
0.8408408164978027
0.6386386156082153


In [23]:
max_pred = torch.tensor(encoded_labels).max(dim=1)[1]
utilities.pred_accuracy(max_pred, real_y)

0.9389389157295227

In [24]:
confident_labels = utilities.get_consensus_labels(encoded_labels, necessary_vote = 3)
confident_labels.shape

(999,)

In [25]:
train_nodes = np.where(confident_labels != -1)[0]
test_nodes = np.where(confident_labels == -1)[0]
print(np.unique(confident_labels))
print(np.unique(confident_labels[train_nodes]))
print(np.unique(confident_labels[test_nodes]))

[-1.  0.  1.  2.  3.]
[0. 1. 2. 3.]
[-1.]


In [26]:
real_y[test_nodes]

array([3, 0, 3, 0, 0, 0, 1, 2, 2, 2, 3, 0, 1, 1, 3, 2, 0, 1, 2, 3, 2, 1,
       0, 3, 3, 0, 3, 1, 2, 1, 1, 0, 0, 3, 2, 0, 0, 1, 2, 2, 0, 1, 1, 3,
       2, 0, 0, 0, 0, 3, 3, 3, 0, 3, 1, 0, 3, 3, 0, 1, 2, 2, 0, 2, 1, 0,
       3, 1, 0, 0, 2, 1, 3, 2, 0, 0, 0, 0, 2, 1, 0, 0, 2, 3, 0, 2, 3, 0,
       1, 2, 3, 2, 2, 2, 2, 0, 3, 0, 2, 3, 3, 2, 2, 0, 0, 2, 2, 0, 0, 0,
       3, 1, 2, 2, 0, 0, 2, 0])

In [27]:
print(utilities.pred_accuracy(confident_labels[train_nodes], real_y[train_nodes]))

0.9807037711143494


In [28]:
len(test_nodes)

118

In [21]:
# tool accuracy on test
print(utilities.pred_accuracy(np.array(all_labels_factored['scina'][test_nodes]), real_y[test_nodes]))
print(utilities.pred_accuracy(np.array(all_labels_factored['sctype'][test_nodes]), real_y[test_nodes]))
print(utilities.pred_accuracy(np.array(all_labels_factored['scsorter'][test_nodes]), real_y[test_nodes]))
print(utilities.pred_accuracy(np.array(all_labels_factored['singler'][test_nodes]), real_y[test_nodes]))
print(utilities.pred_accuracy(np.array(all_labels_factored['scpred'][test_nodes]), real_y[test_nodes]))
max_pred = torch.tensor(encoded_labels).max(dim=1)[1]
print(utilities.pred_accuracy(max_pred[test_nodes], real_y[test_nodes]))

0.19491524994373322
0.3644067943096161
0.43220338225364685
0.5932203531265259
0.17796610295772552
0.6271186470985413


In [29]:
dataset  = torch.utils.data.TensorDataset(torch.tensor(X), torch.tensor(confident_labels))
dataloader = torch.utils.data.DataLoader(dataset, batch_size=35, shuffle=True)

test_dataset  = torch.utils.data.TensorDataset(torch.tensor(X), torch.tensor(real_y))
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=35, shuffle=False)

In [30]:
m = GCNModel("configs/2_8.txt", 2, dropout=0.0)

In [31]:
m.train(dataloader, 100)

Loss in epoch 0 = 36.780918
Loss in epoch 10 = 0.142983
Loss in epoch 20 = 0.038523
Loss in epoch 30 = 0.018872
Loss in epoch 40 = 0.008604
Loss in epoch 50 = 0.006198
Loss in epoch 60 = 0.003754
Loss in epoch 70 = 0.002780
Loss in epoch 80 = 0.002784
Loss in epoch 90 = 0.001744


In [32]:
m.validation_metrics(test_dataloader, train_nodes, test_nodes)

(0.9699699878692627,
 array([[222,   4,   1,   5],
        [  0, 267,   1,   0],
        [  2,   4, 223,   5],
        [  2,   5,   1, 257]]),
 0.9807037711143494,
 array([[183,   3,   1,   3],
        [  0, 249,   0,   0],
        [  1,   2, 195,   4],
        [  1,   1,   1, 237]]),
 0.8898305296897888,
 array([[39,  1,  0,  2],
        [ 0, 18,  1,  0],
        [ 1,  2, 28,  1],
        [ 1,  4,  0, 20]]))

In [22]:
#data_folders = ["/home/groups/ConradLab/daniel/sharp_sims/splat_0.5_de_rq/", "/home/groups/ConradLab/daniel/sharp_sims/splat_0.6_de_rq/", "/home/groups/ConradLab/daniel/sharp_sims/splat_0.7_de_rq/"]
data_folders = ["simulations/splat_0.6_de_rq/", "simulations/splat_0.7_de_rq/", "simulations/splat_0.8_de_rq/"]
tools = ["sctype","scsorter","scina","singler", "scpred"]
votes_necessary = 3
model_file = "configs/2_8.txt"
neighbors = 2
batch_size=35
training_epochs=100
random_inits = 3

In [23]:
results = test_model(data_folders, tools, votes_necessary, model_file, neighbors, batch_size, training_epochs, random_inits)

  view_to_actual(adata)
  view_to_actual(adata)


[0.828000009059906, 0.8309999704360962, 0.8180000185966492]


  return float((torch.tensor(preds) == torch.tensor(real)).type(torch.FloatTensor).mean().numpy())
  view_to_actual(adata)
  view_to_actual(adata)


[0.965965986251831, 0.9529529809951782, 0.9679679870605469]


  return float((torch.tensor(preds) == torch.tensor(real)).type(torch.FloatTensor).mean().numpy())
  view_to_actual(adata)
  view_to_actual(adata)


[0.9829829931259155, 0.9829829931259155, 0.9859859943389893]


  return float((torch.tensor(preds) == torch.tensor(real)).type(torch.FloatTensor).mean().numpy())


In [24]:
results

Unnamed: 0,data_name,method,total_accuracy,train_accuracy,test_accuracy,total_sd,train_sd,test_sd
0,splat_0.6_de_rq,GCN,0.825667,0.944708,0.620345,0.006807,0.0,0.018547
1,splat_0.6_de_rq,Max Col.,0.798,0.944708,0.544959,0.0,0.0,0.0
2,splat_0.6_de_rq,Confident Labels,,0.944708,,0.0,0.0,0.0
3,splat_0.6_de_rq,sctype,0.232,0.235387,0.226158,0.0,0.0,0.0
4,splat_0.6_de_rq,scsorter,0.677,0.846761,0.384196,0.0,0.0,0.0
5,splat_0.6_de_rq,scina,0.467,0.668246,0.119891,0.0,0.0,0.0
6,splat_0.6_de_rq,singler,0.84,0.913112,0.713896,0.0,0.0,0.0
7,splat_0.6_de_rq,scpred,0.503,0.663507,0.226158,0.0,0.0,0.0
0,splat_0.7_de_rq,GCN,0.962296,0.978186,0.854167,0.008153,0.0,0.063629
1,splat_0.7_de_rq,Max Col.,0.935936,0.978186,0.648438,0.0,0.0,0.0


In [34]:
# check grid search results
grid_results = pd.read_csv("grid_search_0.5_output/combined.csv", header=None, index_col=0)
grid_results.columns = ["Total Accuracy", "Train Accuracy", "Test Accuracy"]

In [42]:
grid_results.sort_values("Test Accuracy", ascending=False).head(50)

Unnamed: 0_level_0,Total Accuracy,Train Accuracy,Test Accuracy
0,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
2345,0.846,0.976492,0.684564
1996,0.839,0.976492,0.668904
1801,0.839,0.976492,0.668904
1621,0.838,0.976492,0.666667
1269,0.837,0.976492,0.66443
1103,0.835,0.976492,0.659955
1261,0.835,0.976492,0.659955
721,0.834,0.976492,0.657718
901,0.833,0.976492,0.655481
2165,0.831,0.976492,0.651007
