# How it works

**AnnCollectionDataset** directly reads h5ad files through [AnnCollection](https://anndata.readthedocs.io/en/latest/tutorials/notebooks/anncollection.html) and serve data using [LitData](https://github.com/Lightning-AI/litdata) frontend. To use dataset, we first need to prepare dataset index folder that has multiple splits.

# Building LitData index

In [1]:
import os

from bmfm_targets.datasets.anncollection import get_ann_collection
from bmfm_targets.datasets.data_conversion.litdata_indexing import build_index

  warn(


In [2]:
root_dir = "/dccstor/bmfm-targets/data/omics/transcriptome/bulkRNA"

## Make index and data folder using datamodule.prepare_data()

In [3]:
from bmfm_targets import config
from bmfm_targets.datasets.anncollection import AnnCollectionDataModule
from bmfm_targets.tokenization import get_gene2vec_tokenizer

In [4]:
def gene2vec_fields():
    gene2vec_field_dicts = [
        {
            "field_name": "genes",
            "pretrained_embedding": None,
            "is_masked": False,
            "vocab_update_strategy": "static",
        },
        {
            "field_name": "expressions",
            "pretrained_embedding": None,
            "is_masked": True,
            "vocab_update_strategy": "static",
        },
    ]

    gene2vec_fields = [config.FieldInfo(**fd) for fd in gene2vec_field_dicts]
    tokenizer = get_gene2vec_tokenizer()
    for field in gene2vec_fields:
        field.update_vocab_size(tokenizer)
    return gene2vec_fields

In [5]:
def prepare_index():
    dataset_kwargs = {
        "dataset_dir": f"{root_dir}/ALL/GEO_large_dataset",
        "index_dir": f"{root_dir}/ALL/GEO_large_dataset/index_test",
    }
    tokenizer = get_gene2vec_tokenizer()
    pars = {
        "tokenizer": tokenizer,
        "batch_size": 2,
        "fields": gene2vec_fields(),
        "num_workers": 8,
        "mlm": True,
        "sequence_order": "sorted",
        "shuffle": True,
        "collation_strategy": "language_modeling",
        "dataset_kwargs": dataset_kwargs,
        "transform_datasets":True,
    }
    datamodule = AnnCollectionDataModule(**pars)
    datamodule.prepare_data()
    datamodule.setup("fit")

prepare_index()



### Reading hda5 files into annotation collection 
see https://anndata.readthedocs.io/en/latest/tutorials/notebooks/anncollection.html

In [3]:
dataset_dir = os.path.join(root_dir, "ALL")
collection = get_ann_collection(input_dir=dataset_dir)

### Make folder for LitData index and create (test, dev) subfolders with LitData indices

Function build_index takes **index** parameter, an iterable such as Python generator

In [4]:
index_dir = os.path.join(root_dir, "bulkRNA_litdata_index")
os.mkdir(index_dir)
n_cells = collection.n_obs
n_train_split = int(n_cells * 0.9)

build_index(
    output_dir=os.path.join(index_dir, "train"),
    index = range(0, n_train_split),
    chunk_size = 5000
)
build_index(
    output_dir=os.path.join(index_dir, "dev"),
    index = range(n_train_split, n_cells),
    chunk_size = 5000
)

# Testing dataset

### Helper function that is needed only for tests

### Parameters that normally have to be set in yaml file (see PanglaoDB yaml files).

In [7]:
dataset_kwargs = {
    "dataset_dir": dataset_dir,
    "index_dir": index_dir 
}
tokenizer = get_gene2vec_tokenizer()
pars = {
    "tokenizer": tokenizer,
    "batch_size": 2,
    "fields": gene2vec_fields(),
    "num_workers": 0,
    "mlm": True,
    "collation_strategy": "language_modeling",
    "dataset_kwargs": dataset_kwargs,
}

In [8]:
datamodule = AnnCollectionDataModule(**pars)
datamodule.prepare_data()
datamodule.setup("fit")

train_dataloader = datamodule.train_dataloader()
item = next(iter(train_dataloader))
print(item)

{'input_ids': tensor([[[    3,     0,  7402,  ...,  5681,  9529,     1],
         [    3,     4,     0,  ...,    13,     0,     1]],

        [[    3,     0,  7402,  ...,  9826, 22087,     1],
         [    3,    13,     0,  ...,     0,     0,     1]]]), 'labels': tensor([[[-100,    0, -100,  ...,    0, -100, -100]],

        [[-100, -100, -100,  ..., -100, -100, -100]]]), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]])}
