In [1]:
import pickle
from copy import deepcopy
import numpy as np
from scipy.stats import norm
from SyMBac.cell import Cell
from SyMBac.trench_geometry import trench_creator, get_trench_segments
from pymunk.pyglet_util import DrawOptions
import pymunk
import pyglet
from tqdm.auto import tqdm


In [38]:
import numpy as np
from joblib import Parallel, delayed
import pickle
from SyMBac.drawing import draw_scene, get_space_size, gen_cell_props_for_draw, generate_curve_props
from SyMBac.trench_geometry import  get_trench_segments
import napari
import os
import warnings
from tqdm.auto import tqdm
from SyMBac.cell import Cell
from pymunk.pyglet_util import DrawOptions
import pymunk
import pyglet
from scipy.stats import norm
from copy import deepcopy
from SyMBac.trench_geometry import trench_creator, get_trench_segments

class ColonySimulation:
    
    
    """
    Class for instantiating Simulation objects. These are the basic objects used to run all SyMBac simulations. This
    class is used to parameterise simulations, run them, draw optical path length images, and then visualise them.

    Example:

    >>> from SyMBac.simulation import Simulation
    >>> my_simulation = Simulation(
            trench_length=15,
            trench_width=1.3,
            cell_max_length=6.65, #6, long cells # 1.65 short cells
            cell_width= 1, #1 long cells # 0.95 short cells
            sim_length = 100,
            pix_mic_conv = 0.065,
            gravity=0,
            phys_iters=15,
            max_length_var = 0.,
            width_var = 0.,
            lysis_p = 0.,
            save_dir="/tmp/",
            resize_amount = 3
        )
    >>> my_simulation.run_simulation(show_window=False)
    >>> my_simulation.draw_simulation_OPL(do_transformation=True, label_masks=True)
    >>> my_simulation.visualise_in_napari()

    """

    def __init__(self, cell_max_length, max_length_var, cell_width, width_var, lysis_p, max_cells, pix_mic_conv, phys_iters, resize_amount, save_dir, load_sim_dir = None):
        """
        Initialising a Simulation object

        Parameters
        ----------
        cell_max_length : float
            Maximum length a cell can reach before dividing (micron)
        cell_width : float
            the average cell width in the simulation (micron)
        pix_mic_conv : float
            The micron/pixel size of the image
        phys_iters : int
            Number of physics iterations per simulation frame. Increase to resolve collisions if cells are falling into one
            another, but decrease if cells begin to repel one another too much (too high a value causes cells to bounce off
            each other very hard). 20 is a good starting point
        max_length_var : float
            Variance of the maximum cell length
        width_var : float
            Variance of the maximum cell width
        save_dir : str
            Location to save simulation output
        lysis_p : float
            probability of cell lysis
        max_cells : int
            Max cells in the simulation
        resize_amount : int
            This is the "upscaling" factor for the simulation and the entire image generation process. Must be kept constant
            across the image generation pipeline. Starting value of 3 recommended.
        load_sim_dir : str
            The directory if you wish to load a previously completed simulation
        """
        self.cell_max_length = cell_max_length
        self.max_length_var = max_length_var
        self.cell_width = cell_width
        self.width_var = width_var
        self.lysis_p = lysis_p
        self.pix_mic_conv = pix_mic_conv
        self.phys_iters = phys_iters
        self.resize_amount = resize_amount
        self.save_dir = save_dir
        self.offset = 30
        self.load_sim_dir = load_sim_dir
        self.max_cells = max_cells

        try:
            os.mkdir(save_dir)
        except:
            pass

        if self.load_sim_dir:
            print("Loading previous simulation, no need to call run_simulation method, but you still need to run OPL drawing and correctly define the scale")
            with open(f"{load_sim_dir}/cell_timeseries.p", 'rb') as f:
                self.cell_timeseries = pickle.load(f)
            with open(f"{load_sim_dir}/space_timeseries.p", 'rb') as f:
                self.space = pickle.load(f)

    


    def run_simulation(self, show_window = True):
        if show_window:
            warnings.warn("You are using show_window = True. If you re-run the simulation (even by re-creating the Simulation object), then for reasons which I do not understand, the state of the simulation is not reset. Restart your notebook or interpreter to re-run simulations.")
        """
        Run the simulation

        :param bool show_window: Whether to show the pyglet window while running the simulation. Typically would be `false` if running SyMBac headless.

        """

        self.run_colony_simulation(
            show_window = show_window,
        )  # growth phase

    def draw_simulation_OPL(self, label_masks = True, return_output = False):


        """
        Draw the optical path length images from the simulation. This involves drawing the 3D cells into a 2D numpy
        array, and then the corresponding masks for each cell.

        After running this function, the Simulation object will gain two new attributes: ``self.OPL_scenes`` and ``self.masks`` which can be accessed separately.

        :param bool do_transformation: Sets whether to transform the cells by bending them. Bending the cells can add realism to a simulation, but risks clipping the cells into the mother machine trench walls.

        :param bool label_masks: Sets whether the masks should be binary, or labelled. Masks should be binary is training a standard U-net, such as with DeLTA, but if training Omnipose (recommended), then mask labelling should be set to True.

        :param bool return_output: Controls whether the function returns the OPL scenes and masks. Does not affect the assignment of these attributes to the instance.

        Returns
        -------
        output : tuple(list(numpy.ndarray), list(numpy.ndarray))
           If ``return_output = True``, a tuple containing lists, each of which contains the entire simulation. The first element in the tuple contains the OPL images, the second element contains the masks

        """
        ID_props = generate_curve_props(self.cell_timeseries)

        self.cell_timeseries_properties = Parallel(n_jobs=-1)(
            delayed(gen_cell_props_for_draw)(a, ID_props) for a in tqdm(self.cell_timeseries, desc='Extracting cell properties from the simulation'))
        space_size = get_space_size(self.cell_timeseries_properties)

        #xs, ys = [], []
        #for timepoint in self.cell_timeseries_properties:
        #    for cell in timepoint:
        #        x_, y_ = np.ceil(cell[3]).astype(int)
        #        xs.append(x_)
        #        ys.append(y_)
        #
        #move_x = int(np.mean(xs) + space_size[1]/2)
        #move_y = int(np.mean(ys) + space_size[0]/2)
        
        #for t, timepoint in enumerate(self.cell_timeseries_properties):
        #    for c, cell in enumerate(timepoint):
        #        cell[3]
        #        self.cell_timeseries_properties[t][c][3] += np.array([move_x, move_y])
        
        do_transformation = False
        scenes = Parallel(n_jobs=-1)(delayed(draw_scene)(
        cell_properties, do_transformation, space_size, self.offset, label_masks) for cell_properties in tqdm(
            self.cell_timeseries_properties, desc='Rendering cell optical path lengths'))
        self.OPL_scenes = [_[0] for _ in scenes]
        self.masks = [_[1] for _ in scenes]

        if return_output:
            return self.OPL_scenes, self.masks

    def visualise_in_napari(self):
        """
        Opens a napari window allowing you to visualise the simulation, with both masks, OPL images, interactively.
        :return:
        """
        
        viewer = napari.view_image(np.array(self.OPL_scenes), name='OPL scenes')
        viewer.add_labels(np.array(self.masks), name='Synthetic masks')
        napari.run()


    def run_colony_simulation(self, show_window = True):
        """
        Runs the rigid body simulation of bacterial growth based on a variety of parameters. Opens up a Pyglet window to
        display the animation in real-time. If the simulation looks bad to your eye, restart the kernel and rerun the
        simulation. There is currently a bug where if you try to rerun the simulation in the same kernel with show_window=True, it will be
        extremely slow.

        Parameters
        ----------
        
        cell_max_length : float
            Maximum length a cell can reach before dividing (micron)
        cell_width : float
            the average cell width in the simulation (micron)
        pix_mic_conv : float
            The micron/pixel size of the image
        gravity : float
            Pressure forcing cells into the trench. Typically left at zero, but can be varied if cells start to fall into
            each other or if the simulation behaves strangely.
        phys_iters : int
            Number of physics iterations per simulation frame. Increase to resolve collisions if cells are falling into one
            another, but decrease if cells begin to repel one another too much (too high a value causes cells to bounce off
            each other very hard). 20 is a good starting point
        max_length_var : float
            Variance of the maximum cell length
        width_var : float
            Variance of the maximum cell width
        save_dir : str
            Location to save simulation output
        lysis_p : float
            probability of cell lysis

        Returns
        -------
        cell_timeseries : lists
            A list of parameters for each cell, such as length, width, position, angle, etc. All used in the drawing of the
            scene later
        space : a pymunk space object
            Contains the rigid body physics objects which are the cells.
        """

        self.create_space()
        #space.iterations = 1000
        #space.damping = 0
        #space.collision_bias = 0.0017970074436457143*10
        self.space.collision_slop = 0.
        self.dt = 1 / 20  # time-step per frame
        self.pix_mic_conv_for_sim = 1 / self.pix_mic_conv  # micron per pixel
        scale_factor = self.pix_mic_conv_for_sim * self.resize_amount  # resolution scaling factor

        # Always set the N cells to 1 before adding a cell to the space, and set the mask_label
        self.space.historic_N_cells = 1
        cell1 = Cell(
            length=self.cell_max_length*0.5 * scale_factor,
            width=self.cell_width * scale_factor,
            resolution=60,
            position=(40, 100),
            angle=0.8,
            space=self.space,
            dt= self.dt,
            growth_rate_constant=1,
            max_length=self.cell_max_length * scale_factor,
            max_length_mean=self.cell_max_length * scale_factor,
            max_length_var=self.max_length_var * np.sqrt(scale_factor),
            width_var=self.width_var * np.sqrt(scale_factor),
            width_mean=self.cell_width * scale_factor,
            mother=None,
            lysis_p=self.lysis_p,
            mask_label=1,
            generation = 0,
            N_divisions=0
        )

        if show_window:

            window = pyglet.window.Window(700, 700, "SyMBac", resizable=True)
            options = DrawOptions()
            options.shape_outline_color = (10,20,30,40)
            @window.event
            def on_draw():
                window.clear()
                self.space.debug_draw(options)

            # key press event
            @window.event
            def on_key_press(symbol, modifier):

                # key "E" get press
                if symbol == pyglet.window.key.E:
                    # close the window
                    window.close()

        x = [0]
        self.cell_timeseries = []
        self.cells = [cell1]
        self.historic_cells = [cell1] # A list of cells which will contain all cells ever in the simulation
        self.sim_progress = 0
        if show_window:
            pyglet.clock.schedule_interval(self.step_and_update, interval = self.dt)
            pyglet.app.run()
        else:
            while len(self.cells) < self.max_cells:
                self.step_and_update(self.dt)


        for frame, cells in enumerate(self.cell_timeseries):
            for cell in cells:
                cell.t = frame#

        return self.cell_timeseries, self.space, self.historic_cells

    def create_space(self):
        """
        Creates a pymunk space

        :return pymunk.Space space: A pymunk space
        """

        self.space = pymunk.Space(threaded=True)
        self.space.historic_N_cells = 0
        #space.threads = 2



    def update_pm_cells(self, cells, space):
        """
        Iterates through all cells in the simulation and updates their pymunk body and shape objects. Contains logic to
        check for cell division, and create daughters if necessary.

        :param list(SyMBac.cell.Cell) cells: A list of all cells in the current timepoint of the simulation.

        """
        for cell in cells:
            cell.update_length()
            if cell.is_dividing():
                daughter_details = cell.create_pm_cell()
                if len(daughter_details) > 2: # Really hacky. Needs fixing because sometimes this returns cell_body, cell shape. So this is a check to ensure that it's returing daughter_x, y and angle
                    daughter = Cell(**daughter_details)
                    cell.daughter = daughter
                    daughter.mother = cell
                    #daughter.mo
                    cells.append(daughter)
            else:
                cell.create_pm_cell()
            self.cell_adder(cell, space)
            for _ in range(150):
                space.step(1/100)

    def cell_adder(self, cell, space):
        space.add(cell.body, cell.shape)

    def update_cell_positions(self, cells):
        """
        Iterates through all cells in the simulation and updates their positions, keeping the cell object's position
        synchronised with its corresponding pymunk shape and body inside the pymunk space.

        :param list(SyMBac.cell.Cell) cells: A list of all cells in the current timepoint of the simulation.
        """
        for cell in cells:
            cell.update_position()

    def wipe_space(self, space):
        """
        Deletes all cells in the simulation pymunk space.

        :param pymunk.Space space:
        """
        for body, poly in zip(space.bodies, space.shapes):
            if body.body_type == 0:
                space.remove(body)
                space.remove(poly)


    def step_and_update(self, dt): #dt dummy var in this case
        """
        Evolves the simulation forward

        :param float dt: The simulation timestep
        :param list(SyMBac.cell.Cell)  cells: A list of all cells in the current timestep
        :param pymunk.Space space: The simulations's pymunk space.
        :param int phys_iters: The number of physics iteration in each timestep
        :param int ylim: The y coordinate threshold beyond which to delete cells
        :param list cell_timeseries: A list to store the cell's properties each time the simulation steps forward
        :param int list: A list with a single value to store the simulation's progress.
        :param str save_dir: The directory to save the simulation information.

        Returns
        -------
        cells : list(SyMBac.cell.Cell)

        """

        for cell in self.cells:
            self.historic_cells.append(cell)
            if norm.rvs() <= norm.ppf(cell.lysis_p) and len(self.cells) > 1:   # in case all cells disappear
                cell.dead = True
                self.cells.remove(cell)
                self.space.step(self.dt)
            else:
                pass
            self.historic_cells.append(cell)


        self.wipe_space(self.space)
        self.update_pm_cells(self.cells, self.space)
        for _ in range(self.phys_iters):
            self.space.step(self.dt)
        self.update_cell_positions(self.cells)

        self.cell_timeseries.append(deepcopy(self.cells))
        if len(self.cells) > self.max_cells:
            with open(self.save_dir+"/cell_timeseries.p", "wb") as f:
                pickle.dump(self.cell_timeseries, f)
            with open(self.save_dir+"/space_timeseries.p", "wb") as f:
                pickle.dump(self.space, f)
            pyglet.app.exit()
            return self.cells
        self.sim_progress += 1
        return self.cells



In [39]:
from joblib import Parallel, delayed

In [40]:
args = {
    "cell_max_length" : 6, 
    "cell_width": 1.1, 
    "max_cells": 30,
    "pix_mic_conv": 0.065,
    "phys_iters": 100,
    "max_length_var": 1, 
    "width_var": 0, 
    "save_dir": "/tmp/",
    "resize_amount": 3,
    "lysis_p": 0,
}

In [41]:
my_simulation = ColonySimulation(**args)

In [42]:
my_simulation.run_simulation(show_window=False)

In [44]:
my_simulation.draw_simulation_OPL(label_masks=True)

Extracting cell properties from the simulation:   0%|          | 0/80 [00:00<?, ?it/s]

Rendering cell optical path lengths:   0%|          | 0/80 [00:00<?, ?it/s]

AssertionError: Cell has 0 negative pixels in x coordinate, try increasing your offset

In [49]:
np.sqrt((769**2 +  673**2))/2

510.9525418275165

In [61]:
plt.imshow(draw_scene(my_simulation.cell_timeseries_properties[-1], False, (769, 673), 500, True, True)[1])

ValueError: operands could not be broadcast together with shapes (142,113) (142,137) (142,113) 

In [16]:
import itertools
import random

import numpy as np
from matplotlib import pyplot as plt
from numba import njit
from skimage.measure import label
from skimage.transform import rescale, rotate
from skimage.morphology import opening, remove_small_objects
from skimage.exposure import rescale_intensity
from skimage.segmentation import find_boundaries
from PIL import Image

div_odd = lambda n: (n // 2, n // 2 + 1)
perc_diff = lambda a, b: (a - b) / b * 100

def generate_curve_props(cell_timeseries):
    """
    Generates individual cell curvature properties. 3 parameters for each cell, which are passed to a cosine function to modulate the cell's curvature. 
    
    Parameters
    ---------
    cell_timeseries : list(cell_properties)
        The output of :meth:`SyMBac.simulation.Simulation.run_simulation()`
    
    Returns
    -------
    outupt : A numpy array of unique curvature properties for each cell in the simulation
    """

    # Get unique cell IDs
    IDs = []
    for cell_list in cell_timeseries:
        for cell in cell_list:
            IDs.append(cell.ID)
    IDs = np.array(IDs)
    unique_IDs = np.unique(IDs)
    # For each cell, assign random curvature properties
    ID_props = []
    for ID in unique_IDs:
        freq_modif = (np.random.uniform(0.9, 1.1))  # Choose one per cell
        amp_modif = (np.random.uniform(0.9, 1.1))  # Choose one per cell
        phase_modif = np.random.uniform(-1, 1)  # Choose one per cell
        ID_props.append([int(ID), freq_modif, amp_modif, phase_modif])
    ID_propps = np.array(ID_props)
    ID_propps[:, 0] = ID_propps[:, 0].astype(int)
    return np.array(ID_props)


def gen_cell_props_for_draw(cell_timeseries_lists, ID_props):
    """
    Parameters
    ----------
    cell_timeseries_lists : list
        A list (single frame) from cell_timeseries, the output from run_simulation. E.g: cell_timeseries[x]
    ID_props : list
        A list of properties for each cell in that frame, the output of generate_curve_props()
    
    Returns
    -------
    cell_properties : list
        The final property list used to actually draw a scene of cells. The input to draw_scene
    """

    cell_properties = []
    for cell in cell_timeseries_lists:
        body, shape = (cell.body, cell.shape)
        vertices = []
        for v in shape.get_vertices():
            x, y = v.rotated(shape.body.angle) + shape.body.position  # .rotated(self.shape.body.angle)
            vertices.append((x, y))
        vertices = np.array(vertices)

        centroid = get_centroid(vertices)
        farthest_vertices = find_farthest_vertices(vertices)
        length = get_distance(farthest_vertices[0], farthest_vertices[1])
        width = cell.width
        separation = cell.pinching_sep
        # angle = np.arctan(vertices_slope(farthest_vertices[0], farthest_vertices[1]))
        angle = np.arctan2((farthest_vertices[0] - farthest_vertices[1])[1],
                           (farthest_vertices[0] - farthest_vertices[1])[0])
        angle = np.rad2deg(angle) + 90

        ID, freq_modif, amp_modif, phase_modif = ID_props[ID_props[:, 0] == cell.ID][0]
        phase_mult = 20
        cell_properties.append([length, width, angle, centroid, freq_modif, amp_modif, phase_modif, phase_mult,
                                separation, cell.mask_label])
    return cell_properties

def get_crop_bounds_2D(img, tol=0):
    mask = img>tol
    x_idx = np.ix_(mask.any(1),mask.any(0))
    start_row, stop_row, start_col, stop_col = x_idx[0][0][0], x_idx[0][-1][0], x_idx[1][0][0], x_idx[1][0][-1]

    return (start_row, stop_row), (start_col, stop_col)

def crop_image(img, rows, cols, pad):
    (start_row, stop_row) = rows
    (start_col, stop_col) = cols
    if len(img.shape)==3:
        return np.pad(img[:,start_row:stop_row, start_col:stop_col], ((0,0),(pad,pad),(pad,pad)))
    else:
        return np.pad(img[start_row:stop_row, start_col:stop_col], pad)


def raster_cell(length, width, separation, pinching=True, FL = False):
    """
    Produces a rasterised image of a cell with the intensiity of each pixel corresponding to the optical path length
    (thickness) of the cell at that point.

    :param int length: Cell length in pixels
    :param int width: Cell width in pixels
    :param int separation: An int between (0, `width`) controlling how much pinching is happening.
    :param bool pinching: Controls whether pinching is happening

    Returns
    -------

    cell : np.array
       A numpy array which contains an OPL image of the cell. Can be converted to a mask by just taking ``cell > 0``.

    """

    L = int(np.rint(length))
    W = int(np.rint(width))
    new_cell = np.zeros((L, W))
    R = (W - 1) / 2

    x_cyl = np.arange(0, 2 * R + 1, 1)
    I_cyl = np.sqrt(R ** 2 - (x_cyl - R) ** 2)
    L_cyl = L - W
    new_cell[int(W / 2):-int(W / 2), :] = I_cyl

    x_sphere = np.arange(0, int(W / 2), 1)
    sphere_Rs = np.sqrt((R) ** 2 - (x_sphere - R) ** 2)
    sphere_Rs = np.rint(sphere_Rs).astype(int)

    for c in range(len(sphere_Rs)):
        R_ = sphere_Rs[c]
        x_cyl = np.arange(0, R_, 1)
        I_cyl = np.sqrt(R_ ** 2 - (x_cyl - R_) ** 2)
        new_cell[c, int(W / 2) - sphere_Rs[c]:int(W / 2) + sphere_Rs[c]] = np.concatenate((I_cyl, I_cyl[::-1]))
        new_cell[L - c - 1, int(W / 2) - sphere_Rs[c]:int(W / 2) + sphere_Rs[c]] = np.concatenate((I_cyl, I_cyl[::-1]))

    if separation > 2 and pinching:
        S = int(np.rint(separation))
        new_cell[int((L - S) / 2) + 1:-int((L - S) / 2) - 1, :] = 0
        for c in range(int((S+1) / 2)):
            R__ = sphere_Rs[-c - 1]
            x_cyl_ = np.arange(0, R__, 1)
            I_cyl_ = np.sqrt(R__ ** 2 - (x_cyl_ - R__) ** 2)
            new_cell[int((L-S) / 2) + c + 1, int(W / 2) - R__:int(W / 2) + R__] = np.concatenate((I_cyl_, I_cyl_[::-1]))
            new_cell[-int((L-S) / 2) - c - 1, int(W / 2) - R__:int(W / 2) + R__] = np.concatenate((I_cyl_, I_cyl_[::-1]))
    new_cell = new_cell.astype(int)
    return new_cell


@njit
def OPL_to_FL(cell, density):
    """

    :param np.ndarray cell: A 2D numpy array consisting of a rasterised cell
    :param float density: Number of fluorescent molecules per volume element to sample in the cell
    :return: A cell with fluorescent reporters sampled in it
    :rtypes: np.ndarray
    """

    cell_normalised = (cell/cell.sum())
    i, j = np.arange(cell_normalised.shape[0]), np.arange(cell_normalised.shape[1])
    indices = [] #needed for njit
    for ii in i:
        for jj in j:
            indices.append((ii,jj))
    weights = cell_normalised.flatten()
    n_molecules = int(density * np.sum(cell))
    choices = np.searchsorted(np.cumsum(weights), np.random.rand(n_molecules)) # workaround for np.random.choice from
    FL_cell = np.zeros(cell.shape)
    for c in choices:
        FL_cell[indices[c]] += 1
    return FL_cell


@njit
def generate_deviation_from_CL(centreline, thickness):
    return np.arange(thickness) + centreline - int(np.ceil(thickness ))

@njit
def gen_3D_coords_from_2D(test_cells, centreline, thickness):
    return np.where(test_cells == thickness) + (generate_deviation_from_CL(centreline, thickness),)

@njit
def convert_to_3D_numba(cell):
    expanded_scene = cell
    volume_shape = expanded_scene.shape[0:] + (int(expanded_scene.max()*2),)
    test_cells = rounder(expanded_scene)
    centreline = int(expanded_scene.max() )
    cells_3D = np.zeros(volume_shape,dtype = np.ubyte)
    for t in range(int(expanded_scene.max() *2 )):
        test_coords = gen_3D_coords_from_2D(test_cells, centreline, t)
        for x, y in zip(test_coords[0], (test_coords[1])):
            for z in test_coords[2]:
                cells_3D[x, y, z] = 1
    return cells_3D

@njit
def rounder(x):
    out = np.empty_like(x)
    np.round(x, 0, out)
    return out

def convert_to_3D(cell):
    cells_3D = convert_to_3D_numba(cell)
    cells_3D = np.moveaxis(cells_3D, -1, 0)
    cells_3D[cells_3D.shape[0]//2:,:, :] = cells_3D[:cells_3D.shape[0]//2,:, :][::-1]
    return cells_3D

def draw_scene(cell_properties, do_transformation, space_size, offset, label_masks, pinching=True):
    """
    Draws a raw scene (no trench) of cells, and returns accompanying masks for training data.

    Parameters
    ----------
    cell properties : list
        A list of cell properties for that frame
    do_transformation : bool
        True if you want cells to be bent, false and cells remain straight as in the simulation
    space_size : tuple
        The xy size of the numpy array in which the space is rendered. If too small then cells will not fit. recommend using the :meth:`SyMBac.drawing.get_space_size` function to find the correct space size for your simulation
    offset : int
        A necessary parameter which offsets the drawing a number of pixels from the left hand side of the image. 30 is a good number, but if the cells are very thick, then might need increasing.
    label_masks : bool
        If true returns cell masks which are labelled (good for instance segmentation). If false returns binary masks only. I recommend leaving this as True, because you can always binarise the masks later if you want.
    pinching : bool
        Whether or not to simulate cell pinching during division

    Returns
    -------
    space, space_masks : 2D numpy array, 2D numpy array

    space : 2D numpy array
        Not to be confused with the pyglet object calledspace in some other functions. Simply a 2D numpy array with an image of cells from the input frame properties
    space_masks : 2D numy array
        The masks (labelled or bool) for that scene.

    """
    space_size = np.array(space_size)  # 1000, 200 a good value
    space = np.zeros(space_size)
    space_masks_label = np.zeros(space_size)
    space_masks_nolabel = np.zeros(space_size)
    #colour_label = [1]

    space_masks = np.zeros(space_size)
    if label_masks == False:
        space_masks = space_masks.astype(bool)

    for properties in cell_properties:
        length, width, angle, position, freq_modif, amp_modif, phase_modif, phase_mult, separation, sim_mask_label = properties
        position = np.array(position)
        x = np.array(position).astype(int)[0] + offset
        y = np.array(position).astype(int)[1] + offset
        OPL_cell = raster_cell(length=length, width=width, separation=separation, pinching=pinching)

        if do_transformation:
            OPL_cell_2 = np.zeros((OPL_cell.shape[0], int(OPL_cell.shape[1] * 2)))
            midpoint = int(np.median(range(OPL_cell_2.shape[1])))
            OPL_cell_2[:,
            midpoint - int(OPL_cell.shape[1] / 2):midpoint - int(OPL_cell.shape[1] / 2) + OPL_cell.shape[1]] = OPL_cell
            roll_coords = np.array(range(OPL_cell_2.shape[0]))
            freq_mult = (OPL_cell_2.shape[0])
            amp_mult = OPL_cell_2.shape[1] / 10
            sin_transform_cell = transform_func(amp_modif, freq_modif, phase_modif)
            roll_amounts = sin_transform_cell(roll_coords, amp_mult, freq_mult, phase_mult)
            for B in roll_coords:
                OPL_cell_2[B, :] = np.roll(OPL_cell_2[B, :], roll_amounts[B])
            OPL_cell = (OPL_cell_2)

        rotated_OPL_cell = rotate(OPL_cell, -angle, resize=True, clip=False, preserve_range=True, center=(x, y))
        cell_y, cell_x = (np.array(rotated_OPL_cell.shape) / 2).astype(int)
        offset_y = rotated_OPL_cell.shape[0] - space[y - cell_y:y + cell_y, x - cell_x:x + cell_x].shape[0]
        offset_x = rotated_OPL_cell.shape[1] - space[y - cell_y:y + cell_y, x - cell_x:x + cell_x].shape[1]
        assert y > cell_y, "Cell has {} negative pixels in y coordinate, try increasing your offset".format(y - cell_y)
        assert x > cell_x, "Cell has negative pixels in x coordinate, try increasing your offset"
        space[
        y - cell_y:y + cell_y + offset_y,
        x - cell_x:x + cell_x + offset_x
        ] += rotated_OPL_cell

        def get_mask(label_masks):

            if label_masks:
                space_masks_label[y - cell_y:y + cell_y + offset_y, x - cell_x:x + cell_x + offset_x] += (rotated_OPL_cell > 0) * sim_mask_label
                #colour_label[0] += 1
                return space_masks_label
            else:
                space_masks_nolabel[y - cell_y:y + cell_y + offset_y, x - cell_x:x + cell_x + offset_x] += (
                                                                                                                   rotated_OPL_cell > 0) * 1
                return space_masks_nolabel
                # space_masks = opening(space_masks,np.ones((2,11)))

        label_mask = get_mask(True).astype(int)
        nolabel_mask = get_mask(False).astype(int)
        label_mask_fixed = np.where(nolabel_mask > 1, 0, label_mask)
        if label_masks:
            space_masks = label_mask_fixed
        else:
            mask_borders = find_boundaries(label_mask_fixed, mode="thick", connectivity=2)
            space_masks = np.where(mask_borders, 0, label_mask_fixed)
            space_masks = opening(space_masks)
            space_masks = space_masks.astype(bool)
        space = space * space_masks.astype(bool)
    return space, space_masks


def get_distance(vertex1, vertex2):
    """
    Get euclidian distance between two sets of vertices.


    :param tuple(float, float) vertex1: Vertex 1
    :param tuple(float, float) vertex2: Vertex 2

    :return: Absolute distance between two points
    :rtype: float
    """
    return abs(np.sqrt((vertex1[0] - vertex2[0]) ** 2 + (vertex1[1] - vertex2[1]) ** 2))


def find_farthest_vertices(vertex_list):
    """Given a list of vertices, find the pair of vertices which are farthest from each other

    Parameters
    ----------
    vertex_list : list(tuple(float, float))
        List of pairs of vertices [(x,y), (x,y), ...]

    Returns
    -------
    array(tuple(float, float))
        The two vertices maximally far apart
    """
    vertex_combs = list(itertools.combinations(vertex_list, 2))
    distance = 0
    farthest_vertices = 0
    for vertex_comb in vertex_combs:
        distance_ = get_distance(vertex_comb[0], vertex_comb[1])
        if distance_ > distance:
            distance = distance_
            farthest_vertices = vertex_comb
    return np.array(farthest_vertices)


def get_midpoint(vertex1, vertex2):
    """

    Get the midpoint between two vertices.

    :param tuple(float, float) vertex1: Vertex 1
    :param tuple(float, float) vertex2: Vertex 2
    :return: Midpoint between vertex 1 and 2
    :rtype: tuple(float, float)
    """
    x_mid = (vertex1[0] + vertex2[0]) / 2
    y_mid = (vertex1[1] + vertex2[1]) / 2
    return np.array([x_mid, y_mid])


def vertices_slope(vertex1, vertex2):
    """
    Get the slope between two vertices

    :param tuple(float, float) vertex1: Vertex 1
    :param tuple(float, float) vertex2: Vertex 2
    :return: Slope between vertex 1 and 2
    :rtype: float
    """
    return (vertex1[1] - vertex2[1]) / (vertex1[0] - vertex2[0])


def midpoint_intercept(vertex1, vertex2):
    """
    Get the y-intercept of the line connecting two vertices

    :param tuple(float, float) vertex1: Vertex 1
    :param tuple(float, float) vertex2: Vertex 2
    :return: Y indercept of line between vertex 1 and 2
    :rtype: float
    """
    midpoint = get_midpoint(vertex1, vertex2)
    slope = vertices_slope(vertex1, vertex2)
    intercept = midpoint[1] - (slope * midpoint[0])
    return intercept


def get_centroid(vertices):
    """Return the centroid of a list of vertices 
    
    :param list(tuple(float, float)) vertices: List of tuple of vertices where each tuple is (x, y)
    :return: Centroid of the vertices.
    :rtype: tuple(float, float)
    """
    return np.sum(vertices, axis=0) / len(vertices)


def place_cell(length, width, angle, position, space):
    """Creates a cell and places it in the pymunk space

    Parameters
    ----------
    length : float
        length of the cell
    width : float
        width of the cell
    angle : float
        rotation of the cell in radians counterclockwise
    position : tuple
        x,y coordinates of the cell centroid
    space : pymunk.space.Space
        Pymunk space to place the cell in

    Returns
    -------
    nothing, updates space

    """
    angle = np.rad2deg(angle)
    x, y = np.array(position).astype(int)
    OPL_cell = raster_cell(length=length, width=width)
    rotated_OPL_cell = rotate(OPL_cell, angle, resize=True, clip=False, preserve_range=True)
    cell_y, cell_x = (np.array(rotated_OPL_cell.shape) / 2).astype(int)
    offset_y = rotated_OPL_cell.shape[0] - space[y - cell_y:y + cell_y, x - cell_x:x + cell_x].shape[0]
    offset_x = rotated_OPL_cell.shape[1] - space[y - cell_y:y + cell_y, x - cell_x:x + cell_x].shape[1]
    space[y - cell_y + 100:y + cell_y + offset_y + 100,
    x - cell_x + 100:x + cell_x + offset_x + 100] += rotated_OPL_cell


def transform_func(amp_modif, freq_modif, phase_modif):
    def perm_transform_func(x, amp_mult, freq_mult, phase_mult):
        return (amp_mult * amp_modif * np.cos(
            (x / (freq_mult * freq_modif) - phase_mult * phase_modif) * np.pi)).astype(int)

    return perm_transform_func


def scene_plotter(scene_array, output_dir, name, a, matplotlib_draw):
    if matplotlib_draw == True:
        plt.figure(figsize=(3, 10))
        plt.imshow(scene_array)
        plt.tight_layout()
        plt.savefig(output_dir + "/{}_{}.png".format(name, str(a).zfill(3)))
        plt.clf()
        plt.close('all')
    else:
        im = Image.fromarray(scene_array.astype(np.uint8))
        im.save(output_dir + "/{}_{}.tif".format(name, str(a).zfill(3)))


def make_images_same_shape(real_image, synthetic_image, rescale_int=True):
    """ Makes a synthetic image the same shape as the real image """

    assert real_image.shape[0] < synthetic_image.shape[
        0], "Real image has a higher diemsion on axis 0, increase y_border_expansion_coefficient"
    assert real_image.shape[1] < synthetic_image.shape[
        1], "Real image has a higher diemsion on axis 1, increase x_border_expansion_coefficient"

    x_diff = synthetic_image.shape[1] - real_image.shape[1]
    remove_from_left, remove_from_right = div_odd(x_diff)
    y_diff = synthetic_image.shape[0] - real_image.shape[0]
    if real_image.shape[1] % 2 == 0:
        if synthetic_image.shape[1] % 2 == 0:
            if y_diff > 0:
                synthetic_image = synthetic_image[y_diff:, remove_from_left - 1:-remove_from_right]
            else:
                synthetic_image = synthetic_image[:, remove_from_left:-remove_from_right]
                real_image = real_image[abs(y_diff):, :]
        elif synthetic_image.shape[1] % 2 == 1:
            if y_diff > 0:
                synthetic_image = synthetic_image[y_diff:, remove_from_left:-remove_from_right]
            else:
                synthetic_image = synthetic_image[:, remove_from_left:-remove_from_right]
                real_image = real_image[abs(y_diff):, :]
    elif real_image.shape[1] % 2 == 1:
        if synthetic_image.shape[1] % 2 == 0:
            if y_diff > 0:
                synthetic_image = synthetic_image[y_diff:, remove_from_left:-remove_from_right]
            else:
                synthetic_image = synthetic_image[:, remove_from_left:-remove_from_right]
                real_image = real_image[abs(y_diff):, :]
        elif synthetic_image.shape[1] % 2 == 1:
            if y_diff > 0:
                synthetic_image = synthetic_image[y_diff:, remove_from_left - 1:-remove_from_right]
            else:
                synthetic_image = synthetic_image[:, remove_from_left:-remove_from_right]
                real_image = real_image[abs(y_diff):, :]

    if rescale_int:
        real_image = rescale_intensity(real_image.astype(np.float32), out_range=(0, 1))
        synthetic_image = rescale_intensity(synthetic_image.astype(np.float32), out_range=(0, 1))
    return real_image, synthetic_image


def get_space_size(cell_timeseries_properties):
    """
    :param cell_timeseries_properties: A list of cell properties over time. Generated from :meth:`SyMBac.simulation.Simulation.draw_simulation_OPL`
    :return: Iterates through the simulation timeseries properties, finds the extreme cell positions and retrieves the required image size to fit all cells into.
    :rtype: tuple(float, float)
    """
    max_x, max_y = 0, 0
    for timepoint in cell_timeseries_properties:
        for cell in timepoint:
            x_, y_ = np.ceil(cell[3]).astype(int)
            length_ = np.ceil(cell[0]).astype(int)
            width_ = np.ceil(cell[1]).astype(int)
            max_y_ = y_ + length_
            max_x_ = x_ + width_
            if max_x_ > max_x:
                max_x = max_x_
            if max_y_ > max_y:
                max_y = max_y_
    return (int(1.2 * max_y), int(1.5 * max_x))


def clean_up_mask(mask):
    return remove_small_objects(label(mask))