In [80]:
# dataloader.py

import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import scipy.sparse as sp
import scanpy as sc
import random

class RegimeDataset(Dataset):
    def __init__(
        self,
        adata: sc.AnnData,
        K: int,
        unspliced_key: str = 'unspliced',
        spliced_key: str = 'spliced',
        latent_key: str = 'z',
        nn_key: str = 'indices',
    ):
        self.adata        = adata
        self.K            = K
        self.unspliced_key = unspliced_key
        self.spliced_key  = spliced_key
        self.latent_key   = latent_key

        # kNN indices from adata.uns['indices']
        indices = adata.uns.get(nn_key)
        if indices is None:
            raise KeyError(f"adata.uns['{nn_key}'] not found")
        self.nn_indices = torch.from_numpy(np.asarray(indices, dtype=np.int64))

        # full unspliced+spliced counts
        u = adata.layers[unspliced_key]
        s = adata.layers[spliced_key]
        u = u.toarray().astype(np.float32) if sp.issparse(u) else np.asarray(u, dtype=np.float32)
        s = s.toarray().astype(np.float32) if sp.issparse(s) else np.asarray(s, dtype=np.float32)
        full = np.concatenate([u, s], axis=1)
        self.x = torch.from_numpy(full)

        # will load latent z at switch to regime 2
        self.latent_data = None
        self.first_regime = True

    def set_regime(self, first: bool):
        """Switch between expression‐based regime (first) and latent‐based (second)."""
        self.first_regime = first
        if not first and self.latent_data is None:
            z = self.adata.obsm[self.latent_key]
            self.latent_data = torch.from_numpy(np.asarray(z, dtype=np.float32))

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx: int):
        # exclude self (position 0), take next K neighbors
        neigh_idx = self.nn_indices[idx, 1:self.K+1]
        x_neigh   = self.x[neigh_idx]  # (K, G)

        if self.first_regime:
            x = self.x[idx]
            return x, idx, x_neigh      # you only need x and its neighs
        else:
            x       = self.x[idx]
            z       = self.latent_data[idx]
            z_neigh = self.latent_data[neigh_idx]
            return x, idx, x_neigh, z, z_neigh

def make_dataloader(
    adata: sc.AnnData,
    first_regime: bool = True,
    K: int = 10,
    unspliced_key: str = 'unspliced',
    spliced_key: str = 'spliced',
    latent_key: str = 'z',
    nn_key: str = 'indices',
    batch_size: int = 64,
    shuffle: bool = True,
    num_workers: int = 0,
    seed: int = 0,
) -> DataLoader:
    if seed is not None:
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)

    ds = RegimeDataset(
        adata, K,
        unspliced_key=unspliced_key,
        spliced_key=spliced_key,
        latent_key=latent_key,
        nn_key=nn_key,
    )
    ds.set_regime(first_regime)

    gen = None
    if shuffle and seed is not None:
        gen = torch.Generator()
        gen.manual_seed(seed)

    return DataLoader(
        ds,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        generator=gen,
    )



In [75]:
adata = sc.read_h5ad('/home/lgolinelli/git/lineageVI/input_processed_anndata/pancreas.h5ad')

In [76]:
adata.obsm['z'] = np.random.randn(adata.shape[0], 30)

In [77]:
adata.obsm['z'].shape

(3696, 30)

In [81]:
dl_first = make_dataloader(
        adata,
        first_regime = True,
        K = 10,
        unspliced_key = 'unspliced',
        spliced_key = 'spliced',
        latent_key = 'z',
        batch_size = 64,
        shuffle = True,
        num_workers = 0,
        seed=0)

for epoch in range(1):
    for x_batch, idx_batch, x_neigh_batch in dl_first:
        print(x_batch.shape)
        print(idx_batch.shape)
        print(x_neigh_batch.shape)
        break


torch.Size([64, 3610])
torch.Size([64])
torch.Size([64, 10, 3610])


In [82]:
dl_second = make_dataloader(
        adata,
        first_regime=False,
        unspliced_key = 'unspliced',
        spliced_key = 'spliced',
        latent_key = 'z',
        batch_size = 64,
        shuffle = True,
        num_workers = 0,
        seed=1)

for epoch in range(1):
    for x_batch, idx_batch, x_neigh_batch, z_batch, z_neigh_batch in dl_second:
        print(x_batch.shape)
        print(idx_batch.shape)
        print(x_neigh_batch.shape)
        print(z_batch.shape)
        print(z_neigh_batch.shape)
        break


torch.Size([64, 3610])
torch.Size([64])
torch.Size([64, 10, 3610])
torch.Size([64, 30])
torch.Size([64, 10, 30])
