-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict.py
85 lines (75 loc) · 2.17 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import json
import time
import torch
import numpy as np
import torch.nn as nn
from pathlib import Path
import torch.optim as optim
from argparse import ArgumentParser
from torch.utils.data import DataLoader
from SNNomics.model import SNN
from SNNomics.utils import check_dir
from SNNomics.dataset import PredictDataset
from SNNomics.predictor import Predictor
def rm_query(query_id: str, database: np.ndarray, database_ids: np.array):
query_ind = np.where(database_ids == query_id)[0]
query_vector = database[query_ind, :]
database_queryRm = np.delete(databse, query_ind)
ids_queryRm = np.delete(database_ids, query_ind)
return query_vector, database_queryRm, ids_queryRm
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument(
'-query',
help='Path to .txt file containing GSMs to predict',
type=str,
required=True,
)
parser.add_argument(
'-database',
help='Path to .npz file containing a samples x genes expression matrix',
type=str,
required=True,
)
parser.add_argument(
'-batch_size',
help='size of each batch',
type=int,
default=128,
)
parser.add_argument(
'-outdir',
help='directory to save results to',
type=str,
default='results',
)
parser.add_argument(
'-out_prefix',
help='prefix of results outfiles',
type=str,
default=None,
)
args = parser.parse_args()
# Set paths
samples_file = Path(args.samples)
database_file = Path(args.database)
outdir = Path(args.outdir)
check_dir(outdir)
# Load data
data = np.load(database_file)
database = data['expression']
database_ids = data['gsms']
genes = data['genes']
# Remove query from database
query, database, database_ids = rm_query(args.query, database)
# Predict for queries
predict_data = PredictDataset(database, database_ids)
loader = DataLoader(predict_data, batch_size=batch_size, num_workers=6, shuffle=False)
predictor = Predictor(
query,
model,
criterion,
loader,
device,
outdir,
)