# Exploratory Data Analysis

---

This few-shot benchmark tests various meta-learning methods in the context of biomedical applications. In particular, we are dealing with the [Tabula Muris]() and [SwissProt]() datasets. One is a cell type classification task based on single-cell gene expressions and the other is a protein function prediction task based on protein sequences. The goal of this notebook is to explore basic statistics about the two datasets, as well as understand how the data loading is implemented for the episodic training during meta-training.

## Setup

---

First, let's import the relevant modules needed.

In [None]:
# ruff: noqa: E402
# Reload modules automatically
%load_ext autoreload
%autoreload 2

# Module imports
import os
import sys
import time
import collections
import matplotlib.pyplot as plt
import seaborn as sns

# External imports
import numpy as np
import torch

In [None]:
# Add path to load local modules
sys.path.append("..")

# Set styles
sns.set_style("whitegrid")

## Base Classes

---

Both datasets are implemented as subclasses of the `FewShotDataset` class and use some other generic utility classes. We will explore these here in detail. They are all defined in the `datasets.dataset` module.


### FewShotDataset

The `FewShotDataset(torch.utils.data.Dataset)` is the base class for all few-shot datasets. It implements the `__getitem__` and `__len__` methods and has some utilities for checking the data validty. Furthermore, it is responsible for loading and extracting the dataset into the `root` directory if specified and not yet existent. However, as it is a abstract base class, it cannot be instantiated, e.g. it requires the `_dataset_name` and `_dataset_dir` as class attributes.

In [None]:
# Demo: FewShotDataset
from datasets.dataset import FewShotDataset # noqa

try:
    few_shot_dataset = FewShotDataset()
except Exception as e:
    print(f"❌ Fails with error {e}.")

### FewShotSubDataset

The `FewShotSubDataset(torch.utils.data.Dataset)` is a class used for using only a subset of samples that are in the same class in a PyTorch Dataset class.

In [None]:
# Demo FewShotSubDataset
from datasets.dataset import FewShotSubDataset #noqa

# Create a random dataset with 100 samples, 5 features and 5 classes
samples = torch.rand(100, 5)
targets = torch.randint(0, 5, (100,)) # 5-way
subset_target = 4

# Create a subset dataset for class 4
subset_samples = samples[targets == subset_target]

# Create a few-shot dataset for class 4
few_shot_sub_dataset = FewShotSubDataset(subset_samples, subset_target)

# Sanity checks
assert len(few_shot_sub_dataset) == (targets  == subset_target).sum(), "❌ Length of few-shot dataset is not correct."
assert few_shot_sub_dataset.dim == samples.shape[1], "❌ Dimension of few-shot dataset is not correct."

### Episodic Batch Sampler

The `EpisodicBatchSampler` is a utility class that randomly samples `n_way` classes (out of a totla of `n_classes`) for a total of `n_episodes`. It can be used in episodic training to sample the classes used in each episode.

In [None]:
# Demo: EpisodicBatchSampler
from datasets.dataset import EpisodicBatchSampler # noqa

# Demo of EpisodicBatchSampler
n_episodes, n_way, n_classes = 3, 5, 10
episodic_batch_sampler = EpisodicBatchSampler(n_classes, n_way, n_episodes)

print(f"Episodes: {n_episodes}, Ways: {n_way}, Classes: {n_classes}")
for batch_idx, indices in enumerate(episodic_batch_sampler):
    print(f"Episode {batch_idx+1} w/ classes {indices.numpy()}")

## Tabula Muris

---

**Tabula Muris** is a dataset of single cell transcriptome data (gene expressions) from mice, containing nearly `100,000` cells from `20` organs and tissues. The data allow for direct and controlled comparison of gene expression in cell types shared between tissues, such as immune cells from distinct anatomical locations. They also allow for a comparison of two distinct technical approaches:

*More Resources*: 

- [Tabular Muris Website](https://tabula-muris.ds.czbiohub.org/)
- [SF Biohub Article](https://www.czbiohub.org/sf/tabula-muris/)

### MacaData

The `MacaData` class is responsible for loading and processing the Tabula Muris dataset. Thus, before looking at the `TMSimpleDataset` and `TMSetDataset`, let's investigate the data loading/ processing first.

In [None]:
from datasets.cell.utils import MacaData # noqa

path = os.path.join("..", "data", "tabula_muris", "tabula-muris-comet.h5ad")

start = time.time()
maca_data = MacaData(src_file=path)
annotated_data = maca_data.adata
print(f"⌛ Loaded data in {time.time() - start:.2f} seconds.")

The loading and processing time for all samples takes ~30 seconds. As this function is called on each instantiation of the Tabula Muris dataset, we are loading the data in all splits in ~1.30 minutes. 

We can trivially speed up this time by only processing the cells relevant for the split and by introducing a subsampling flag which will load 10% of the data. The `MacaDataImproved` class inherits from the `MacaData` class and implements these two changes.

In [None]:
from datasets.cell.utils import MacaDataImproved # noqa

path = os.path.join("..", "data", "tabula_muris", "tabula-muris-comet.h5ad")

start = time.time()
MacaDataImproved(src_file=path, mode="train", subset=False).adata
print(f"⌛ Loaded training split in {time.time() - start:.2f} seconds.")

start = time.time()
MacaDataImproved(src_file=path, mode="train", subset=True).adata
print(f"⌛ Loaded subsetted training split in {time.time() - start:.2f} seconds.")

Loading all splits should be reduced by a factor of **3x** and loading a the subsetted data should reduce the time by a factor of **10x**. Thus, when combined we can load the data **30x** faster.

The annotated data (`anndata.AnnData`) is a data structure that stores the data including annotations. We can get detailled information about the data by printing the object.

In [None]:
annotated_data

We can view the annotation for each cell (sample) and each gene (feature) by accessing the `obs` and `var` attributes of the `anndata.AnnData` object. The `obs` attribute is a `pandas.DataFrame` with the cell annotations and the `var` attribute is a `pandas.DataFrame` with the gene annotations.

In [None]:
# Cell annotations
annotated_data.obs

In [None]:
# Gene annotations
annotated_data.var

In [None]:
# We can get the features and targets as numpy arrays (this is done in the TMDataset class as well)
feature_matrix = annotated_data.X
targets = annotated_data.obs["label"].cat.codes.to_numpy()

print(f"Feature matrix: {feature_matrix.shape}, Targets: {targets.shape}")
print(f"Number of target tissues: {len(np.unique(targets))}")

In [None]:
# Plot Cell Type Distribution
_, ax = plt.subplots(figsize=(20, 10))
names2cells = {v: k for k, v in maca_data.cells2names.items()}
cell_types = [names2cells[trg] for trg in targets]

top_k = 10
counts = collections.Counter(cell_types)
counts = dict(sorted(counts.items(), key=lambda x: x[1], reverse=True)[:top_k])

sns.barplot(x=list(counts.keys()), y=list(counts.values()), palette="mako", ax=ax)
ax.set(xlabel="Cell type", ylabel="Count", title=f"Cell Type Distribution (Top {top_k})")
ax.set_xticklabels(ax.get_xticklabels(), fontsize=8);

Let's run the same analysis for the data that we get from the `MacaDataImproved` class for each split.

In [None]:
# Load MacaData for each split
maca_data_train = MacaDataImproved(src_file=path, mode="train", subset=False)
maca_data_val = MacaDataImproved(src_file=path, mode="val", subset=False)
maca_data_test = MacaDataImproved(src_file=path, mode="test", subset=False)

# Load subset of MacaData for each split
maca_data_sub_train = MacaDataImproved(src_file=path, mode="train", subset=True)
maca_data_sub_val = MacaDataImproved(src_file=path, mode="val", subset=True)
maca_data_sub_test = MacaDataImproved(src_file=path, mode="test", subset=True)

In [None]:
train_val_test_data = {"train": maca_data_train, "val": maca_data_val, "test": maca_data_test}
sub_train_val_test_data = {"train": maca_data_sub_train, "val": maca_data_sub_val, "test": maca_data_sub_test}

In [None]:
# We can get the features and targets as numpy arrays (this is done in the TMDataset class as well)
for mode, data in train_val_test_data.items():
    feature_matrix = data.adata.X
    targets = data.adata.obs["label"].cat.codes.to_numpy()

    print(f"Split {mode}")
    print(f"Feature matrix: {feature_matrix.shape}, Targets: {targets.shape}")
    print(f"Number of target tissues: {len(np.unique(targets))}")

### TMSimpleDataset

The `TMSimpleDataset` inherits from the abstract `TMDataset` class which, in turn, inherits from the generic `FewShotDataset` class. The `TMDataset` defines the `_dataset_name` as `"tabula_muris"` and the `_dataset_url` and provides a convenient loader utility which loads all samples and their targets. The `TMSimpleDataset` initialises the data directory, loads the data and then does the sanity checks from the base class. It provides the basic methods `__getitem__`, `__len__`, the `dim` property.

Crucially, the data loader is tied to the dataset class and is available by calling the `get_data_loader()` method. It will sample batches of size `batch_size`.

*Note: Upon first call, the `TMSimpleDataset` class will download the data into the `root` directory.*

In [None]:
# Demo: TMSimpleDataset
from datasets.cell.tabula_muris import TMSimpleDataset # noqa

# Arguments to provide
batch_size = 10 # Controls the batch_size of data loader
root = "./data" # Controls where to store the data
min_samples = 20 # Filter out tissue types with less than min_samples

modes = ["train", "val", "test"] # Controls data split (returns subset of tissue types)

# Initialise TabulaMuris training dataset
data = {}
for mode in modes:
    start = time.time()
    tm_data = TMSimpleDataset(
        batch_size=batch_size,
        root=root,
        mode=mode,
        min_samples=min_samples
    )
    data[mode] = tm_data

    print(f"✅ TabulaMuris {mode} split loaded in {time.time() - start:.2f} seconds.")

In [None]:
from datasets.cell.utils import MacaDataImproved # noqa

path = os.path.join("..", "data", "tabula_muris", "tabula-muris-comet.h5ad")

start = time.time()
annotated_train = MacaDataImproved(src_file=path, mode="train", subset=False).adata
print(f"⌛ Loaded training split in {time.time() - start:.2f} seconds.")

start = time.time()
annotated_train_sub = MacaDataImproved(src_file=path, mode="train", subset=True).adata
print(f"⌛ Loaded subsetted training split in {time.time() - start:.2f} seconds.")

In [None]:
annotated_data

In [None]:
# Let's investigate the size of the downloaded data
!du -sh ../data/tabula_muris/*

`gene_association.mgi` (`84 MB`): This file is associated with gene annotations, specifically regarding mouse genes. The file likely contains information such as gene identifiers, gene names, and possibly their associations with various biological functions or diseases.

`go-basic.obo` (`32 MB`): This file is associated with the Gene Ontology (GO), which is a major bioinformatics initiative to unify the representation of gene and gene product attributes across all species. The ".obo" format (Open Biomedical Ontologies format) is a text-based format used for ontologies. The file likely contains GO terms and their definitions, including information on biological processes, cellular components, and molecular functions.

`tabula-muris-comet.h5ad` (`2.3 GB`): The ".h5ad" extension suggests this file is an AnnData file, a format commonly used in bioinformatics for storing large annotated datasets, particularly single-cell data. AnnData files are based on the HDF5 file format, which is designed for storing and organizing large amounts of data. This particular file likely contains the main single-cell RNA sequencing data from the Tabula Muris project, including gene expression measurements for individual cells, metadata about the cells, and possibly additional layers of data like spliced/unspliced gene counts or quality metrics.

As the combined size of the files is pretty large, even loading in the data after downloading takes a while.

In [None]:
# Statistics on the dataset
print(f"ℹ️ Tabula Muris dataset has {len(tm_train)} train samples and {len(tm_test)} test samples.")

In [None]:
# Get sample
tr_smp, tr_trg = tm_train[0]
te_smp, te_trg = tm_test[0]

print(f"Training sample shape: {tr_smp.shape} and target {tr_trg}")
print(f"Test sample shape: {te_smp.shape} and target {te_trg}")

In [None]:
# Get batches
tm_train_loader = tm_train.get_data_loader()
tm_test_loader = tm_test.get_data_loader()


# Get batch
tr_smps, tr_trgs = next(iter(tm_train_loader))
te_smps, te_trgs = next(iter(tm_test_loader))

In [None]:


# Demo: TMSimpleDataset

## SwissProt

---

In [None]:
from datasets.prot.swissprot import SPSimpleDataset, SPSetDataset