# Tutorial: Model training

In [None]:
# pip install zarr<3 lamindb lightning modlyn
import warnings
import os
from os.path import join
import lamindb as ln
import anndata as ad
import lightning as L
from tqdm import tqdm
from modlyn.io.datamodules import ClassificationDataModule
from modlyn.models.linear import Linear
from modlyn.io.loading import read_lazy

ln.track("UMQFXo0vs0Z6", project="DataLoader v2")

## Cache the pre-shuffled zarr store

In [None]:
# if running this not in the arrayloader-benchmarks instance, please add .using(...)
# ln.Artifact.using("laminlabs/arrayloader-benchmarks").get(uid)
# artifact_tahoe_store = ln.Artifact.get("BQ6RplqNcT0akokn0000")  # full 100M cells and 60k genes
artifact_tahoe_store = ln.Artifact.get("TuhkPw0wkzlUXN5k0000")  # subsampled to 2k cells and 200 genes
artifact_tahoe_store

In [None]:
%%time
# in case of the 100M cell datasets, downloads 320GB and 36k zarr fragments (files) into the local cache
# will run a while even on AWS due to so many files
store_path = artifact_tahoe_store.cache()

In [None]:
# list(store_path.iterdir())
store_path

## Train a linear model

In [None]:
import anndata
anndata.__version__

In [None]:
with warnings.catch_warnings():
    warnings.simplefilter("ignore")  # ignore zarr warnings that zarrv3 codec is not final yet
    adata = read_lazy(store_path)

adata

In [None]:
adata.obs["y"] = adata.obs["cell_line"].astype("category").cat.codes.to_numpy().astype("i8")

In [None]:
adata_train = adata[:80_527_360]
adata_val = adata[80_527_360:]

datamodule = ClassificationDataModule(
    adata_train=adata_train,
    adata_val=adata_val,
    label_column="y",
    train_dataloader_kwargs={
        "batch_size": 2048,
        "drop_last": True,
    },
    val_dataloader_kwargs={
        "batch_size": 2048,
        "drop_last": False,
    },
)

In [None]:
linear = Linear(
    n_genes=adata.n_vars,
    n_covariates=adata.obs["y"].nunique(),
    learning_rate=1e-2,
)

In [None]:
trainer = L.Trainer(
    max_epochs=3,
    log_every_n_steps=100,
    max_steps=3000,  # only fit a few steps for the sake of this tutorial
)

In [None]:
trainer.fit(model=linear, datamodule=datamodule)

In [None]:
ln.finish()