In [None]:
import sys
sys.path.insert(1, '../SyMBac/') 

from SyMBac.drawing import raster_cell
from SyMBac.PSF import PSF_generator
from SyMBac.renderer import convolve_rescale
import numpy as np
from tqdm.auto import tqdm
from PIL import Image
import os

from skimage.util import img_as_uint
from skimage.measure import label

In [None]:
def pad_np_arrays_to_largest_and_return_masks(cells):
    max_height = cells[-1].shape[0]
    new_cells = []
    masks = []
    for cell in cells:
        height_diff = max_height - cell.shape[0]
        middle_pad = np.zeros((height_diff, cell.shape[1]))
        padded_cell = np.concatenate([cell, middle_pad])
        padded_cell = np.pad(padded_cell, additional_pad)
        mask = (padded_cell>0)
        new_cells.append(padded_cell)
        masks.append(mask)
    new_cells = np.concatenate(new_cells, axis=1)
    return new_cells, np.concatenate(masks, axis=1)

In [None]:
try:
    os.mkdir("synthetic_training_data")
except:
    pass

In [None]:
lengths = np.arange(5,11,1)
widths = np.arange(1,11,1)
densities = np.linspace(1,100, 10)
NAs = [1.45]
ns = [1.518]
wavelengths = [0.45, 0.5, 0.575,  0.65]
names = ["blue", "green", "orange", "far red"]

a = 0
for length in lengths:
    for width in widths:
        if length >= width*2:
             for NA in NAs:
                for n in ns:
                    if n > NA:
                        for wavelength, name in zip(wavelengths, names):
                            a+=1

In [None]:
resize_amount = 1
pix_mic_conv = 0.065 #micron/pix

In [None]:
additional_pad = ((60,60),(60,60))

In [None]:
for length in tqdm(lengths):
    for width in widths:
        if length >= width*2:
             for NA in NAs:
                for n in ns:
                    if n > NA:
                        for wavelength, name in zip(wavelengths, names):
                            cell_length = length #micron
                            cell_width = width #micron
                            raster_cell_length = cell_length/pix_mic_conv * resize_amount
                            raster_cell_width = cell_width/pix_mic_conv * resize_amount

                            cells = []
                            for separation in np.linspace(0, raster_cell_width-1, 10):
                                cell = raster_cell(length=raster_cell_length, width=raster_cell_width, separation=separation, pinching=True)
                                cells.append(cell)


                            for further_separation in np.linspace(0, raster_cell_width//2, 5):
                                further_separation = int(further_separation)
                                cell = cells[-1]
                                middle_pad = np.zeros((further_separation, cell.shape[1]))
                                cell = np.concatenate([cell[:cell.shape[0]//2,:], middle_pad, cell[cell.shape[0]//2:,:]])
                                cells.append(cell)

                            cells, masks = pad_np_arrays_to_largest_and_return_masks(cells)
                            masks = label(masks)

                            PSF = PSF_generator(
                                radius = 150, 
                                wavelength = wavelength, 
                                NA=NA, 
                                n =n, 
                                resize_amount=resize_amount, 
                                pix_mic_conv=pix_mic_conv, 
                                apo_sigma = 10, 
                                mode="3d fluo", 
                                condenser="Ph3", 
                                z_height = 50
                            )
                            PSF.calculate_PSF()
                            PSF.kernel = np.sum(PSF.kernel, axis=0)

                            conv_cell = convolve_rescale(cells, PSF.kernel, 1, True)
                            conv_cell = img_as_uint(conv_cell)

                            Image.fromarray(conv_cell).save(f"synthetic_training_data/{length}_{width}_{wavelength}_{NA}_{n}.png")
                            Image.fromarray(masks).save(f"synthetic_training_data/{length}_{width}_{wavelength}_{NA}_{n}_masks.png")