-
Notifications
You must be signed in to change notification settings - Fork 11
/
script.py
135 lines (114 loc) · 4.47 KB
/
script.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
from scanpy._utils import check_nonnegative_integers
import mudata
import scvi
### VIASH START
par = {
"input": "resources_test/pbmc_1k_protein_v3/pbmc_1k_protein_v3_mms.h5mu",
"modality": "rna",
"input_layer": None,
"obs_batch": "sample_id",
"var_input": None,
"output": "foo.h5mu",
"obsm_output": "X_scvi_integrated",
"early_stopping": True,
"early_stopping_monitor": "elbo_validation",
"early_stopping_patience": 45,
"early_stopping_min_delta": 0,
"reduce_lr_on_plateau": True,
"lr_factor": 0.6,
"lr_patience": 30,
"max_epochs": 500,
"n_obs_min_count": 10,
"n_var_min_count": 10,
"output_model": "test/",
"output_compression": "gzip",
}
meta = {
"resources_dir": 'src/integrate/scvi'
}
### VIASH END
import sys
sys.path.append(meta['resources_dir'])
# START TEMPORARY WORKAROUND subset_vars
# reason: resources aren't available when using Nextflow fusion
# from subset_vars import subset_vars
def subset_vars(adata, subset_col):
return adata[:, adata.var[subset_col]].copy()
# END TEMPORARY WORKAROUND subset_vars
#TODO: optionally, move to qa
# https://github.com/openpipelines-bio/openpipeline/issues/435
def check_validity_anndata(adata, layer, obs_batch,
n_obs_min_count, n_var_min_count):
assert check_nonnegative_integers(
adata.layers[layer] if layer else adata.X
), f"Make sure input adata contains raw_counts"
assert len(set(adata.var_names)) == len(
adata.var_names
), f"Dataset contains multiple genes with same gene name."
# Ensure every obs_batch category has sufficient observations
assert min(adata.obs[[obs_batch]].value_counts()) > n_obs_min_count, \
f"Anndata has fewer than {n_obs_min_count} cells."
assert adata.n_vars > n_var_min_count, \
f"Anndata has fewer than {n_var_min_count} genes."
def main():
mdata = mudata.read(par["input"].strip())
adata = mdata.mod[par['modality']]
if par['var_input']:
# Subset to HVG
adata_subset = subset_vars(adata, subset_col=par["var_input"]).copy()
else:
adata_subset = adata.copy()
check_validity_anndata(
adata_subset, par['input_layer'], par['obs_batch'],
par["n_obs_min_count"], par["n_var_min_count"]
)
# Set up the data
scvi.model.SCVI.setup_anndata(
adata_subset,
batch_key=par['obs_batch'],
layer=par['input_layer'],
labels_key=par['obs_labels'],
size_factor_key=par['obs_size_factor'],
categorical_covariate_keys=par['obs_categorical_covariate'],
continuous_covariate_keys=par['obs_continuous_covariate'],
)
# Set up the model
vae_uns = scvi.model.SCVI(
adata_subset,
n_hidden=par["n_hidden_nodes"],
n_latent=par["n_dimensions_latent_space"],
n_layers=par["n_hidden_layers"],
dropout_rate=par["dropout_rate"],
dispersion=par["dispersion"],
gene_likelihood=par["gene_likelihood"],
use_layer_norm=par["use_layer_normalization"],
use_batch_norm=par["use_batch_normalization"],
encode_covariates=par["encode_covariates"], # Default (True) is for better scArches performance -> maybe don't use this always?
deeply_inject_covariates=par["deeply_inject_covariates"], # Default (False) for better scArches performance -> maybe don't use this always?
use_observed_lib_size=par["use_observed_lib_size"], # When size_factors are not passed
)
plan_kwargs = {
"reduce_lr_on_plateau": par['reduce_lr_on_plateau'],
"lr_patience": par['lr_patience'],
"lr_factor": par['lr_factor'],
}
# Train the model
vae_uns.train(
max_epochs=par['max_epochs'],
early_stopping=par['early_stopping'],
early_stopping_monitor=par['early_stopping_monitor'],
early_stopping_patience=par['early_stopping_patience'],
early_stopping_min_delta=par['early_stopping_min_delta'],
plan_kwargs=plan_kwargs,
check_val_every_n_epoch=1,
accelerator="auto",
)
# Note: train_size=1.0 should give better results, but then can't do early_stopping on validation set
# Get the latent output
adata.obsm[par['obsm_output']] = vae_uns.get_latent_representation()
mdata.mod[par['modality']] = adata
mdata.write_h5mu(par['output'].strip(), compression=par["output_compression"])
if par["output_model"]:
vae_uns.save(par["output_model"], overwrite=True)
if __name__ == "__main__":
main()