In [2]:
from pathlib import Path

import torch
import numpy as np
import pandas as pd

from microfilm import microplot

from FUCCIDataset import FUCCIDatasetInMemory

In [3]:
dataset_dir = Path("/data/ishang/Fucci-dataset-v3_filtered/")

Dataset has some cache files and metadata in the home directory
Then data separated by train, valid, test splits
And then experiment folders inside. These folders contain the individual FOV images per channel.
They also contain the sc cropped and downsampled images caches, and cell segmentation by products.

In [4]:
dataset_index = pd.DataFrame()

In [5]:
# get all experiment image cache files
image_cache_files = []
for stage_split in dataset_dir.iterdir():
    if not stage_split.is_dir():
        continue
    for experiment_dir in stage_split.iterdir():
        if not experiment_dir.is_dir():
            continue
        for image_cache_file in experiment_dir.glob("cells_256.npy"):
            image_cache_files.append(dataset_dir / stage_split / experiment_dir / image_cache_file)

print("Number of full FOV images:", len(image_cache_files))
print(image_cache_files[0])

Number of full FOV images: 2415
/data/ishang/Fucci-dataset-v3_filtered/valid/Overview 1_Image 54--Stage13/cells_256.npy


The below takes about 30 seconds to run, 10ish to load and 20 to concatenate

In [6]:
# load all the images
fucci_images = np.concatenate([np.load(image_cache_file) for image_cache_file in image_cache_files])
print(fucci_images.shape)

In [None]:
images = torch.from_numpy(fucci_images).permute(1, 0, 2, 3)
print(images.shape)
print(images.dtype)
print(images.min(), images.max())
print(torch.isfinite(images).all())

torch.Size([4, 55777, 256, 256])
torch.float32


tensor(-1.) tensor(1.)
tensor(True)


Images are normalized -1 to 1 now

In [None]:
channel_names = ["dapi", "tubulin", "geminin", "cdt1"]
for c, channel in enumerate(channel_names):
    torch.save(images[c], dataset_dir / f"{channel}.pt")

In [None]:
!ls $dataset_dir

cdt1.pt			     fucci_logvar.pt	      reference_mu.pt
colors.npy		     fucci_mu.pt	      reference_var.pt
dapi.pt			     geminin.pt		      test
fucci_embeddings_flipped.pt  index.csv		      train
fucci_embeddings.pt	     reference_embeddings.pt  tubulin.pt
fucci_indices		     reference_indices	      valid
fucci_indices_flipped.npy    reference_indices.npy
fucci_indices.npy	     reference_logvar.pt


In [None]:
from torch.utils.data import Dataset

class ImageChannelDataset(Dataset):
    def __init__(self, dataset_dir, channel_name, color=None):
        self.dataset_dir = dataset_dir
        self.channel_name = channel_name
        self.images = torch.load(self.dataset_dir / f"{self.channel_name}.pt")

    def __len__(self):
        return len(self.images) 

    def __getitem__(self, idx):
        return self.images[idx]

In [None]:
dapi_dataset = ImageChannelDataset(dataset_dir, "dapi", "blue")
print(len(dapi_dataset))

55777


In [None]:
from torch.utils.data import Dataset

class SimpleDataset(Dataset):
    def __init__(self, tensor) -> None:
        self.tensor = tensor

    def __getitem__(self, index):
        return self.tensor[index]

    def __len__(self):
        return self.tensor.size(0)

In [None]:
from lightning import LightningDataModule
from lightning.pytorch.utilities.types import EVAL_DATALOADERS
from torch.utils.data import random_split, DataLoader

class MultiModalDataModule(LightningDataModule):
    def __init__(self, datasets, mode, split, batch_size, num_workers):
        super().__init__()
        self.datasets = [dataset[:] for dataset in datasets]
        self.split = split
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.mode = mode
    
        if self.mode not in self.modes():
            raise ValueError(f"Mode must be one of {self.modes()}. Got {mode}.")
        if self.mode == "paired":
            # stack should give us (modalities, samples, ...)
            # then swapaxes should give us (samples, modalities, ...) so that they are paired
            self.dataset = torch.stack(self.datasets).swapaxes(0, 1)
        elif self.mode == "unpaired":
            # stack should give us (modalities, samples, ...)
            self.dataset = torch.stack(self.datasets)
            raise NotImplementedError("Unpaired mode not implemented yet.")
        elif self.mode == "combined":
            self.dataset = torch.cat(self.datasets)

        self.dataset = SimpleDataset(self.dataset)

        if len(self.split) != 3:
            raise ValueError("split must be a tuple of length 3")
        self.data_train, self.data_val, self.data_test = random_split(self.dataset, self.split)

    def modes(self):
        return ["paired", "unpaired", "combined"]

    def __shared_dataloader(self, dataset, shuffle=True):
        return DataLoader(dataset, batch_size=self.batch_size, num_workers=self.num_workers, persistent_workers=True)

    def train_dataloader(self):
        return self.__shared_dataloader(self.data_train)
    
    def val_dataloader(self):
        return self.__shared_dataloader(self.data_val, shuffle=False)

    def test_dataloader(self):
        return self.__shared_dataloader(self.data_test, shuffle=False)

    def predict_dataloader(self):
        return super().predict_dataloader()


The below takes about 1.5 mins to run

In [None]:
datasets = [ImageChannelDataset(dataset_dir, c) for c in channel_names]

In [None]:
from Dataset import MultiModalDataModule as mmdm
dm = mmdm(datasets, "combined", (0.64, 0.16, 0.2), 8, 1)
print(len(dm.data_train), len(dm.data_val), len(dm.data_test))

142790 35697 44621


In [None]:
print(dm.data_train[0].shape)

NameError: name 'dm' is not defined