In [1]:
import pandas as pd 
import numpy as np 
import scanpy as sc 
import anndata as ann 
from mudata import MuData
import mudata as md 
from sklearn.model_selection import train_test_split
import torch

In [2]:
## Load the MuData object (This is a big file - should have at last 16GB of RAM to prevent memory issues)
mdata = md.read("../data/multi.h5mu")



First function - Scenario 1. 

In this scenario, the MuonData object is split such that there is a 60% training set, 10% validation set, and 30% test set. 

In [37]:

def scenario_1_split(mudata_object, train_size = 0.6, val_size = 0.1, test_size = 0.3, seed = 42):
    """
    Split the MuonData object into training, validation, and test sets. In this scenario, the split is 
    based on the rows (cells) of the MuonData object.
    
    Parameters:
    mudata_object (MuonData): MuonData object to split
    train_size (float): Proportion of the dataset to include in the training set
    val_size (float): Proportion of the dataset to include in the validation set
    test_size (float): Proportion of the dataset to include in the test set
    
    Returns:
    train_data (MuonData): MuonData object with annotations for train, val, test splits
    """
    # Get the indices of the mdata by enumerating over shape
    mudata_indices = [i for i in range(mudata_object.shape[0])]
    
    # Set the seed for reproducibility
    np.random.seed(seed)
    
    # Split the indices based on the specified test/train/validation sizes
    train_indices, test_indices = train_test_split(mudata_indices, train_size=train_size)
    val_indices, test_indices = train_test_split(test_indices, train_size=(val_size/(val_size + test_size)))
    
    # Add the indices to the obs of the ADT and SCT aspects of the MuonData object
    mudata_object["ADT"].obs["Split"] = "None"
    mudata_object["ADT"].obs["Split"][train_indices] = "Train"
    mudata_object["ADT"].obs["Split"][val_indices] = "Validation"
    mudata_object["ADT"].obs["Split"][test_indices] = "Test"
    
    mudata_object["SCT"].obs["Split"] = "None"
    mudata_object["SCT"].obs["Split"][train_indices] = "Train"
    mudata_object["SCT"].obs["Split"][val_indices] = "Validation"
    mudata_object["SCT"].obs["Split"][test_indices] = "Test"
    
    # Return the annotated 
    return mudata_object

In [38]:
train_data, val_data, test_data = scenario_1_split(mdata)

In [39]:
train_data

In [40]:
test_data

In [41]:
mdata["SCT"]

View of AnnData object with n_obs × n_vars = 250 × 20729
    var: 'highly_variable'
    uns: 'pca', 'spca'
    obsm: 'X_pca', 'X_spca'
    varm: 'PCs', 'spca'
    layers: 'counts'
    obsp: 'wknn', 'wsnn'

Second function - Scenario 2. 

In this scenario, the MuonData object is split along the ADT axis based on the features, such that there is a 70% training set, 15% validation set, and 15% test set.

In [42]:
def scenario_2_split(mudata_object, train_size = 0.7, val_size = 0.15, test_size = 0.15, seed = 42):
    """
    Split the MuonData object into training, validation, and test sets. In this case, the split is 
    amongst the vars of the ADT aspect of the MuonData object.
    
    Parameters:
    mudata_object (MuonData): MuonData object to split
    train_size (float): Proportion of the dataset to include in the training set
    val_size (float): Proportion of the dataset to include in the validation set
    test_size (float): Proportion of the dataset to include in the test set
    
    Returns:
    mudata_object (MuonData): MuonData object with the ADT vars annotated with the split
    """
    # Get the indices of the ADT vars by enumerating over shape
    adt_var_indices = [i for i in range(mudata_object["ADT"].var.shape[0])]
    
    # Set the seed for reproducibility
    np.random.seed(seed)
    
    # Split the indices based on the specified test/train/validation sizes
    train_indices, test_indices = train_test_split(adt_var_indices, train_size=train_size)
    val_indices, test_indices = train_test_split(test_indices, train_size=(val_size/(val_size + test_size)))
    
    
    # Subset the mudata objects based on indices and return train/test/validation
    mudata_object["ADT"].var["Split"] = "None"
    mudata_object["ADT"].var["Split"][train_indices] = "Train"
    mudata_object["ADT"].var["Split"][val_indices] = "Validation"
    mudata_object["ADT"].var["Split"][test_indices] = "Test"
    
    # Return the annodated split data
    return mudata_object

In [43]:
mdata_test = scenario_2_split(mdata)

You are setting values through chained assignment. Currently this works in certain cases, but when using Copy-on-Write (which will become the default behaviour in pandas 3.0) this will never work to update the original DataFrame or Series, because the intermediate object on which we are setting values will behave as a copy.
A typical example is when you are setting values in a column of a DataFrame, like:

df["col"][row_indexer] = value

Use `df.loc[row_indexer, "col"] = values` instead, to perform the assignment in a single step and ensure this keeps updating the original `df`.

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy

  mudata_object["ADT"].var["Split"][train_indices] = "Train"
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  muda

Now, define a function to extract the relevant arrays and features of the MuonData split objects. Three options will be used:

- Option 1: Return AnnData objects of the Train, Validation and Test sets.
- Option 2: Return the arrays of the Train, Validation and Test sets.
- Option 3: Return torch tensors of the Train, Validation and Test sets.

In both cases, the SCT and ADT subsets will be used.

In [44]:
def adt_split_returns(scenario = 1, mudata_train = None, mudata_val = None, mudata_test = None, return_type = "AnnData"):
    # Scenario 1 conditions 
    if scenario == 1:
        if return_type == "AnnData":
            train_adt_adata = mudata_train["ADT"].copy()
            train_sct_adata = mudata_train["SCT"].copy()
            
            val_adt_adata = mudata_val["ADT"].copy()
            val_sct_adata = mudata_val["SCT"].copy()
            
            test_adt_adata = mudata_test["ADT"].copy()
            test_sct_adata = mudata_test["SCT"].copy()
            
            return train_adt_adata, train_sct_adata, val_adt_adata, val_sct_adata, test_adt_adata, test_sct_adata
            
        elif return_type == "Numpy":
            train_adt_adata = mudata_train["ADT"].X
            train_sct_adata = mudata_train["SCT"].X.todense()
            
            val_adt_adata = mudata_val["ADT"].X
            val_sct_adata = mudata_val["SCT"].X.todense()
            
            test_adt_adata = mudata_test["ADT"].X
            test_sct_adata = mudata_test["SCT"].X.todense()
            
            return train_adt_adata, train_sct_adata, val_adt_adata, val_sct_adata, test_adt_adata, test_sct_adata
            
        elif return_type == "Torch":
            train_adt_adata = torch.tensor(mudata_train["ADT"].X)
            train_sct_adata = torch.tensor(mudata_train["SCT"].X.todense())
            
            val_adt_adata = torch.tensor(mudata_val["ADT"].X)
            val_sct_adata = torch.tensor(mudata_val["SCT"].X.todense())
            
            test_adt_adata = torch.tensor(mudata_test["ADT"].X)
            test_sct_adata = torch.tensor(mudata_test["SCT"].X.todense())
            
            return train_adt_adata, train_sct_adata, val_adt_adata, val_sct_adata, test_adt_adata, test_sct_adata
        else:
            raise ValueError("Return type must be 'AnnData', 'Numpy', or 'Torch'")
        
    # Scenario 2 conditions 
    elif scenario == 2:
        if return_type == "AnnData":
            train_sct_adata = mudata_train["SCT"].copy()
            val_sct_adata = mudata_train["SCT"].copy()
            test_sct_adata = mudata_train["SCT"].copy()
            
            train_adt_adata = mudata_train["ADT"][:, mudata_train["ADT"].var["Split"] == "Train"].copy()
            val_adt_adata = mudata_train["ADT"][:, mudata_train["ADT"].var["Split"] == "Validation"].copy()
            test_adt_adata = mudata_train["ADT"][:, mudata_train["ADT"].var["Split"] == "Test"].copy()
            
            return train_adt_adata, train_sct_adata, val_adt_adata, val_sct_adata, test_adt_adata, test_sct_adata
            
        elif return_type == "Numpy":
            train_sct_adata = mudata_train["SCT"].X.todense()
            val_sct_adata = mudata_train["SCT"].X.todense()
            test_sct_adata = mudata_train["SCT"].X.todense()
            
            train_adt_adata = mudata_train["ADT"][:, mudata_train["ADT"].var["Split"] == "Train"].X
            val_adt_adata = mudata_train["ADT"][:, mudata_train["ADT"].var["Split"] == "Validation"].X
            test_adt_adata = mudata_train["ADT"][:, mudata_train["ADT"].var["Split"] == "Test"].X
            
            return train_adt_adata, train_sct_adata, val_adt_adata, val_sct_adata, test_adt_adata, test_sct_adata
            
        elif return_type == "Torch":
            train_sct_adata = torch.tensor(mudata_train["SCT"].X.todense())
            val_sct_adata = torch.tensor(mudata_train["SCT"].X.todense())
            test_sct_adata = torch.tensor(mudata_train["SCT"].X.todense())
            
            train_adt_adata = torch.tensor(mudata_train["ADT"][:, mudata_train["ADT"].var["Split"] == "Train"].X)
            val_adt_adata = torch.tensor(mudata_train["ADT"][:, mudata_train["ADT"].var["Split"] == "Validation"].X)
            test_adt_adata = torch.tensor(mudata_train["ADT"][:, mudata_train["ADT"].var["Split"] == "Test"].X)
            
            return train_adt_adata, train_sct_adata, val_adt_adata, val_sct_adata, test_adt_adata, test_sct_adata
            
        else:
            raise ValueError("Return type must be 'AnnData', 'Numpy', or 'Torch'")

Test out all combinations of scenarios and options for returns

In [45]:
# Subset the mudata object to manageable size to test the functions 
mdata_full = mdata.copy()
mdata = mdata_full[0:250]



In [46]:
scenario_1_train, scenario_1_val, scenario_1_test = scenario_1_split(mdata)

In [47]:
train_adt_adata, train_sct_adata, val_adt_adata, val_sct_adata, test_adt_adata, test_sct_adata = adt_split_returns(scenario = 1, mudata_train = scenario_1_train, mudata_val = scenario_1_val, mudata_test = scenario_1_test, return_type = "AnnData")

In [48]:
# Print all the returns from previous 
print(train_adt_adata)
print(train_sct_adata)
print(val_adt_adata)
print(val_sct_adata)
print(test_adt_adata)
print(test_sct_adata)

AnnData object with n_obs × n_vars = 150 × 228
    var: 'highly_variable', 'Split'
    uns: 'apca'
    obsm: 'X_apca'
    varm: 'apca'
    layers: 'counts'
AnnData object with n_obs × n_vars = 150 × 20729
    var: 'highly_variable'
    uns: 'pca', 'spca'
    obsm: 'X_pca', 'X_spca'
    varm: 'PCs', 'spca'
    layers: 'counts'
    obsp: 'wknn', 'wsnn'
AnnData object with n_obs × n_vars = 25 × 228
    var: 'highly_variable', 'Split'
    uns: 'apca'
    obsm: 'X_apca'
    varm: 'apca'
    layers: 'counts'
AnnData object with n_obs × n_vars = 25 × 20729
    var: 'highly_variable'
    uns: 'pca', 'spca'
    obsm: 'X_pca', 'X_spca'
    varm: 'PCs', 'spca'
    layers: 'counts'
    obsp: 'wknn', 'wsnn'
AnnData object with n_obs × n_vars = 75 × 228
    var: 'highly_variable', 'Split'
    uns: 'apca'
    obsm: 'X_apca'
    varm: 'apca'
    layers: 'counts'
AnnData object with n_obs × n_vars = 75 × 20729
    var: 'highly_variable'
    uns: 'pca', 'spca'
    obsm: 'X_pca', 'X_spca'
    varm: 'PCs'

In [49]:
# Run the same function but with numpy return type
train_adt_adata, train_sct_adata, val_adt_adata, val_sct_adata, test_adt_adata, test_sct_adata = adt_split_returns(scenario = 1, mudata_train = scenario_1_train, mudata_val = scenario_1_val, mudata_test = scenario_1_test, return_type = "Numpy")

In [50]:
# Print all the returns from previous and their shapes 
print(train_adt_adata)
print(train_sct_adata)
print(val_adt_adata)
print(val_sct_adata)
print(test_adt_adata)
print(test_sct_adata)

print(train_adt_adata.shape)
print(train_sct_adata.shape)
print(val_adt_adata.shape)
print(val_sct_adata.shape)
print(test_adt_adata.shape)
print(test_sct_adata.shape)


[[1.28239354 0.91628971 1.28239354 ... 0.5368004  0.94738028 2.07283937]
 [1.52854319 0.57317092 0.82846665 ... 0.22949215 0.57317092 1.27525384]
 [0.50795782 0.33442464 0.65578345 ... 0.42495074 1.1372111  0.72223506]
 ...
 [1.61703022 0.64635629 1.13783482 ... 0.69789913 0.99923203 1.98225094]
 [0.48059421 0.95716045 0.90861493 ... 0.95716045 0.90861493 1.37669036]
 [0.51141641 0.7892629  1.05411149 ... 0.51141641 0.58844308 1.05411149]]
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]
[[0.60106031 1.11845726 1.02440787 ... 0.67365535 1.07253782 1.35691226]
 [0.3748766  0.92268608 0.92268608 ... 0.794273   1.08884861 1.03647001]
 [0.99063706 0.87304651 2.32839216 ... 0.99063706 0.87304651 1.88354447]
 ...
 [2.95675998 0.58189022 1.46222571 ... 0.43992601 0.7061833  0.66644439]
 [1.41009148 0.79851371 1.2128301  ... 0.63990783 0.63990783 2.09050096]
 [0.28392384 0.83868676 0.6856

In [51]:
# Test the function with torch return type
train_adt_adata, train_sct_adata, val_adt_adata, val_sct_adata, test_adt_adata, test_sct_adata = adt_split_returns(scenario = 1, mudata_train = scenario_1_train, mudata_val = scenario_1_val, mudata_test = scenario_1_test, return_type = "Torch")

In [52]:
# Print all the returns from previous and their shapes 
print(train_adt_adata)
print(train_sct_adata)
print(val_adt_adata)
print(val_sct_adata)
print(test_adt_adata)
print(test_sct_adata)

print(train_adt_adata.shape)
print(train_sct_adata.shape)
print(val_adt_adata.shape)
print(val_sct_adata.shape)
print(test_adt_adata.shape)
print(test_sct_adata.shape)


tensor([[1.2824, 0.9163, 1.2824,  ..., 0.5368, 0.9474, 2.0728],
        [1.5285, 0.5732, 0.8285,  ..., 0.2295, 0.5732, 1.2753],
        [0.5080, 0.3344, 0.6558,  ..., 0.4250, 1.1372, 0.7222],
        ...,
        [1.6170, 0.6464, 1.1378,  ..., 0.6979, 0.9992, 1.9823],
        [0.4806, 0.9572, 0.9086,  ..., 0.9572, 0.9086, 1.3767],
        [0.5114, 0.7893, 1.0541,  ..., 0.5114, 0.5884, 1.0541]],
       dtype=torch.float64)
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], dtype=torch.float64)
tensor([[0.6011, 1.1185, 1.0244,  ..., 0.6737, 1.0725, 1.3569],
        [0.3749, 0.9227, 0.9227,  ..., 0.7943, 1.0888, 1.0365],
        [0.9906, 0.8730, 2.3284,  ..., 0.9906, 0.8730, 1.8835],
        ...,
        [2.9568, 0.5819, 1.4622,  ..., 0.4399, 0.7062, 0.6664],
        [1.4101, 0.7985, 1.212

In [53]:
# Do the same for scenario 2
mdata_test = scenario_2_split(mdata)

  mudata_object["ADT"].var["Split"] = "None"
You are setting values through chained assignment. Currently this works in certain cases, but when using Copy-on-Write (which will become the default behaviour in pandas 3.0) this will never work to update the original DataFrame or Series, because the intermediate object on which we are setting values will behave as a copy.
A typical example is when you are setting values in a column of a DataFrame, like:

df["col"][row_indexer] = value

Use `df.loc[row_indexer, "col"] = values` instead, to perform the assignment in a single step and ensure this keeps updating the original `df`.

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy

  mudata_object["ADT"].var["Split"][train_indices] = "Train"
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexi

In [54]:
# Get the returns for scenario 2
train_adt_adata, train_sct_adata, val_adt_adata, val_sct_adata, test_adt_adata, test_sct_adata = adt_split_returns(scenario = 2, mudata_train = mdata_test, return_type = "AnnData")

print(train_adt_adata)
print(train_sct_adata)
print(val_adt_adata)
print(val_sct_adata)
print(test_adt_adata)
print(test_sct_adata)


AnnData object with n_obs × n_vars = 250 × 159
    var: 'highly_variable', 'Split'
    uns: 'apca'
    obsm: 'X_apca'
    varm: 'apca'
    layers: 'counts'
AnnData object with n_obs × n_vars = 250 × 20729
    var: 'highly_variable'
    uns: 'pca', 'spca'
    obsm: 'X_pca', 'X_spca'
    varm: 'PCs', 'spca'
    layers: 'counts'
    obsp: 'wknn', 'wsnn'
AnnData object with n_obs × n_vars = 250 × 34
    var: 'highly_variable', 'Split'
    uns: 'apca'
    obsm: 'X_apca'
    varm: 'apca'
    layers: 'counts'
AnnData object with n_obs × n_vars = 250 × 20729
    var: 'highly_variable'
    uns: 'pca', 'spca'
    obsm: 'X_pca', 'X_spca'
    varm: 'PCs', 'spca'
    layers: 'counts'
    obsp: 'wknn', 'wsnn'
AnnData object with n_obs × n_vars = 250 × 35
    var: 'highly_variable', 'Split'
    uns: 'apca'
    obsm: 'X_apca'
    varm: 'apca'
    layers: 'counts'
AnnData object with n_obs × n_vars = 250 × 20729
    var: 'highly_variable'
    uns: 'pca', 'spca'
    obsm: 'X_pca', 'X_spca'
    varm: 'PC

In [55]:
# Get the returns for scenario 2 with numpy return type
train_adt_adata, train_sct_adata, val_adt_adata, val_sct_adata, test_adt_adata, test_sct_adata = adt_split_returns(scenario = 2, mudata_train = mdata_test, return_type = "Numpy")

# Print all the returns from previous and their shapes 
print(train_adt_adata)
print(train_sct_adata)
print(val_adt_adata)
print(val_sct_adata)
print(test_adt_adata)
print(test_sct_adata)

print(train_adt_adata.shape)
print(train_sct_adata.shape)
print(val_adt_adata.shape)
print(val_sct_adata.shape)
print(test_adt_adata.shape)
print(test_sct_adata.shape)


[[1.95916424 0.86914159 1.48523314 ... 0.55307644 1.04608444 1.72565693]
 [0.4322284  1.01422751 0.79594998 ... 0.66587988 0.85514588 1.37971736]
 [0.61381759 1.30390619 0.75610373 ... 0.6874892  0.75610373 1.04246048]
 ...
 [0.516672   0.79642125 1.23406963 ... 0.79642125 0.73347434 1.23406963]
 [0.34911635 0.81245902 1.17196872 ... 0.44281582 0.60738681 0.52848296]
 [0.71403457 0.78440454 1.07689159 ... 0.55643281 0.46721908 1.02484115]]
[[0.         0.69314718 0.         ... 0.         0.         0.        ]
 [0.         0.         0.         ... 0.         0.         0.        ]
 [0.         0.         0.         ... 0.         0.         0.        ]
 ...
 [0.         0.         0.         ... 0.         0.         0.        ]
 [0.         0.         0.         ... 0.         0.         0.        ]
 [0.         0.69314718 0.         ... 0.         0.         0.        ]]
[[0.31429728 0.08830771 0.31429728 ... 0.08830771 1.16807008 0.37954604]
 [0.23928732 0.34040395 0.34040395 ... 

In [56]:
# Get the returns for scenario 2 with torch return type
train_adt_adata, train_sct_adata, val_adt_adata, val_sct_adata, test_adt_adata, test_sct_adata = adt_split_returns(scenario = 2, mudata_train = mdata_test, return_type = "Torch")

# Print all the returns from previous and their shapes 
print(train_adt_adata)
print(train_sct_adata)
print(val_adt_adata)
print(val_sct_adata)
print(test_adt_adata)
print(test_sct_adata)

print(train_adt_adata.shape)
print(train_sct_adata.shape)
print(val_adt_adata.shape)
print(val_sct_adata.shape)
print(test_adt_adata.shape)
print(test_sct_adata.shape)



tensor([[1.9592, 0.8691, 1.4852,  ..., 0.5531, 1.0461, 1.7257],
        [0.4322, 1.0142, 0.7959,  ..., 0.6659, 0.8551, 1.3797],
        [0.6138, 1.3039, 0.7561,  ..., 0.6875, 0.7561, 1.0425],
        ...,
        [0.5167, 0.7964, 1.2341,  ..., 0.7964, 0.7335, 1.2341],
        [0.3491, 0.8125, 1.1720,  ..., 0.4428, 0.6074, 0.5285],
        [0.7140, 0.7844, 1.0769,  ..., 0.5564, 0.4672, 1.0248]],
       dtype=torch.float64)
tensor([[0.0000, 0.6931, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.6931, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],
       dtype=torch.float64)
tensor([[0.3143, 0.0883, 0.3143,  ..., 0.0883, 1.1681, 0.3795],
        [0.2393, 0.3404, 0.3404,  ..., 0.0000, 3.0702, 0.1268],
        [0.1321, 0.1

In [57]:

print(train_adt_adata.shape)
print(val_adt_adata.shape)
print(test_adt_adata.shape)


torch.Size([250, 159])
torch.Size([250, 34])
torch.Size([250, 35])
