The aim of this notebook is to demonstrate how to use a custom DataLoader function in pytorch to create a dataloader that will return only samples from one "single cell batch"/sample at a time. I.e. each minibatch contains only cells from one sample/single-cell batch. 

Then from there, batchnorm can be utilized in the manner outlined in https://www.biorxiv.org/content/10.1101/2022.10.14.512286v1.full.pdf

In [121]:
import random
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader

First create some synthetic data of 3 classes (we can replace classes with single-cell batches/samples in our case) and normally distributed data with 10 features. Being sure to convert each data to a tensor

In [37]:
classes = np.array([0, 1, 2])
X = np.random.normal(size=(1000, 10))
y = np.random.choice(classes, 1000)

X = torch.from_numpy(X)
y = torch.from_numpy(y)

Now we're going to first create a custom Dataset object (which goes into the DataLoader class). I realized it will be easy to modify the DataLoader class for our purposes. We're going to just slightly modify the dataset class to return both labels and data, as well as the counts for each class and indices associated with those classes.

In [112]:
class TestDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y

    # Get the total counts (number of cells/samples) from each
    # class/batch
    def get_class_counts(self):
        unique_classes = torch.unique(self.y)
        class_counts = {}
        for i in unique_classes:
            class_counts[i.item()] = len(self.X[self.y == i])
        return class_counts

    # Get the indices corresponding to each class in the 
    # given dataset 
    def get_class_indices(self):
        unique_classes = torch.unique(self.y)
        class_indices = {}
        for i in unique_classes:
            class_indices[i.item()] = torch.where(self.y == i)[0]
        return class_indices

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        X_sample = X[idx]
        y_sample = y[idx]
        return X_sample, y_sample

This code above should be fairly straightforward - it's just adding the fact that we're returning both the data and the class/sample label, and some custom functions to return class counts and indices.

In [113]:
test_ds = TestDataset(X=X, y=y)

<__main__.TestDataset at 0x7fe647c81990>

Now I'm going to add a custom **Dataloader** that will wrap the torch dataloader class but it will have some extra bells and whistles. I'll add code comments explaining each aspect. 

In [192]:
class SingleBatchLoader(torch.utils.data.DataLoader):
    # In this initialization attr, we're going to indicate our dataset, call the class_counts method 
    # from out dataset object (to get the class counts for this DataLoader), and call the 
    # class_indices method from our dataset object to get the corresponding indices 
    # We're also going to call the class internal function (marked with a _) to get our iterator,
    # which is going to do our special magic for our dataset and return mini-batches that contain only 
    # one class/sample-type 
    def __init__(self, dataset, batch_size=1):
        super().__init__(dataset, batch_size=batch_size)
        self.class_counts = dataset.get_class_counts()
        self.class_indices = dataset.get_class_indices()
        self.class_iter = self._create_class_iter()

    # This is the internal function that creates our iterator (which is what is created when we 
    # loop over our DataLoader in a for-loop. Each iteration for the iter calls __next__, which is 
    # defined below)
    def _create_class_iter(self):
        # Here we're essentially shuffling the indices for each class - this will happen during 
        # each epoch and ensure that each epoch gets shuffled data 
        for label, indices in self.class_indices.items():
            index_len = indices.shape[0]
            index_perm = torch.randperm(index_len)
            self.class_indices[label] = indices[index_perm]
        
        # Create a copy of the labels and indices in each of these - remember 
        # we don't want to modify anything globally and we're removing indices
        # from the dictionaries, so we don't want to retain those changes for 
        # the next epoch
        class_indices_copy = self.class_indices.copy()
        class_counts_copy = self.class_counts.copy()
        
        # Now, in our actual iterator, we're going to Randomly sample a class and indices from that class
        # until the dataset is exhausted. The while condition says that at least one class index (which is 
        # indicated as key in the class_counts dictionary) is present 
        while len(class_counts_copy) >=1:
            # Randomly select a class and its indices 
            label, indices = random.choice(list(class_indices_copy.items()))
            # If the remaining indices for that class are less than the batch size, go to the next 
            # iteration. Also pop (remove) that class from the counts and indices dictionaries 
            if len(indices) < self.batch_size:
                class_counts_copy.pop(label)
                class_indices_copy.pop(label)
                continue
            # If we have enough data for that class, get data up to the batch size 
            yield label, indices[:self.batch_size]
            # Remove the selected data from the current indices of that class
            class_indices_copy[label] = indices[self.batch_size:]

    # This function defines our actual __iter__ call - what is called when we run the DataLoader
    # in a for-loop. We're going to shuffle the classes in the counts and indices dictionaries,
    # and then use our internal _create_class_iter() function to get the iterator 
    def __iter__(self):
        # Shuffle the order of the classes (this is probably not necessary as we're randomly
        # selecting each class anyway, but adds another layer of randomness)
        labels = list(self.class_indices.keys())
        random.shuffle(labels)
        self.class_indices = {label:self.class_indices[label] for label in labels}
        self.class_counts = {label:self.class_counts[label] for label in labels}
        # Get and return the class iterable (our dataloader iterable)
        self.class_iter = self._create_class_iter()
        return self

    # Now for each iteration in the for-loop, this next function is called
    def __next__(self):
        try:
            label, indices = next(self.class_iter)
            return self.dataset[indices]
        except StopIteration:
            raise StopIteration

    # Just returns the total counts
    def __len__(self):
        return sum(count for _, count in self.class_counts)

Now let's go ahead and try this for a minibatch size of 10

In [193]:
ds_loader = SingleBatchLoader(test_ds, batch_size = 10)

Go ahead and iterate over this and see what the batch outputs are - remember we're outputting both the 
data corresponding to each class as well as our class labels (based on the TestData class)

In [194]:
for batch in ds_loader:
    print(batch)

(tensor([[-1.2019,  0.2572,  1.2555,  0.0743, -0.2201,  0.0190,  0.7776, -0.2258,
          0.2286, -0.6700],
        [-0.4160, -0.1224,  0.0073,  1.0665,  1.7345, -0.3340, -0.6029,  1.3700,
          0.0373, -1.3256],
        [ 1.6986, -0.2955,  0.8458, -0.0628, -1.2133,  0.9620,  0.0458,  0.7612,
         -0.4856,  0.3673],
        [-0.2259, -0.9570, -0.4016, -0.7993, -0.2500,  0.8118,  1.0402, -0.3954,
         -1.7297, -1.1281],
        [-1.4883,  0.4619, -1.7759,  1.6886,  0.7106,  0.9437,  0.3989, -0.5394,
          1.3072,  0.1519],
        [ 0.9417,  0.4657,  1.9614, -0.0810,  0.8562,  0.0799, -0.0155, -0.2217,
          0.7630, -0.3516],
        [ 0.4878,  0.6722, -1.4734,  0.2221,  1.2016, -0.5593, -0.2113, -0.6407,
          0.5540,  1.5406],
        [ 0.5790, -0.1961,  0.2515, -1.4647, -0.4651, -1.1377, -0.0721,  0.1540,
         -1.3218, -0.2014],
        [ 0.4862, -0.4409, -0.4686,  0.1159,  0.0962, -1.1628,  1.1368, -1.3923,
         -0.2764,  0.5102],
        [-1.1073, 

Things look like they are working as intended, as each mini-batch contains data from one class. We can replace these 'classes' with single-cell batches from our data. 

I'll leave the more thorough checks (e.g. does it work over many epochs?) to you guys. Happy DataLoading