In [1]:
from typing import Literal, Any, Sequence
import os
import datetime
import json

import tifffile as tiff
import numpy as np
import xarray as xr

from microsim import schema as ms
from microsim.schema.optical_config.lib import spectral_detector

In [2]:
# --- Set simulation parameters
labels: str = ["ER", "F-actin", "Microtubules"]
"""The labels of the structures to simulate."""
fluorophores: str = ["mTurquoise", "EGFP", "EYFP"]
"""The fluorophores associated with the structures to simulate."""
num_bands: int = 32
"""The number of spectral bands to acquire (i.e., physically, the number of cameras)."""
light_wavelengths: Sequence[int] = [435, 488, 514]
"""List of lasers to use for excitation."""
light_powers: Sequence[float] = [1., 1., 1.]
"""List of powers associate to each light source (work as scaling factors)."""
out_range: tuple[int, int] = (400, 700)
"""The range of wavelengths of the acquired spectrum in nm."""
exposure_ms: float = 100
"""The exposure time for the detector cameras in ms."""

'The exposure time for the detector cameras in ms.'

In [3]:
def create_distribution(
    label: Literal["CCPs", "ER", "F-actin", "Microtubules"],
    fluorophore: str,
    root_dir: str,
    idx: int | None = None, 
) -> ms.FluorophoreDistribution:
    return ms.FluorophoreDistribution(
        distribution=ms.BioSR(root_dir=root_dir, label=label),
        fluorophore=fluorophore, 
        img_idx=idx,  
    )


def init_simulation(
    labels: list[Literal["CCPs", "ER", "F-actin", "Microtubules"]],
    fluorophores: list[str],
    root_dir: str,
    detector_qe: float = 0.8,
) -> ms.Simulation:
    assert len(labels) == len(fluorophores)
    
    custom_cache_settings = ms.settings.CacheSettings(
        read=False,
        write=False,
    )
    # create the GT sample
    sample = ms.Sample(
        labels=[
            create_distribution(label, fp, root_dir) 
            for label, fp in zip(labels, fluorophores)
        ]
    )
    # create the channels simulating the spectral detector
    # NOTE: this assumes excitation is done with lasers
    detect_channels = spectral_detector(
        bins=num_bands,
        min_wave=out_range[0],
        max_wave=out_range[1],
        lasers=light_wavelengths,
        powers=light_powers,
    )
    
    return ms.Simulation(
        truth_space=ms.ShapeScaleSpace(shape=(1, 1004, 1004), scale=(0.02, 0.02, 0.02)),
        output_space={"downscale": (1, 4, 4)},
        sample=sample,
        channels=detect_channels,
        modality=ms.Identity(),
        settings=ms.Settings(
            cache=custom_cache_settings, spectral_bins_per_emission_channel=1
        ),
        detector=ms.CameraCCD(qe=detector_qe, read_noise=6, bit_depth=12),
    )

    
def run_simulation(sim: ms.Simulation) -> xr.DataArray:
    gt = sim.ground_truth()
    print(f"Ground truth: {gt.sizes}") # (F, Z, Y, X)
    print("----------------------------------")
    em_img = sim.emission_flux()
    print(f"Emission image: {em_img.sizes}") # (C, F, Z, Y, X)
    print("----------------------------------")
    opt_img_per_fluor = sim.optical_image_per_fluor() # (C, F, Z, Y, X)
    opt_img = opt_img_per_fluor.sum("f")
    print(f"Optical image: {opt_img.sizes}") # (C, Z, Y, X)
    print("----------------------------------")
    digital_img = sim.digital_image(opt_img)
    # TODO: add digital GT (C, F, Z, Y, X)
    print(f"Digital image: {digital_img.sizes}") # (C, Z, Y, X)
    print("----------------------------------")    
    return digital_img


def simulate_dataset(
    labels: list[Literal["CCPs", "ER", "F-actin", "Microtubules"]],
    fluorophores: list[str],
    num_simulations: int,
    root_dir: str,
    detector_qe: float = 0.8,
    detect_exposure: int = 100,
) -> list[xr.DataArray]:
    sim_imgs = []
    for i in range(num_simulations):
        print("----------------------------------")
        print(f"SIMULATING IMAGE {i+1}")
        print("----------------------------------")
        sim = init_simulation(labels, fluorophores, root_dir, detector_qe)
        sim_imgs.append(run_simulation(sim, detect_exposure)) 
        
    # Create simulation metadata
    wave_range = [
       int(sim.channels[0].filters[0].transmission.wavelength[0].magnitude),
       int(sim.channels[0].filters[0].transmission.wavelength[-1].magnitude),
    ]
    sim_metadata = {
        "structures": labels, 
        "fluorophores": fluorophores,
        "shape": list(sim_imgs[0].shape[-3:]),
        "downscale": sim.output_space.downscale,
        "detect_exposure_ms": detect_exposure,
        "detect_quantum_eff": detector_qe,
        "wavelength_range": wave_range,
        "dtype": str(sim_imgs[0].dtype),
    }
    return sim_imgs, sim_metadata


In [4]:
ROOT_DIR = "/group/jug/federico/careamics_training/data/BioSR"
SAVE_DIR = "/group/jug/federico/microsim/BIOSR_spectral_data"

sim = init_simulation(
    labels=labels,
    fluorophores=fluorophores,
    root_dir=ROOT_DIR,
    detector_qe=0.8,
)

In [None]:
res, sim_metadata = simulate_dataset(
    labels=labels,
    fluorophores=fluorophores,
    num_simulations=100,
    root_dir=ROOT_DIR
)