In [None]:
import math
from pathlib import Path
from typing import Any

import matplotlib.pyplot as plt
import requests
import scanpy as sc  # type: ignore
from moscot import datasets as mds
from nicheflow.preprocessing.h5ad_dataset_type import load_h5ad_dataset_dataclass
from nicheflow.preprocessing.h5ad_preprocessor import H5ADPreprocessor
from tqdm import tqdm


# Utils

In [None]:
def plot_data(preprocessor: H5ADPreprocessor) -> None:
    timepoints = list(preprocessor.subsampled_timepoint_idx.keys())
    num_timepoints = len(timepoints)
    cols = 5
    rows = math.ceil(num_timepoints / cols)

    if rows == 1:
        cols = min(5, num_timepoints)

    # Create subplots
    fig, axes = plt.subplots(rows, cols, figsize=(cols * 4, rows * 4), squeeze=False)
    fig.subplots_adjust(wspace=0.3, hspace=0.4)

    i = 0
    for timepoint in timepoints:
        row = i // cols
        col = i % cols
        ax = axes[row][col]

        coords = preprocessor.coords[preprocessor.timepoint_indices[timepoint]]
        centroid_indices = preprocessor.subsampled_timepoint_idx[timepoint]

        # Plot background cells
        ax.scatter(coords[:, 0], coords[:, 1], s=5, color=(0.95, 0.95, 0.95))
        # Plot centroids
        ax.scatter(
            coords[centroid_indices][:, 0], coords[centroid_indices][:, 1], s=5, label="Centroids"
        )

        ax.set_title(str(timepoint))
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_frame_on(False)
        ax.legend(loc="upper right", fontsize=8)
        i += 1

    # Hide unused axes
    for j in range(i, rows * cols):
        row = j // cols
        col = j % cols
        axes[row][col].axis("off")

    plt.tight_layout()
    plt.show()

# Downloading the datasets

In [None]:
# Let's first make the data folder
data_folder = Path("../data")
data_folder.mkdir(parents=True, exist_ok=True)

In [None]:
def download_and_save(url: str, output_file: Path) -> None:
    response = requests.get(url, stream=True)
    response.raise_for_status()

    total_size = int(response.headers.get("content-length", 0))
    chunk_size = 1024  # 1 KB

    with (
        open(output_file, "wb") as file,
        tqdm(
            desc="Downloading",
            total=total_size,
            unit="B",
            unit_scale=True,
            unit_divisor=1024,
        ) as bar,
    ):
        for chunk in response.iter_content(chunk_size=chunk_size):
            file.write(chunk)
            bar.update(len(chunk))

In [None]:
aging_url = "https://zenodo.org/records/13883177/files/aging_coronal.h5ad?download=1"
aging_output_file = data_folder.joinpath("aging_coronal.h5ad")

axolotl_dev_url = "https://figshare.com/ndownloader/files/44714629"
axolotl_dev_output_file = data_folder.joinpath("axolotl_development.h5ad")

download_and_save(url=aging_url, output_file=aging_output_file)
download_and_save(url=axolotl_dev_url, output_file=axolotl_dev_output_file)

In [None]:
adata_moscot = mds.mosta()
adata_moscot.obs["time"] = adata_moscot.obs["time"].astype("category")
adata_moscot.layers["counts"] = adata_moscot.X.copy()

# Write it down
moscot_output_file = data_folder.joinpath("moscot.h5ad")
adata_moscot.write_h5ad(moscot_output_file)

In [None]:
# Validate that the datasets have been downloaded
assert aging_output_file.exists()
assert axolotl_dev_output_file.exists()
assert moscot_output_file.exists()

# Manually perprocess the AnnData with scanpy

In [None]:
moscot_output_file = Path("/nfs/homedirs/sakalyan/code/cellular-graph-flow/data/moscot.h5ad")
axolotl_dev_output_file = Path("/nfs/homedirs/sakalyan/code/cellular-graph-flow/data/axolotl_development.h5ad")
aging_output_file = Path("/nfs/homedirs/sakalyan/code/cellular-graph-flow/data/aging_coronal.h5ad")

In [None]:
moscot_adata = sc.read_h5ad(moscot_output_file)
print("Loaded the embryonic development dataset")

axolotl_adata = sc.read_h5ad(axolotl_dev_output_file)
print("Loaded the axolotl brain development dataset")

aging_adata = sc.read_h5ad(aging_output_file)
print("Loaded the mouse brain aging dataset")

In [None]:
N_PRINCIPAL_COMPONENTS = 50
N_TOP_GENES = 2000

### Embryonic development

In [None]:
sc.pp.pca(moscot_adata, n_comps=N_PRINCIPAL_COMPONENTS)

# Set dataset-specific attributes
moscot_timepoint_column = "timepoint"
moscot_cell_type_column = "annotation"

# Timepoints
moscot_timepoints_ordered = sorted(
    set(moscot_adata.obs[moscot_timepoint_column].cat.categories),
    key=lambda x: float(x[1:]),
)

In [None]:
moscot_timepoints_ordered

### Axolotl brain development

In [None]:
sc.pp.highly_variable_genes(axolotl_adata, n_top_genes=N_TOP_GENES, subset=True)
print("Finished selecting the highly variable genes")

sc.pp.pca(axolotl_adata, n_comps=N_PRINCIPAL_COMPONENTS)
print("PCA is done")

# Set dataset-specific attributes
axolotl_timepoint_column = "condition"
axolotl_cell_type_column = "Annotation"

# Timepoints
axolotl_timepoints_ordered = [
    "Stage44",
    "Stage54",
    "Stage57",
    "Injury control",
    "Adult",
    "Meta",
]

### Mouse brain aging

In [None]:
aging_adata.layers["counts"] = aging_adata.X.copy()
sc.pp.normalize_total(aging_adata)
print("Finished the total normalization")

sc.pp.log1p(aging_adata)
print("Finished log1p")

sc.pp.pca(aging_adata, n_comps=N_PRINCIPAL_COMPONENTS)
print("PCA is done")

fraction = 0.2
sc.pp.subsample(aging_adata, fraction)
print("Subsampling is done")

# Set dataset-specific attributes
aging_timepoint_column = "age"
aging_cell_type_column = "celltype"

# Timepoints
aging_timepoints_ordered = sorted(set(aging_adata.obs[aging_timepoint_column]))

In [None]:
aging_timepoints_ordered

# Use the H5AD Preprocessor

It will 
- normalize the positions of the cells
- standardize the PCA components
- compute the microenvironments with the given radius
- create a discrete set of test microenvironments that ensure almost full slide coverage
- store the data to avoid recomputation during training

In [None]:
moscot_save_filepath = data_folder.joinpath("embryonic_data.pkl")
axolotl_save_filepath = data_folder.joinpath("axolotl_brain_dev.pkl")
aging_save_filepath = data_folder.joinpath("mouse_brain_aging.pkl")

In [None]:
data_to_process: list[tuple[sc.AnnData, str, str, list[Any], str]] = [
    (
        moscot_adata,
        moscot_timepoint_column,
        moscot_cell_type_column,
        moscot_timepoints_ordered,
        str(moscot_save_filepath),
    ),
    (
        axolotl_adata,
        axolotl_timepoint_column,
        axolotl_cell_type_column,
        axolotl_timepoints_ordered,
        str(axolotl_save_filepath),
    ),
    (
        aging_adata,
        aging_timepoint_column,
        aging_cell_type_column,
        aging_timepoints_ordered,
        str(aging_save_filepath),
    ),
]

In [None]:
for adata, timepoint_column, cell_type_column, timepoints_ordered, save_filepath in data_to_process:
    preprocessor = H5ADPreprocessor(
        timepoint_column=timepoint_column,
        cell_type_column=cell_type_column,
        timepoints_ordered=timepoints_ordered,
        standardize_coordinates=True,
        radius=0.15,
        dx=0.15,
        dy=0.2,
        device="cpu",
        chunk_size=1000,
        # This will fix the number of microenvironments
        # in the test dataset for each timepoint to be the same.
        fixed_microenvironments=True,
    )

    # Preprocess the data
    preprocessor.preprocess_data(adata)

    # Save the data
    preprocessor.save(save_filepath)

    # Let's plot the data with the centroids chosen for testing
    plot_data(preprocessor=preprocessor)

# Data loading

Now you can also load the dataset dataclass that holds both the data and information about the preprocessing

In [None]:
ds_dataclass = load_h5ad_dataset_dataclass(str(moscot_save_filepath))

In [None]:
ds_dataclass.timepoint_num_neighbors

In [None]:
ds_dataclass.X_pca

In [None]:
import torch

torch.zeros((10, 10)).device.type