In [1]:
from typing import Literal, Sequence

import xarray as xr

from microsim import schema as ms
from microsim.schema.optical_config.lib import spectral_detector
from microsim.schema.optical_config import OpticalConfig

In [None]:
# --- Set simulation parameters
labels: str = ["ER", "F-actin", "Microtubules", "CCPs", "F-actin_Nonlinear"]
"""The labels of the structures to simulate."""
fluorophores: str = ["mTurquoise", "EGFP", "EYFP", "tdTomato", "mCherry"]
"""The fluorophores associated with the structures to simulate."""
num_bands: int = 32
"""The number of spectral bands to acquire (i.e., physically, the number of cameras)."""
light_wavelengths: Sequence[int] = [435, 488, 514, 555, 586]
"""List of lasers to use for excitation."""
light_powers: Sequence[float] = [4., 3., 1., 1., 6.]
"""List of powers associate to each light source (work as scaling factors)."""
out_range: tuple[int, int] = (450, 650)
"""The range of wavelengths of the acquired spectrum in nm."""
exposure_ms: float = 5
"""The exposure time for the detector cameras in ms."""
bp_bandwidth: int = 5
"""The bandwidth of the bandpass filters used in the Beam Splitter (in nm)."""

In [33]:
def create_distribution(
    label: Literal["CCPs", "ER", "F-actin", "Microtubules", "F-actin_Nonlinear"],
    fluorophore: str,
    root_dir: str,
    idx: int | None = None, 
) -> ms.FluorophoreDistribution:
    return ms.FluorophoreDistribution(
        distribution=ms.BioSR(root_dir=root_dir, label=label),
        fluorophore=fluorophore, 
        img_idx=idx,  
    )

def init_simulation(
    labels: list[Literal["CCPs", "ER", "F-actin", "Microtubules", "F-actin_Nonlinear"]],
    fluorophores: list[str],
    root_dir: str,
    channels: Sequence[OpticalConfig],
    detector_qe: float = 0.8,
) -> ms.Simulation:
    assert len(labels) == len(fluorophores)
    
    custom_cache_settings = ms.settings.CacheSettings(
        read=False,
        write=False,
    )
    # create the GT sample
    sample = ms.Sample(
        labels=[
            create_distribution(label, fp, root_dir) 
            for label, fp in zip(labels, fluorophores)
        ]
    )
    
    return ms.Simulation(
        truth_space=ms.ShapeScaleSpace(shape=(1, 1004, 1004), scale=(0.02, 0.02, 0.02)),
        output_space={"downscale": (1, 4, 4)},
        sample=sample,
        channels=channels,
        modality=ms.Identity(),
        settings=ms.Settings(
            cache=custom_cache_settings, spectral_bins_per_emission_channel=1
        ),
        detector=ms.CameraCCD(qe=detector_qe, read_noise=6, bit_depth=12),
    )

    
def run_simulation(sim: ms.Simulation) -> tuple[xr.DataArray]:
    gt = sim.ground_truth()
    print(f"Ground truth: {gt.sizes}") # (F, Z, Y, X)
    print("----------------------------------")
    em_img = sim.emission_flux()
    print(f"Emission image: {em_img.sizes}") # (C, F, Z, Y, X)
    print("----------------------------------")
    opt_img_per_fluor = sim.optical_image_per_fluor() # (C, F, Z, Y, X)
    opt_img = opt_img_per_fluor.sum("f")
    print(f"Optical image: {opt_img.sizes}") # (C, Z, Y, X)
    print("----------------------------------")
    dig_img_per_fluor = sim.digital_image(opt_img_per_fluor)
    digital_img = sim.digital_image(opt_img)
    print(f"Digital image: {digital_img.sizes}") # (C, Z, Y, X)
    print("----------------------------------")    
    return em_img, opt_img_per_fluor, dig_img_per_fluor, digital_img


def simulate_dataset(
    labels: list[Literal["CCPs", "ER", "F-actin", "Microtubules", "F-actin_Nonlinear",]],
    fluorophores: list[str],
    num_simulations: int,
    root_dir: str,
    detect_channels: Sequence[OpticalConfig],
    detector_qe: float = 0.8,
) -> list[xr.DataArray]:
    sim_imgs = []
    for i in range(num_simulations):
        print("----------------------------------")
        print(f"SIMULATING IMAGE {i+1}")
        print("----------------------------------")
        sim = init_simulation(
            labels=labels, 
            fluorophores=fluorophores, 
            root_dir=root_dir,
            channels=detect_channels,
            detector_qe=detector_qe
        )
        sim_imgs.append(run_simulation(sim)) 
        
    # Create simulation metadata
    sim_metadata = {
        "structures": labels, 
        "fluorophores": fluorophores,
        "shape": list(sim_imgs[0].shape[2:]),
        "downscale": sim.output_space.downscale,
        "detect_exposure_ms": exposure_ms,
        "detect_quantum_eff": detector_qe,
        "light_powers": light_powers,
        "light_wavelengths": light_wavelengths,
        "wavelength_range": out_range,
        "dtype": str(sim_imgs[0].dtype),
    }
    return sim_imgs, sim_metadata


In [34]:
ROOT_DIR = "/group/jug/federico/careamics_training/data/BioSR"
SAVE_DIR = "/group/jug/federico/microsim/BIOSR_spectral_data"

detect_channels = spectral_detector(
    bins=num_bands,
    min_wave=out_range[0],
    max_wave=out_range[1],
    lasers=light_wavelengths,
    powers=light_powers,
    exposure_ms=exposure_ms,
    bp_bandwidth=bp_bandwidth,
)

sim = init_simulation(
    labels=labels,
    fluorophores=fluorophores,
    root_dir=ROOT_DIR,
    channels=detect_channels,
    detector_qe=0.8,
)

In [None]:
em, opt, dig_per_fluor, dig = run_simulation(sim)

In [None]:
em.shape, opt.shape, dig_per_fluor.shape, dig.shape

In [None]:
import matplotlib.pyplot as plt
_, ax = plt.subplots(len(labels), 5, figsize=(15, 3*len(labels)))
for i in range(len(fluorophores)):
    ax[i, 0].imshow(em[2, i, 0])
    ax[i, 1].imshow(em[8, i, 0])
    ax[i, 2].imshow(em[14, i, 0])
    ax[i, 3].imshow(em[20, i, 0])
    ax[i, 4].imshow(em[26, i, 0])

em1=em[14, i, 0]
em2=em[20, i, 0]
em3=em[26, i, 0]
print(em1.mean().item(), em2.mean().item(), em3.mean().item())
print(em1.std().item(), em2.std().item(), em3.std().item())

In [None]:
import matplotlib.pyplot as plt
_, ax = plt.subplots(len(labels), 5, figsize=(15, 3*len(labels)))
for i in range(len(fluorophores)):
    ax[i, 0].imshow(opt[2, i, 0])
    ax[i, 1].imshow(opt[8, i, 0])
    ax[i, 2].imshow(opt[14, i, 0])
    ax[i, 3].imshow(opt[20, i, 0])
    ax[i, 4].imshow(opt[26, i, 0])

opt1=opt[14, i, 0]
opt2=opt[20, i, 0]
opt3=opt[26, i, 0]
print(opt1.mean().item(), opt2.mean().item(), opt3.mean().item())
print(opt1.std().item(), opt2.std().item(), opt3.std().item())

In [None]:
import matplotlib.pyplot as plt
_, ax = plt.subplots(len(labels), 5, figsize=(15, 3*len(labels)))
for i in range(len(fluorophores)):
    ax[i, 0].imshow(dig_per_fluor[2, i, 0])
    ax[i, 1].imshow(dig_per_fluor[8, i, 0])
    ax[i, 2].imshow(dig_per_fluor[14, i, 0])
    ax[i, 3].imshow(dig_per_fluor[20, i, 0])
    ax[i, 4].imshow(dig_per_fluor[26, i, 0])

dig_per_fluor1=dig_per_fluor[14, i, 0]
dig_per_fluor2=dig_per_fluor[20, i, 0]
dig_per_fluor3=dig_per_fluor[26, i, 0]
print(dig_per_fluor1.mean().item(), dig_per_fluor2.mean().item(), dig_per_fluor3.mean().item())
print(dig_per_fluor1.std().item(),  dig_per_fluor2.std().item(),  dig_per_fluor3.std().item())

In [None]:
from microsim.util import intensity_histograms

intensity_histograms(em, "Emission", 1e6, 1.5e4)

In [None]:
from microsim.util import view_multi_channel
view_multi_channel([dig[:, 0]])

In [25]:
fps = [label.fluorophore for label in sim.sample.labels]

In [None]:
_, ax = plt.subplots(len(fluorophores)+1, 2, figsize=(20, 3*(len(fluorophores)+1)))
for i in range(len(fluorophores)):
    ax[i, 0].plot(em[:, i, ...].max(axis=(1, 2, 3)))
    ax[i, 1].plot(fps[i].emission_spectrum.wavelength, fps[i].emission_spectrum.intensity)
    ax[i, 1].set_xlim(out_range)
ax[-1, 0].plot(em.sum(axis=1).max(axis=(1, 2, 3)))
em_sp_tot = fps[0].emission_spectrum
for i in range(1, len(fluorophores)):
    em_sp_tot += fps[i].emission_spectrum
ax[-1, 1].plot(em_sp_tot.wavelength, em_sp_tot.intensity)
ax[-1, 1].set_xlim(out_range)


In [None]:
tiff.imsave("../digital_try.tif", dig.values.astype(np.uint16))

In [None]:
res, sim_metadata = simulate_dataset(
    labels=labels,
    fluorophores=fluorophores,
    num_simulations=100,
    root_dir=ROOT_DIR
)