In [7]:
from datasets import load_from_disk
from transformers import DataCollatorWithPadding
from encoders import BioZorroCollator

ModuleNotFoundError: No module named 'encoders'

In [2]:
!ls ../data
ds = load_from_disk('../data/filtered_protein_mrna_genes').with_format('torch')

filtered_protein_mrna_genes  pbmc_protein.h5ad
pbmc_gene.h5ad		     pbmc_w3_teaseq.h5mu


In [8]:
import torch
from torch import nn
from torch.nn.functional import pad
from torch import Tensor
from typing import Optional


class BioZorroCollator:
    def __init__(self, pad_token=0, pad_len=2048):
        self.pad_token = pad_token
        self.pad_len=pad_len
    def __call__(self, data):#(2)
        collated_data = {k:list() for k in data[0].keys()}
        for d in data:
            for k,v in d.items():
                length = v.shape[-1]
                padded_v = pad(v, (0,self.pad_len-length), mode='constant', value=self.pad_token)
                collated_data[k].append(padded_v)
        for k,v in collated_data.items():
            collated_data[k]=torch.stack(v)
        return collated_data


In [13]:
from torch.utils.data import DataLoader
train_dataloader = DataLoader(
        ds, shuffle=True, collate_fn=BioZorroCollator(), batch_size=16,
    )

In [14]:
print(len(train_dataloader))

89649


In [11]:
#from torch.nn.utils.rnn import pad_sequence
from torch.nn.functional import pad
data = next(iter(train_dataloader))
print(data)

{'total_index': tensor([[ 9, 26, 46,  ...,  0,  0,  0]]), 'spliced_index': tensor([[  9,  46, 275,  ...,   0,   0,   0]]), 'unspliced_index': tensor([[ 26, 136, 196,  ...,   0,   0,   0]]), 'ambiguous_index': tensor([[132, 304, 397,  ...,   0,   0,   0]]), 'total_data': tensor([[1., 1., 1.,  ..., 0., 0., 0.]]), 'spliced_data': tensor([[1., 1., 1.,  ..., 0., 0., 0.]]), 'unspliced_data': tensor([[1., 1., 1.,  ..., 0., 0., 0.]]), 'ambiguous_data': tensor([[1., 1., 1.,  ..., 0., 0., 0.]])}


In [4]:
import muon as mu
# Change directory to the root folder of the repository
import os
os.chdir("../../")

In [5]:
data_dir = "/efs-private/st_perceiver"
mdata = mu.read(f"{data_dir}/pbmc_w3_teaseq.h5mu")

In [26]:
(mdata[:,1].mod['rna'].X == mdata[:,1].mod['rna'].X).sum()

ArrayView(5805)

In [16]:
torch.Tensor([0.0]).to("cpu")
#Maybe one day these will work, but currently they aren't directly compatible with mudata object

#from anndata.experimental.multi_files import AnnCollection
#from anndata.experimental.pytorch import AnnLoader

tensor([0.])

In [None]:
import numpy as np
import torch
from scipy.sparse import csr_matrix
import anndata
def sparse_csr_to_tensor(csr:csr_matrix):
    """
    Transform scipy csr matrix to pytorch sparse tensor
    """

    values = csr.data
    indices = np.vstack(csr.nonzero())
    shape = csr.shape

    i = torch.LongTensor(indices)
    v = torch.FloatTensor(values)
    s = torch.Size(shape)

    return torch.sparse.FloatTensor(i, v, s)
    
def sparse_batch_collate(batch:list):
    """
    Collate function to transform anndata csr view to pytorch sparse tensor
    """
    if type(batch[0]['atac'].X) == anndata._core.views.SparseCSRView:
        atac_batch = sparse_csr_to_tensor(np.vstack([x['atac'].X for x in batch]))
    else:
        atac_batch = torch.FloatTensor(np.vstack([x['atac'].X for x in batch]))

    if type(batch[0]['rna'].X) == anndata._core.views.SparseCSRView:
        rna_batch = sparse_csr_to_tensor(np.vstack([x['rna'].X for x in batch]))
    else:
        rna_batch = torch.FloatTensor(np.vstack([x['rna'].X for x in batch]))
    
    if type(batch[0]['prot'].X) == anndata._core.views.SparseCSRView:
        prot_batch = sparse_csr_to_tensor(np.vstack([x['prot'].X for x in batch]))
    else:
        prot_batch = torch.FloatTensor(np.vstack([x['prot'].X for x in batch]))

    return atac_batch, rna_batch, prot_batch


loader = DataLoader(
    mdata,
    batch_size=10,
    collate_fn = sparse_batch_collate,
)

In [None]:
import matplotlib.pyplot as plt
sample = next(iter(loader))

In [None]:
sample[1].shape

In [None]:
if type(batch[0]['atac'].X) == anndata._core.views.SparseCSRView:
        atac_batch = sparse_csr_to_tensor(np.vstack([x['atac'].X for x in batch]))
    else:
        atac_batch = torch.FloatTensor(np.vstack([x['atac'].X for x in batch]))


In [1]:
import sys
sys.path.append('/efs-private/st_perceiver')
from biozorromodel import BioZorro, TokenTypes as T
from mudataloader import get_dataloader

  @numba.jit()
  @numba.jit()
  @numba.jit()
  from .autonotebook import tqdm as notebook_tqdm
  @numba.jit()


In [48]:
model = BioZorro(
                        512, #dim,
                        6, #depth,
                        96162, #16381, #rna_input_dim,
                        16381, #96162, #atac_input_dim,
                        dim_head = 64,
                        heads = 8,
                        ff_mult = 4,
                        num_fusion_tokens = 16,
                        return_token_types = (
                            T.RNA,
                            T.FUSION,
                            T.GLOBAL,
                            T.ATAC,
                            )
                        )

In [49]:
dataloader = get_dataloader('/efs-private/st_perceiver/pbmc_w3_teaseq.h5mu', batch_size=1024)

In [50]:
loader = iter(dataloader)
print(len(next(loader)))
print(next(loader)[0].shape)

3
torch.Size([1024, 96162])


In [51]:
rna, atac, prot = next(loader)
res = model(rna=rna, atac=atac)

In [52]:
from torchmultimodal.modules.losses.contrastive_loss_with_temperature import ContrastiveLossWithTemperature

In [53]:
rna, atac, prot = next(loader)
res = model(rna=rna, atac=atac)
print(res[:,0,:].squeeze().shape)
loss = ContrastiveLossWithTemperature()
loss(res[:,0,:].squeeze(), res[:,3,:].squeeze())

torch.Size([1024, 512])


tensor(24.6985, grad_fn=<DivBackward0>)

In [54]:
model = model.to('cuda')
rna, atac, prot = next(loader)
res = model(rna=rna.to('cuda'), atac=atac.to('cuda'))

In [96]:
!ls ../

Notebooks	  encoders.py	   neuron_utils.py  run_trn.slurm
biozorromodel.py  env.yaml	   run_trn.py	    train.py
data		  mudataloader.py  run_trn.sh


In [None]:
epochs=100
optimizer = AdamW(model.parameters(), lr=0.0001)
num_training_steps = epochs * len(train_dl)
progress_bar = tqdm(range(num_training_steps))

print("Start training: {}".format(strftime("%Y-%m-%d %H:%M:%S", gmtime())))
## Start model training and defining the training loop
model.train()
for epoch in range(epochs):
    for batch in train_device_loader:
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        optimizer.zero_grad()
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        progress_bar.update(1)
    print(f"Epoch {epoch}: loss: {loss.detach()}")
        #
    #if xm.is_master_ordinal(local=False):
    wandb.log({"epoch_loss":loss.detach().to("cpu")})

logger.info("End training: {}".format(strftime("%Y-%m-%d %H:%M:%S", gmtime())))
