In [None]:

# For data loading
import anndata as ad
import scanpy as sc
import pytorch_lightning as pl

from wcd_vae.data import get_dataloader_from_adata
from wcd_vae.model import VAE, VAEConfig

In [None]:
# download all the data
# get_data("Vu")
# get_data("Ji")
# get_data("Mascharak")

In [None]:
# Load the anndata object
anndata = ad.read_h5ad("data/vu_2022_ay_wh.h5ad")
anndata.layers["normalized"] = anndata.X

# Find/subset HVGs & swap to raw counts
sc.pp.highly_variable_genes(anndata, n_top_genes=3000, batch_key="sample")
sc.pl.highly_variable_genes(anndata)
# sc.pp.normalize_total(anndata, target_sum=1e4) # normalize each cell to have same total gene count (counts per 1e4)
# sc.pp.log1p(anndata) # Logarithmize data 

sc.pl.scatter(anndata, x="total_counts", y="pct_counts_mt")
sc.pl.scatter(anndata, x="total_counts", y="n_genes_by_counts")

anndata = anndata[:, anndata.var["highly_variable"]]
anndata.X = anndata.layers["counts"]

In [None]:
sc.pl.pca(anndata, color="age")

In [None]:
print(anndata.obs['age'].value_counts(normalize=True))

In [None]:
print(anndata.obs['age'].value_counts(normalize=False))

In [None]:
train_loader, test_loader, domain_encoder = get_dataloader_from_adata(anndata, by = 'age')

In [None]:
config = VAEConfig(
    input_dim=anndata.shape[1],
    latent_dim=10,
    encoder_hidden_dims=[128, 64],
    decoder_hidden_dims=[128, 64],
    dropout=0.1,
    batchsize=64,
    num_epochs=100,
    lr=1e-3,
    weight_decay=1e-5,
)
vae = VAE(config)

In [None]:
trainer = pl.Trainer(
        max_epochs=config.num_epochs,
        accelerator="auto",  # use 'gpu' or 'cpu' explicitly if needed
        devices="auto",      # or use devices=1
        log_every_n_steps=10,
    )

trainer.fit(vae, train_dataloaders=train_loader, val_dataloaders=test_loader)