In [1]:
from pathlib import Path

import torch
import numpy as np
import pandas as pd

from microfilm import microplot

from FUCCIDataset import FUCCIDatasetInMemory

In [2]:
dataset_dir = Path("/data/ishang/Fucci-dataset-v3_filtered/")

Dataset has some cache files and metadata in the home directory
Then data separated by train, valid, test splits
And then experiment folders inside. These folders contain the individual FOV images per channel.
They also contain the sc cropped and downsampled images caches, and cell segmentation by products.

In [5]:
# get all experiment image cache files
image_cache_files = []
for stage_split in dataset_dir.iterdir():
    if not stage_split.is_dir():
        continue
    for experiment_dir in stage_split.iterdir():
        if not experiment_dir.is_dir():
            continue
        for image_cache_file in experiment_dir.glob("cells_256.npy"):
            image_cache_files.append(dataset_dir / stage_split / experiment_dir / image_cache_file)

print("Number of full FOV images:", len(image_cache_files))
print(image_cache_files[0])

Number of full FOV images: 2415
/data/ishang/Fucci-dataset-v3_filtered/valid/Overview 1_Image 54--Stage13/cells_256.npy


The below takes about 30 seconds to run, 10ish to load and 20 to concatenate

In [6]:
# load all the images
fucci_images = np.concatenate([np.load(image_cache_file) for image_cache_file in image_cache_files])
print(fucci_images.shape)

In [None]:
images = torch.from_numpy(fucci_images).permute(1, 0, 2, 3)
print(images.shape)
print(images.dtype)
print(images.min(), images.max())
print(torch.isfinite(images).all())

torch.Size([4, 55777, 256, 256])
torch.float32


tensor(-1.) tensor(1.)
tensor(True)


Images are normalized -1 to 1 now

In [None]:
channel_names = ["dapi", "tubulin", "geminin", "cdt1"]
for c, channel in enumerate(channel_names):
    torch.save(images[c], dataset_dir / f"{channel}.pt")

In [None]:
!ls $dataset_dir

cdt1.pt			     fucci_logvar.pt	      reference_mu.pt
colors.npy		     fucci_mu.pt	      reference_var.pt
dapi.pt			     geminin.pt		      test
fucci_embeddings_flipped.pt  index.csv		      train
fucci_embeddings.pt	     reference_embeddings.pt  tubulin.pt
fucci_indices		     reference_indices	      valid
fucci_indices_flipped.npy    reference_indices.npy
fucci_indices.npy	     reference_logvar.pt


Takes 2.5 mins to run the below code

In [3]:
from Dataset import MultiModalDataModule
channel_names = ["dapi", "tubulin", "geminin", "cdt1"]
dataset_dirs = [dataset_dir for _ in range(len(channel_names))]
colors = ["blue", "yellow", "green", "red"]
split = (0.64, 0.16, 0.2)
batch_size = 8
num_workers = 8

In [4]:

dm = MultiModalDataModule(dataset_dirs, channel_names, colors, "combined", split, batch_size, num_workers)
print(len(dm.data_train), len(dm.data_val), len(dm.data_test))

142790 35697 44621


In [5]:
print(dm.data_train[0].shape)
print(next(iter(dm.train_dataloader()))[0].shape)

torch.Size([1, 256, 256])
torch.Size([1, 256, 256])


In [5]:
shard_size = 1000
for dataset in channel_names:
    print(f"loading {dataset}")
    data = torch.load(dataset_dir / f"{dataset}.pt")
    print(f"{dataset} loaded, shape: {data.shape}")
    print("saving shards")
    for i in range(0, data.shape[0], shard_size):
        print(f"Saving {dataset}_{int(i / shard_size)}.pt")
        torch.save(data[i:i+shard_size], dataset_dir / f"{dataset}_{int(i / shard_size)}.pt")

loading dapi
dapi loaded, shape: torch.Size([55777, 256, 256])
saving shards
Saving dapi_0.pt
Saving dapi_1.pt
Saving dapi_2.pt
Saving dapi_3.pt
Saving dapi_4.pt


KeyboardInterrupt: 

In [13]:
from multiprocessing import Pool
import sys

shard_size = 1000
for dataset in channel_names:
    print(f"loading {dataset}")
    data = torch.load(dataset_dir / f"{dataset}.pt")
    print(f"{dataset} loaded, shape: {data.shape}")
    print("saving shards")

    def save_shard(i):
        print(f"Saving {dataset}_{int(i / shard_size)}.pt")
        sys.stdout.flush()
        torch.save(data[i:min(i+shard_size, len(data))], dataset_dir / f"{dataset}_{int(i / shard_size)}.pt")
    
    print(list(range(0, len(data), shard_size)))
    with Pool(8) as p:
        p.map(save_shard, list(range(0, len(data), shard_size)))

loading dapi
dapi loaded, shape: torch.Size([55777, 256, 256])
saving shards
[0, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000, 10000, 11000, 12000, 13000, 14000, 15000, 16000, 17000, 18000, 19000, 20000, 21000, 22000, 23000, 24000, 25000, 26000, 27000, 28000, 29000, 30000, 31000, 32000, 33000, 34000, 35000, 36000, 37000, 38000, 39000, 40000, 41000, 42000, 43000, 44000, 45000, 46000, 47000, 48000, 49000, 50000, 51000, 52000, 53000, 54000, 55000]
Saving dapi_0.ptSaving dapi_2.ptSaving dapi_10.ptSaving dapi_12.ptSaving dapi_14.ptSaving dapi_6.ptSaving dapi_8.ptSaving dapi_4.pt







Saving dapi_3.pt
Saving dapi_11.pt
Saving dapi_7.pt
Saving dapi_15.pt
Saving dapi_5.pt
Saving dapi_9.pt
Saving dapi_13.pt
Saving dapi_1.pt


KeyboardInterrupt: 