In [1]:
from ref_query_arm import RefQueryArm
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import utilities

In [2]:
datasets = []
for i in np.arange(.4,.5,.1):
    i = round(i, 5)
    file_path = "simulations/splat_" + str(i) + "_de/"
    temp_data = pd.read_csv(file_path + "counts.csv", index_col=0)
    temp_meta = pd.read_csv(file_path + "meta.csv", index_col=0)
    temp_X = np.array(temp_data)
    temp_y = pd.factorize(temp_meta['Group'], sort=True)[0]
    datasets.append((temp_X, temp_y))

In [62]:
labels = datasets[0][1]
len(labels[labels==2])

219

In [55]:
ref_X = datasets[0][0][:500,:]
ref_y = datasets[0][1][:500]

query_X = datasets[0][0][500:,:]
query_y = datasets[0][1][500:]

ref_X = utilities.preprocess(ref_X, scale=False)
query_X = utilities.preprocess(query_X, scale=False)
query_y, query_y_masked = utilities.mask_labels(query_y, 0)

215

In [56]:
ref_dataset = torch.utils.data.TensorDataset(torch.tensor(ref_X), torch.tensor(ref_y))
query_train_dataset = torch.utils.data.TensorDataset(torch.tensor(query_X), torch.tensor(query_y_masked))
query_test_dataset = torch.utils.data.TensorDataset(torch.tensor(query_X), torch.tensor(query_y))

ref_dataloader = torch.utils.data.DataLoader(ref_dataset, batch_size=35, shuffle=True)
ref_test_dataloader = torch.utils.data.DataLoader(ref_dataset, batch_size=35, shuffle=False)
query_train_dataloader = torch.utils.data.DataLoader(query_train_dataset, batch_size=35, shuffle=True)
query_test_dataloader = torch.utils.data.DataLoader(query_test_dataset, batch_size=35, shuffle=False)

In [59]:
arm = RefQueryArm("configs/semi_basic_linear.txt", 2, 1)

In [60]:
arm.train(ref_dataloader, query_train_dataloader, 10)

Loss in epoch 0 = 21.183672


In [52]:
arm.validation_metrics(ref_test_dataloader, query_test_dataloader)

tensor([1, 0, 2, 2, 0, 2, 0, 3, 2, 1, 1, 0, 1, 1, 0, 3, 0, 0, 0, 1, 0, 0, 1, 0,
        2, 3, 0, 1, 1, 2, 0, 1, 3, 0, 0, 0, 3, 3, 0, 2, 3, 0, 1, 2, 0, 0, 1, 0,
        3, 0, 0, 0, 0, 0, 3, 0, 2, 0, 0, 2, 3, 3, 2, 1, 0, 0, 0, 0, 3, 0, 2, 1,
        2, 2, 0, 1, 1, 3, 2, 2, 1, 2, 0, 2, 1, 0, 0, 0, 2, 0, 0, 2, 0, 1, 1, 0,
        1, 1, 0, 0, 2, 0, 0, 1, 2, 2, 1, 2, 2, 2, 2, 2, 0, 0, 0, 1, 0, 3, 2, 3,
        2, 0, 1, 0, 0, 0, 0, 0, 0, 3, 0, 2, 0, 2, 3, 2, 2, 1, 0, 0, 0, 0, 2, 3,
        2, 0, 0, 3, 1, 3, 2, 2, 0, 1, 2, 1, 3, 0, 0, 0, 1, 2, 0, 2, 0, 0, 3, 0,
        0, 1, 3, 0, 1, 3, 0, 3, 2, 2, 0, 2, 2, 0, 3, 0, 0, 2, 0, 0, 0, 0, 2, 3,
        1, 2, 1, 2, 0, 1, 1, 2, 1, 0, 0, 2, 0, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 1,
        1, 0, 2, 1, 1, 0, 0, 1, 2, 0, 0, 0, 3, 2, 3, 1, 1, 0, 3, 0, 0, 3, 0, 3,
        0, 2, 0, 0, 0, 0, 3, 1, 2, 2, 1, 3, 0, 2, 1, 3, 0, 1, 2, 3, 3, 1, 1, 1,
        0, 2, 0, 1, 0, 2, 2, 1, 2, 3, 2, 0, 3, 1, 0, 1, 3, 0, 3, 1, 2, 0, 0, 1,
        1, 0, 0, 0, 2, 0, 2, 2, 0, 0, 0,

(1.0,
 array([[216,   0,   0,   0],
        [  0, 106,   0,   0],
        [  0,   0, 111,   0],
        [  0,   0,   0,  67]]),
 0.593999981880188,
 array([[215,   0,   0,   0],
        [  1,  10, 106,   0],
        [  0,  96,  12,   0],
        [  0,   0,   0,  60]]))

In [38]:
print(len(ref_y[ref_y==0]))
print(len(ref_y[ref_y==1]))
print(len(ref_y[ref_y==2]))
print(len(ref_y[ref_y==3]))
print(len(query_y[query_y==0]))
print(len(query_y[query_y==1]))
print(len(query_y[query_y==2]))
print(len(query_y[query_y==3]))

216
106
111
67
215
117
108
60


In [6]:
query_preds = torch.tensor([[1,0,0,0],[1,0,0,0],[0,0,0,1]])
query_y = torch.tensor([0,0,2])

ref_preds = torch.tensor([[0,1,0,0]])
ref_y = torch.tensor([0])



arm.validation_metrics_test(query_preds, ref_preds, query_y, ref_y)

tensor([0, 0, 3])


(0.6666666865348816,
 array([[2, 0, 0],
        [0, 0, 1],
        [0, 0, 0]]),
 0.0,
 array([[0, 1],
        [0, 0]]))