In [2]:
import pandas as pd 
import numpy as np 
import scanpy as sc 
import anndata as ann 
from mudata import MuData
import muon as mu
import mudata as md 
from sklearn.model_selection import train_test_split
import torch



In [3]:
## Load the MuData object (This is a big file - should have at last 16GB of RAM to prevent memory issues)
mdata = md.read("../data/multi.h5mu")



In [4]:
from torch.utils.data import Dataset, DataLoader

Create a CITEData class to store the data and the operations to be performed on the data. The class will have the following methods:

- init
- format
- normalize 

In [None]:
def CITEData(Dataset):
    """
    Dataset class for storing the CITE-seq data, and creating dataloaders for training and testing the model.    
    """
    
    def __init__(self, mudata, esm_embeddings = False, scgpt_embeddings = False,
                 orthrus_embeddings = False, hvgs = False, hvg_values = None):
        
        self.mudata = mudata
        
        # For each of the embeddings, create arrays to store them by extracting them from the .obs 
        # columns of the MuData objects
        if esm_embeddings:
            # Grep all cols in the adt var that start with "esm_"
            esm_cols = [col for col in self.mudata["ADT"].var.columns if col.startswith("esm_")]
            self.esm_embeddings = np.zeros(((self.mudata["ADT"].shape[1]), len(esm_cols)))
            for col in esm_cols:
                self.esm_embeddings[:, col] = self.mudata["ADT"].var[col].values
                
        else:
            esm_embeddings = None
                
        if orthrus_embeddings:
            # Grep all cols in the adt var that start with "orthrus_"
            orthrus_cols = [col for col in self.mudata["ADT"].var.columns if col.startswith("orthrus_")]
            self.orthrus_embeddings = np.zeros(((self.mudata["ADT"].shape[1]), len(orthrus_cols)))
            for col in orthrus_cols:
                self.orthrus_embeddings[:, col] = self.mudata["ADT"].var[col].values
        else:
            orthrus_embeddings = None
                
        if scgpt_embeddings:
            # Grep all cols in the sct obs that start with "scgpt_"
            scgpt_cols = [col for col in self.mudata["SCT"].obs.columns if col.startswith("scgpt_")] 
            self.scgpt_embeddings = np.zeros(((self.mudata["SCT"].shape[0]), len(scgpt_cols)))
            for col in scgpt_cols:
                self.scgpt_embeddings[:, col] = self.mudata["SCT"].obs[col].values
        else:
            scgpt_embeddings = None
            
        # Get the HVG indices if indicated
        if hvgs:
            self.hvg_indices = hvg_values
        else:
            self.hvg_indices = None
            
        # Extract the counts for ADT and SCT - SCT based on highly variable genes 
        self.adt_counts = self.mudata["ADT"].X
        if hvg_values is not None:
            self.sct_counts = self.mudata["SCT"][:, hvg_values == 1].X
        else:
            self.sct_counts = self.mudata["SCT"].X
            
    def format(self, normalize_sct = False, normalize_adt = False, standardize = False):
        """
        Normalize the data in the MuData object, and format it for training the model.
        """
        # Save the raw data in .raw attributes for both adt and sct
        self.mudata["ADT"].raw = self.mudata["ADT"].X.copy()
        self.mudata["SCT"].raw = self.mudata["SCT"].X.copy()
        
        # Normalize the data if indicated 
        if normalize_sct:
            sc.pp.normalize_total(self.mudata["SCT"], target_sum=1e4)
            sc.pp.log1p(self.mudata["SCT"])
            
        if normalize_adt:
            sc.pp.normalize_total(self.mudata["ADT"], target_sum=1000)
            sc.pp.log1p(self.mudata["ADT"])
        
        # Standardize the data if indicated - ideally we don't do this 
        # to prevent data leakage. The normalize_total and log1p functions
        # are invariant to test/train splits.
        if standardize:
            sc.pp.scale(self.mudata["SCT"])
            sc.pp.scale(self.mudata["ADT"])
            
    def __len__(self):
        return self.mudata["ADT"].shape[0]
    
    def __getitem__(self, idx):
        """
        Get the data for a given index
        """
        if torch.is_tensor(idx):
            idx = idx.tolist()
            
        # This is assuming we're already adding everything - the conditional otherwise is long but can be done too
        return self.sct_counts[idx], self.adt_counts[idx], self.esm_embeddings[idx], self.orthrus_embeddings[idx], self.scgpt_embeddings[idx]
    
            

In [26]:
# Get counts for only highly variable genes for SCT (highly variable is already in the .var attribute)
highly_variable = mdata["SCT"].var.highly_variable


In [32]:
# Subset mdata sct such that if highly_variable ==1, the vars are kept 
sct_subset = mdata["SCT"][:, highly_variable == 1]

In [33]:
sct_subset

View of AnnData object with n_obs × n_vars = 161764 × 5000
    var: 'highly_variable'
    uns: 'pca', 'spca'
    obsm: 'X_pca', 'X_spca'
    varm: 'PCs', 'spca'
    layers: 'counts'
    obsp: 'wknn', 'wsnn'

In [20]:
# Check to see if var names from ADT are in RNA - the lenght is not the same
np.unique(np.isin(mdata["ADT"].var_names, mdata["SCT"].var_names), return_counts=True)

(array([False,  True]), array([184,  44]))

In [31]:
highly_variable

AL627309.1    0
AL669831.5    0
LINC00115     0
FAM41C        0
NOC2L         0
             ..
AC016588.1    0
FAM83E        0
Z82244.2      0
AP001468.1    0
AP001469.2    0
Name: highly_variable, Length: 20729, dtype: uint8