In [1]:
from models import scLightning
from ad_data import setup_ad_anndata_module

from anndata_to_pytorch_dataloader.dataset import setup_anndata_datamodule, setup_simple_datamodule
import anndata as ad
import scanpy as sc

import pytorch_lightning as pl

import gdown
import os
import timeit

# Download anndata

In [2]:
url = 'https://drive.google.com/uc?id=1ehxgfHTsMZXy6YzlFKGJOsBKQ5rrvMnd'
output = 'pancreas.h5ad'

if not os.path.exists(output):
    gdown.download(url, output, quiet=False)
    print(f"File '{output}' downloaded successfully.")
else:
    print(f"Found {output}.")

Found pancreas.h5ad.


In [3]:
adata = sc.read_h5ad("pancreas.h5ad")
# Convert back to raw counts 
adata.X = adata.raw.X # put raw counts to .X
adata.obs['size_factors'] = adata.X.sum(1)
adata.var = adata.var.reset_index()
adata.var.columns = ["gene_name"]
# encode cell types for conversion to tensors 
ct_to_id_dict = {c : i for i, c in enumerate(adata.obs["cell_type"].unique())}
adata.obs["label"] = adata.obs["cell_type"].map(ct_to_id_dict)
                                             

In [4]:
ad_dm = setup_ad_anndata_module(adata=adata, train_frac=0.7, test_frac=0.2, val_frac=0.1)
pt_dm = setup_simple_datamodule(adata=adata, train_frac=0.7, test_frac=0.2, val_frac=0.1, include_exprs=True,
obs_fields=["label"])





In [5]:
# we first initialize a class 
model = scLightning(n_vars = adata.n_vars,
                    n_classes=adata.obs["cell_type"].nunique(), 
                    feature_var="X", 
                    label_var="obs_label")

In [6]:
# define a pytorch trainer
trainer = pl.Trainer(devices=1, max_epochs=2, enable_checkpointing=False)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/Users/adr/miniconda3/envs/scvi_hack/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:67: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default


In [7]:
# lightning handles all of the training
trainer.fit(model, pt_dm)


  | Name  | Type | Params
-------------------------------
0 | model | MLP  | 139 K 
-------------------------------
139 K     Trainable params
0         Non-trainable params
139 K     Total params
0.557     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/adr/miniconda3/envs/scvi_hack/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:492: Your `val_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.
/Users/adr/miniconda3/envs/scvi_hack/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:436: Consider setting `persistent_workers=True` in 'val_dataloader' to speed up the dataloader worker initialization.


                                                                           

/Users/adr/miniconda3/envs/scvi_hack/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:436: Consider setting `persistent_workers=True` in 'train_dataloader' to speed up the dataloader worker initialization.
/Users/adr/miniconda3/envs/scvi_hack/lib/python3.9/site-packages/pytorch_lightning/loops/fit_loop.py:293: The number of training batches (22) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 1: 100%|██████████| 22/22 [00:10<00:00,  2.07it/s, v_num=13, train/loss=0.612, val/loss=0.525, val/acc=0.934, train/acc=0.715]

`Trainer.fit` stopped: `max_epochs=2` reached.


Epoch 1: 100%|██████████| 22/22 [00:10<00:00,  2.07it/s, v_num=13, train/loss=0.612, val/loss=0.525, val/acc=0.934, train/acc=0.715]


In [8]:
from sklearn.preprocessing import OneHotEncoder, LabelEncoder
import numpy as np

In [9]:
encoder_study = OneHotEncoder(sparse=False, dtype=np.float32)
encoder_study.fit(adata.obs['study'].to_numpy()[:, None])
encoder_celltype = LabelEncoder()
encoder_celltype.fit(adata.obs['cell_type'])



In [10]:
encoders = {
    'obs': {
        'study': lambda s: encoder_study.transform(s.to_numpy()[:, None]),
        'cell_type': encoder_celltype.transform
    }
}

In [11]:
t = ad.experimental.AnnLoader(adata, batch_size=128, shuffle=True, convert=encoders)

In [12]:
for b in t:
    print(t)

<anndata.experimental.pytorch._annloader.AnnLoader object at 0x7fcc99becc40>
<anndata.experimental.pytorch._annloader.AnnLoader object at 0x7fcc99becc40>
<anndata.experimental.pytorch._annloader.AnnLoader object at 0x7fcc99becc40>
<anndata.experimental.pytorch._annloader.AnnLoader object at 0x7fcc99becc40>
<anndata.experimental.pytorch._annloader.AnnLoader object at 0x7fcc99becc40>
<anndata.experimental.pytorch._annloader.AnnLoader object at 0x7fcc99becc40>
<anndata.experimental.pytorch._annloader.AnnLoader object at 0x7fcc99becc40>
<anndata.experimental.pytorch._annloader.AnnLoader object at 0x7fcc99becc40>
<anndata.experimental.pytorch._annloader.AnnLoader object at 0x7fcc99becc40>
<anndata.experimental.pytorch._annloader.AnnLoader object at 0x7fcc99becc40>
<anndata.experimental.pytorch._annloader.AnnLoader object at 0x7fcc99becc40>
<anndata.experimental.pytorch._annloader.AnnLoader object at 0x7fcc99becc40>
<anndata.experimental.pytorch._annloader.AnnLoader object at 0x7fcc99becc40>