In [None]:
from microsim import schema as ms
import numpy as np

In [None]:
# Define a simple Filter through a function that allows us to set min and max values for the spectrum
def create_custom_channel(
    min_wave: int = 300, 
    max_wave: int = 800,
) -> ms.optical_config.OpticalConfig:

    custom_spectrum = ms.Spectrum(
        wavelength=np.arange(min_wave, max_wave, 1),
        intensity=np.ones(max_wave - min_wave),
    )

    custom_filter = ms.optical_config.SpectrumFilter(transmission=custom_spectrum) # placement=ALL by default

    custom_channel = ms.optical_config.OpticalConfig(
        name="FEDERICO",
        filters=[custom_filter],
    )
    
    return custom_channel

In [None]:
my_ch = create_custom_channel()

print(my_ch.filters[0].spectrum)
print(my_ch.filters[0].spectrum.wavelength.shape, type(my_ch.filters[0].spectrum.wavelength))
print(my_ch.filters[0].spectrum.intensity.shape, type(my_ch.filters[0].spectrum.intensity), max(my_ch.filters[0].spectrum.intensity))

Instantiate `Simulation` object

In [None]:
from microsim.schema.optical_config import lib

sim = ms.Simulation(
    truth_space=ms.ShapeScaleSpace(shape=(52, 512, 512), scale=(0.064, 0.064, 0.064)),
    output_space={"downscale": 2},
    sample=ms.Sample(
        labels=[
            ms.FluorophoreDistribution(
                distribution=ms.CosemLabel(dataset="jrc_hela-3", label="ne_pred"),
                fluorophore="mTurquoise",
            ),
            ms.FluorophoreDistribution(
                distribution=ms.CosemLabel(dataset="jrc_hela-3", label="er-mem_pred"),
                fluorophore="EGFP",
            ),
            ms.FluorophoreDistribution(
                distribution=ms.CosemLabel(dataset="jrc_hela-3", label="mito-mem_pred"),
                fluorophore="EYFP",
            ),
        ]
    ),
    channels= [create_custom_channel(min_wave=460, max_wave=550)],
    modality=ms.Confocal(pinhole_au=2),
    detector=ms.CameraCCD(qe=0.82, read_noise=6),
    output_path="h2-cf.tif",
    settings=ms.Settings(max_psf_radius_aus=2),
    emission_bins=32,
    light_powers=[3, 1, 1]
)

### Simulate `ground_truth` array

In [None]:
gt = sim.ground_truth()
print(gt.shape, gt.coords) # (F, Z, Y, X)

In [None]:
# Print the fluorophore distribution
import matplotlib.pyplot as plt

# Make MIP over z-axis
gt_mip = gt.max(dim='z')

cmaps = ["Blues", "Greens", "Reds", "Grays"]
_, ax = plt.subplots(1, gt_mip.shape[0], figsize=(10, 5))
for i in range(gt_mip.shape[0]):
    ax[i].imshow(gt_mip[i, ...], cmap=cmaps[i])

### Get the spectral emission flux

In [None]:
em_img, em_spectra, em_binned_spectra = sim.spectral_emission_flux(gt, channel_idx=0)

In [None]:
# Print the emission image
from microsim.schema.dimensions import Axis

def normalize_global(data):
    min_val = data.min().item()
    max_val = data.max().item()
    normalized_data = ((data - min_val) / (max_val - min_val) * 65535).astype(np.uint16)
    return normalized_data

fluor1_em_img = em_img[:, :, 0, ...].squeeze()
fluor2_em_img = em_img[:, :, 1, ...].squeeze()
fluor3_em_img = em_img[:, :, 2, ...].squeeze()
mixed_em_img = em_img[:, :, 3, ...].squeeze()

fluor1_em_mip = fluor1_em_img.max(dim=Axis.Z)
fluor2_em_mip = fluor2_em_img.max(dim=Axis.Z)
fluor3_em_mip = fluor3_em_img.max(dim=Axis.Z)
mixed_em_mip = mixed_em_img.max(dim=Axis.Z)

norm_fluor1_em_mip = normalize_global(fluor1_em_mip)
norm_fluor2_em_mip = normalize_global(fluor2_em_mip)
norm_fluor3_em_mip = normalize_global(fluor3_em_mip)
norm_mixed_em_mip = normalize_global(mixed_em_mip)

_, ax = plt.subplots(4, 3, figsize=(15, 20))
ax[0,0].imshow(norm_fluor1_em_mip[1, ...], cmap="Grays")
ax[0,1].imshow(norm_fluor1_em_mip[5, ...], cmap="Greens")
ax[0,2].imshow(abs(norm_fluor1_em_mip[1, ...] - norm_fluor1_em_mip[5, ...]), cmap="Reds")
ax[1,0].imshow(norm_fluor2_em_mip[6, ...], cmap="Grays")
ax[1,1].imshow(norm_fluor2_em_mip[13, ...], cmap="Greens")
ax[1,2].imshow(abs(norm_fluor2_em_mip[6, ...] - norm_fluor2_em_mip[13, ...]), cmap="Reds")
ax[2,0].imshow(norm_fluor3_em_mip[10, ...], cmap="Grays")
ax[2,1].imshow(norm_fluor3_em_mip[14, ...], cmap="Greens")
ax[2,2].imshow(abs(norm_fluor3_em_mip[10, ...] - norm_fluor3_em_mip[14, ...]), cmap="Reds")
ax[3,0].imshow(norm_mixed_em_mip[7, ...], cmap="Grays")
ax[3,1].imshow(norm_mixed_em_mip[12, ...], cmap="Greens")
ax[3,2].imshow(abs(norm_mixed_em_mip[7, ...] - norm_mixed_em_mip[12, ...]), cmap="Reds")

### Get the optical image (new approach)

In [None]:
sim.spectral_image = True
opt_img = sim.optical_image(em_img, channel_idx=0)

In [None]:
fluor1_opt_img = opt_img[:, :, 0, ...].squeeze()
fluor2_opt_img = opt_img[:, :, 1, ...].squeeze()
fluor3_opt_img = opt_img[:, :, 2, ...].squeeze()
mixed_opt_img = opt_img[:, :, 3, ...].squeeze()

fluor1_opt_mip = fluor1_opt_img.max(dim=Axis.Z)
fluor2_opt_mip = fluor2_opt_img.max(dim=Axis.Z)
fluor3_opt_mip = fluor3_opt_img.max(dim=Axis.Z)
mixed_opt_mip = mixed_opt_img.max(dim=Axis.Z)

norm_fluor1_opt_mip = normalize_global(fluor1_opt_mip)
norm_fluor2_opt_mip = normalize_global(fluor2_opt_mip)
norm_fluor3_opt_mip = normalize_global(fluor3_opt_mip)
norm_mixed_opt_mip = normalize_global(mixed_opt_mip)

_, ax = plt.subplots(4, 3, figsize=(15, 20))
ax[0,0].imshow(norm_fluor1_opt_mip[1, ...], cmap="Grays")
ax[0,0].set_title("A first spectral band")
ax[0,1].imshow(norm_fluor1_opt_mip[5, ...], cmap="Grays")
ax[0,1].set_title("Another spectral band")
ax[0,2].imshow(abs(norm_fluor1_opt_mip[1, ...] - norm_fluor1_opt_mip[5, ...]), cmap="Reds")
ax[0,2].set_title("Abs Difference")
ax[1,0].imshow(norm_fluor2_opt_mip[6, ...], cmap="Grays")
ax[1,1].imshow(norm_fluor2_opt_mip[13, ...], cmap="Grays")
ax[1,2].imshow(abs(norm_fluor2_opt_mip[6, ...] - norm_fluor2_opt_mip[13, ...]), cmap="Reds")
ax[2,0].imshow(norm_fluor3_opt_mip[10, ...], cmap="Grays")
ax[2,1].imshow(norm_fluor3_opt_mip[14, ...], cmap="Grays")
ax[2,2].imshow(abs(norm_fluor3_opt_mip[10, ...] - norm_fluor3_opt_mip[14, ...]), cmap="Reds")
ax[3,0].imshow(norm_mixed_opt_mip[8, ...], cmap="Grays")
ax[3,1].imshow(norm_mixed_opt_mip[12, ...], cmap="Grays")
ax[3,2].imshow(abs(norm_mixed_opt_mip[8, ...] - norm_mixed_opt_mip[12, ...]), cmap="Reds")

### Get the digital image

In [None]:
exposure = 100
digital_img = sim.digital_image(opt_img, exposure_ms=exposure)

In [None]:
fluor1_digital_img = digital_img[:, :, 0, ...].squeeze()
fluor2_digital_img = digital_img[:, :, 1, ...].squeeze()
fluor3_digital_img = digital_img[:, :, 2, ...].squeeze()
mixed_digital_img = digital_img[:, :, 3, ...].squeeze()

# norm_fluor1_digital_img = normalize_global(fluor1_digital_img)
# norm_fluor2_digital_img = normalize_global(fluor2_digital_img)
# norm_fluor3_digital_img = normalize_global(fluor3_digital_img)
# norm_mixed_digital_img = normalize_global(mixed_digital_img)

fluor1_digital_mip = fluor1_digital_img.max(dim=Axis.Z)
fluor2_digital_mip = fluor2_digital_img.max(dim=Axis.Z)
fluor3_digital_mip = fluor3_digital_img.max(dim=Axis.Z)
mixed_digital_mip = mixed_digital_img.max(dim=Axis.Z)

norm_fluor1_digital_mip = normalize_global(fluor1_digital_mip)
norm_fluor2_digital_mip = normalize_global(fluor2_digital_mip)
norm_fluor3_digital_mip = normalize_global(fluor3_digital_mip)
norm_mixed_digital_mip = normalize_global(mixed_digital_mip)

_, ax = plt.subplots(4, 3, figsize=(15, 20))
ax[0,0].imshow(norm_fluor1_digital_mip[1, ...], cmap="Grays")
ax[0,0].set_title("A first spectral band")
ax[0,1].imshow(norm_fluor1_digital_mip[5, ...], cmap="Grays")
ax[0,1].set_title("Another spectral band")
ax[0,2].imshow(abs(norm_fluor1_digital_mip[1, ...] - norm_fluor1_digital_mip[5, ...]), cmap="Reds")
ax[0,2].set_title("Abs Difference")
ax[1,0].imshow(norm_fluor2_digital_mip[6, ...], cmap="Grays")
ax[1,1].imshow(norm_fluor2_digital_mip[13, ...], cmap="Grays")
ax[1,2].imshow(abs(norm_fluor2_digital_mip[6, ...] - norm_fluor2_digital_mip[13, ...]), cmap="Reds")
ax[2,0].imshow(norm_fluor3_digital_mip[10, ...], cmap="Grays")
ax[2,1].imshow(norm_fluor3_digital_mip[14, ...], cmap="Grays")
ax[2,2].imshow(abs(norm_fluor3_digital_mip[10, ...] - norm_fluor3_digital_mip[14, ...]), cmap="Reds")
ax[3,0].imshow(norm_mixed_digital_mip[8, ...], cmap="Grays")
ax[3,1].imshow(norm_mixed_digital_mip[12, ...], cmap="Grays")
ax[3,2].imshow(abs(norm_mixed_digital_mip[8, ...] - norm_mixed_digital_mip[12, ...]), cmap="Reds")

In [None]:
# import napari

# viewer = napari.Viewer()
# viewer.add_image(norm_fluor1_digital_mip, name="Fluorophore 1", colormap="blue", blending="additive",  contrast_limits=(0, 65535))
# viewer.add_image(norm_fluor2_digital_mip, name="Fluorophore 2", colormap="yellow", blending="additive",  contrast_limits=(0, 65535))
# viewer.add_image(norm_fluor3_digital_mip, name="Fluorophore 3", colormap="red", blending="additive",  contrast_limits=(0, 65535))
# viewer.add_image(norm_mixed_digital_mip, name="Mixed", blending="additive",  contrast_limits=(0, 65535))

# viewer.add_image(norm_fluor1_digital_img, name="Fluorophore 1", colormap="blue", blending="additive", contrast_limits=(0, 65535))
# viewer.add_image(norm_fluor2_digital_img, name="Fluorophore 2", colormap="yellow", blending="additive", contrast_limits=(0, 65535))
# viewer.add_image(norm_fluor3_digital_img, name="Fluorophore 3", colormap="red", blending="additive", contrast_limits=(0, 65535))
# viewer.add_image(norm_mixed_digital_img, name="Mixed", blending="additive", contrast_limits=(0, 65535))

### Save files 

In [None]:
import os
import datetime
import tifffile as tiff

def get_unique_directory_path(base_path):
    version = 0
    new_path = f"{base_path}_v{version}"
    
    while os.path.exists(new_path):
        version += 1
        new_path = f"{base_path}_v{version}"
    
    return new_path

In [None]:
current_date = datetime.date.today()
formatted_date = current_date.strftime("%y%m%d")

ROOT_DIR = "/group/jug/federico/microsim/sim_spectral_data/"
current_dir = os.path.join(ROOT_DIR, formatted_date)
current_dir = get_unique_directory_path(current_dir)

imgs_dir = os.path.join(current_dir, "imgs")
os.makedirs(imgs_dir, exist_ok=True)
mips_dir = os.path.join(current_dir, "mips")
os.makedirs(mips_dir, exist_ok=True)

print(current_dir)

In [None]:
# Save metadata (coordinates, wavelengths, channels)
import json

w_bins = [
    (em_img.coords[Axis.W].values[i].left, em_img.coords[Axis.W].values[i].right)
    for i in range(len(em_img.coords[Axis.W].values))
]

coords_info = {
    "x_coords": gt.coords[Axis.X].values.tolist(),
    "y_coords": gt.coords[Axis.Y].values.tolist(),
    "z_coords": gt.coords[Axis.Z].values.tolist(),
    "w_bins": w_bins,
}

sim_metadata = {
    "fluorophores": [fp.fluorophore.name for fp in gt.coords[Axis.F].values],
    "light_powers": sim.light_powers,
    "downscale": sim.output_space.downscale,
    "detect_exposure_ms": exposure,
    "detect_quantum_eff": sim.detector.qe,
    "min_max_wavelength": [460, 550],
}

with open(os.path.join(current_dir, "sim_coords.json"), "w") as f:
    json.dump(coords_info, f)
    
with open(os.path.join(current_dir, "sim_metadata.json"), "w") as f:
    json.dump(sim_metadata, f)

We compute microscopy ground truth as the sum of the intensities over all the spectral bands

In [None]:
fluor1_em_img_gt = fluor1_em_img.sum(dim=Axis.W)
fluor2_em_img_gt = fluor2_em_img.sum(dim=Axis.W)
fluor3_em_img_gt = fluor3_em_img.sum(dim=Axis.W)

fluor1_opt_img_gt = fluor1_opt_img.sum(dim=Axis.W)
fluor2_opt_img_gt = fluor2_opt_img.sum(dim=Axis.W)
fluor3_opt_img_gt = fluor3_opt_img.sum(dim=Axis.W)

fluor1_digital_img_gt = fluor1_digital_img.sum(dim=Axis.W)
fluor2_digital_img_gt = fluor2_digital_img.sum(dim=Axis.W)
fluor3_digital_img_gt = fluor3_digital_img.sum(dim=Axis.W)

In [None]:
# Normalize images
norm_fluor1_em_img_gt = normalize_global(fluor1_em_img_gt)
norm_fluor2_em_img_gt = normalize_global(fluor2_em_img_gt)
norm_fluor3_em_img_gt = normalize_global(fluor3_em_img_gt)
norm_mixed_em_img = normalize_global(mixed_em_img)

norm_fluor1_opt_img_gt = normalize_global(fluor1_opt_img_gt)
norm_fluor2_opt_img_gt = normalize_global(fluor2_opt_img_gt)
norm_fluor3_opt_img_gt = normalize_global(fluor3_opt_img_gt)
norm_mixed_opt_img = normalize_global(mixed_opt_img)

norm_fluor1_digital_img_gt = normalize_global(fluor1_digital_img_gt)
norm_fluor2_digital_img_gt = normalize_global(fluor2_digital_img_gt)
norm_fluor3_digital_img_gt = normalize_global(fluor3_digital_img_gt)
norm_mixed_digital_img = normalize_global(mixed_digital_img)

In [None]:
# Save ground truth
tiff.imwrite(os.path.join(current_dir, "ground_truth_img.tif"), gt.data)

In [None]:
# Save emission images 3D
tiff.imwrite(os.path.join(imgs_dir, "emission_fluor1_gt.tif"), norm_fluor1_em_img_gt.data)
tiff.imwrite(os.path.join(imgs_dir, "emission_fluor2_gt.tif"), norm_fluor2_em_img_gt.data)
tiff.imwrite(os.path.join(imgs_dir, "emission_fluor3_gt.tif"), norm_fluor3_em_img_gt.data)
tiff.imwrite(os.path.join(imgs_dir, "emission_mixed.tif"), norm_mixed_em_img.data)

# Save optical images 3D
tiff.imwrite(os.path.join(imgs_dir, "optical_fluor1_gt.tif"), norm_fluor1_opt_img_gt.data)
tiff.imwrite(os.path.join(imgs_dir, "optical_fluor2_gt.tif"), norm_fluor2_opt_img_gt.data)
tiff.imwrite(os.path.join(imgs_dir, "optical_fluor3_gt.tif"), norm_fluor3_opt_img_gt.data)
tiff.imwrite(os.path.join(imgs_dir, "optical_mixed.tif"), norm_mixed_opt_img.data)

# Save digital images 3D
tiff.imwrite(os.path.join(imgs_dir, "digital_fluor1_gt.tif"), norm_fluor1_digital_img_gt.data)
tiff.imwrite(os.path.join(imgs_dir, "digital_fluor2_gt.tif"), norm_fluor2_digital_img_gt.data)
tiff.imwrite(os.path.join(imgs_dir, "digital_fluor3_gt.tif"), norm_fluor3_digital_img_gt.data)
tiff.imwrite(os.path.join(imgs_dir, "digital_mixed.tif"), norm_mixed_digital_img.data)

In [None]:
# Save maximum intensity projections (to have lighter data to experiment with)

# Save emission images 2D
# tiff.imwrite(os.path.join(mips_dir, "emission_fluor1_mip.tif"), norm_fluor1_em_mip.data)
# tiff.imwrite(os.path.join(mips_dir, "emission_fluor2_mip.tif"), norm_fluor2_em_mip.data)
# tiff.imwrite(os.path.join(mips_dir, "emission_fluor3_mip.tif"), norm_fluor3_em_mip.data)
tiff.imwrite(os.path.join(mips_dir, "emission_mixed_mip.tif"), norm_mixed_em_mip.data)

# Save optical images 2D
# tiff.imwrite(os.path.join(mips_dir, "optical_fluor1_mip.tif"), norm_fluor1_opt_mip.data)
# tiff.imwrite(os.path.join(mips_dir, "optical_fluor2_mip.tif"), norm_fluor2_opt_mip.data)
# tiff.imwrite(os.path.join(mips_dir, "optical_fluor3_mip.tif"), norm_fluor3_opt_mip.data)
tiff.imwrite(os.path.join(mips_dir, "optical_mixed_mip.tif"), norm_mixed_opt_mip.data)

# Save digital images 2D
# tiff.imwrite(os.path.join(mips_dir, "digital_fluor1_mip.tif"), norm_fluor1_digital_mip.data)
# tiff.imwrite(os.path.join(mips_dir, "digital_fluor2_mip.tif"), norm_fluor2_digital_mip.data)
# tiff.imwrite(os.path.join(mips_dir, "digital_fluor3_mip.tif"), norm_fluor3_digital_mip.data)
tiff.imwrite(os.path.join(mips_dir, "digital_mixed_mip.tif"), norm_mixed_digital_mip.data)