Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## New functionality

* Added `metrics/kbet_pg` and `metrics/kbet_pg_label` components (PR #52).
* Added `method/drvi` component (PR #61).

## Minor changes

Expand Down
44 changes: 44 additions & 0 deletions src/methods/drvi/config.vsh.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
__merge__: ../../api/comp_method.yaml
name: drvi
label: DRVI
summary: "DrVI is an unsupervised generative model capable of learning non-linear interpretable disentangled latent representations from single-cell count data."
description: |
Disentangled Representation Variational Inference (DRVI) is an unsupervised deep generative model designed for integrating single-cell RNA sequencing (scRNA-seq) data across different batches.
It extends the variational autoencoder (VAE) framework by learning a latent representation that captures biological variation while disentangling and correcting for batch effects.
DRVI conditions both the encoder and decoder on batch covariates, allowing it to explicitly model and mitigate batch-specific variations during training.
By incorporating a KL-divergence regularization term, it balances data reconstruction with latent space structure, resulting in a unified embedding where similar cells cluster together regardless of batch.
references:
doi:
- 10.1101/2024.11.06.622266
links:
documentation: https://drvi.readthedocs.io/latest/index.html
repository: https://github.com/theislab/DRVI?tab=readme-ov-file
info:
preferred_normalization: counts
arguments:
- name: --n_hvg
type: integer
default: 2000
description: Number of highly variable genes to use.
- name: --n_epochs
type: integer
default: 100
description: Number of epochs
resources:
- type: python_script
path: script.py
- path: /src/utils/read_anndata_partial.py
engines:
- type: docker
image: openproblems/base_pytorch_nvidia:1.0.0
setup:
- type: python
pypi:
- drvi-py==0.1.7
- torch==2.3.0
- torchvision==0.18.0
runners:
- type: executable
- type: nextflow
directives:
label: [midtime,midmem,lowcpu,gpu]
82 changes: 82 additions & 0 deletions src/methods/drvi/script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import anndata as ad
import scanpy as sc
import drvi
from drvi.model import DRVI
from drvi.utils.misc import hvg_batch
import pandas as pd
import numpy as np
import warnings
import sys
import scipy.sparse

## VIASH START
par = {
'input': 'resources_test/task_batch_integration/cxg_immune_cell_atlas/dataset.h5ad',
'output': 'output.h5ad',
'n_hvg': 2000,
'n_epochs': 400
}
meta = {
'name': 'drvi'
}
## VIASH END

sys.path.append(meta["resources_dir"])
from read_anndata_partial import read_anndata

print('Reading input files', flush=True)
adata = read_anndata(
par['input'],
X='layers/counts',
obs='obs',
var='var',
uns='uns'
)

if par["n_hvg"]:
print(f"Select top {par['n_hvg']} high variable genes", flush=True)
idx = adata.var["hvg_score"].to_numpy().argsort()[::-1][:par["n_hvg"]]
adata = adata[:, idx].copy()

print('Train model with DRVI', flush=True)

DRVI.setup_anndata(
adata,
categorical_covariate_keys=["batch"],
is_count_data=False,
)

model = DRVI(
adata,
categorical_covariates=["batch"],
n_latent=128,
encoder_dims=[128, 128],
decoder_dims=[128, 128],
)
model

model.train(
max_epochs=par["n_epochs"],
early_stopping=False,
early_stopping_patience=20,
plan_kwargs={
"n_epochs_kl_warmup": par["n_epochs"],
},
)

print("Store outputs", flush=True)
output = ad.AnnData(
obs=adata.obs.copy(),
var=adata.var.copy(),
obsm={
"X_emb": model.get_latent_representation(),
},
uns={
"dataset_id": adata.uns.get("dataset_id", "unknown"),
"normalization_id": adata.uns.get("normalization_id", "unknown"),
"method_id": meta["name"],
},
)

print("Write output AnnData to file", flush=True)
output.write_h5ad(par['output'], compression='gzip')