# PyMMM: Python Mother Machine Manager

In [None]:
from PyMMM_main.Experiment.experiment import * 

In [None]:
os.getcwd()

The first thing to do is to define your experiment's directory, and instantite an `Experiment` object using that directory.
Let's also print the experiment's properties

In [None]:
directory = os.getcwd()
my_experiment = Experiment(
    directory,  
    save_filetype="png"
)
print(my_experiment)

Here, we can choose the times and FOVs to take forward for registration

In [None]:
#my_experiment.set_analysis_times(0, 50)
my_experiment.discard_FOVs([f"xy0{str(x).zfill(2)}" for x in range(30, 32)]) 

In [None]:
my_experiment.FOVs

### Let's now ensure the quality of the data

We're going to register the images to minimise the effects of stage drift.

* First we will set the experiment's registration channel, here we are using `"PC"` for phase contrast.

In [None]:
my_experiment.registration_channel = "PC"

Next we will do two things:
* The first is to calculate mean images for the experiment. These are average images which will be used for image registration
* By default, PyMMM takes an average over the last 15 images of the experiment, as the experiment's stage drift has likely stopped by then.
* We can also rotate the mean image if the experiment was not properly aligned. This will rotate the mean images, and all other images will be registered against them, rotating them in the process.
* Let's just check our experiment without the rotation argument.

In [None]:
my_experiment.mean_amount = 10

In [None]:
my_experiment.get_mean_images(plot = True)

It looks like there's some rotation in our experiment. Let's rotate the images and recalculate the mean images.

In [None]:
my_experiment.get_mean_images(rotation = 0.9, plot = True)

In [None]:
my_experiment.rotation

That's looking much better.
* We can now register the images. We only need to do this once, so we can check whether the experiment has already been registered with the `is_registered` property.
* Here we check if an experiment is registered, and not, we register it with all of our CPU cores.
* There is also a `force` argument, which if set to `True` will overwrite any previous registered images.

In [None]:
my_experiment.register_experiment(force=True, mode=5, sum=False, n_jobs = -1,  y_lims = (300,900), x_lims = (0,-1))

### Let's now find the trenches

There are several methods you can use to explore the experiment:

In [None]:
#Indexing using numbers
an_image = my_experiment.get_image(FOV = 1, channel = 0, time = 1, plot = True, registered = True)
#Indexing using channel and FOV names
#an_image = my_experiment.get_image(FOV = "A23", channel = "Green", time = 1, plot = True, registered = True)

In [None]:
#Getting the mean of a timestack for a particular FOV
a_mean_image = my_experiment.get_mean_of_timestack(1, 1, plot=True)

In [None]:
mean_timestacks = [my_experiment.get_mean_of_timestack(FOV, "PC", ) for FOV in my_experiment.FOVs]

### Finding the x limits

In [None]:
# Getting the mean of the timestack over the x direction
a_t_x_mean = my_experiment.mean_t_x(1, my_experiment.registration_channel, plot = True)

In [None]:
#Adding a gaussian blurring using the sigma argument
sigma = 4

# Use a convolution filter if gaussian blurring isn't working well
a_t_x_mean = my_experiment.mean_t_x(1, my_experiment.registration_channel, sigma = sigma, plot = True)
f = deepcopy(a_t_x_mean[175:300])
plt.plot(f)

In [None]:
# Now finding trench peaks for a particular FOV, giving a distance argument, and using our sigma value from before
distance = 100
prominence=10
peaks = my_experiment.find_trench_peaks(0, 
                                        my_experiment.registration_channel, 
                                        sigma = sigma, distance = distance, 
                                        prominence=prominence, 
                                        conv_filter=f,
                                        plot = True)
print(my_experiment.dims)

We've found good values for `sigma` and `distance`.
Let's now call `find_all_trench_x_positions` on the phase contrast channel to identify all the x limits of the trenches in every FOV

In [None]:
trench_x_positions = my_experiment.find_all_trench_x_positions(my_experiment.registration_channel, 
                                                               sigma = sigma, 
                                                               distance = distance, 
                                                               prominence=prominence, 
                                                               conv_filter=f,
                                                               shrink_scale = 4, 
                                                               trench_width = 128,
                                                               use_exact_trench_width = True, 
                                                               plot = False, 
                                                               plot_save=True)

### Discard any bad trenches

In [None]:
my_experiment.discard_trenches([20,21,30,202,203,216,217,263,264,268])

### Finding the y limits
So we've successfully found the x limits of the trenches. Let's now find the y limits

In [None]:
#Adding a gaussian blurring using the sigma argument
sigma = 40
a_t_y_mean = my_experiment.mean_t_y(1, my_experiment.registration_channel, sigma = sigma, plot = True)

In [None]:
#Adding a gaussian blurring using the sigma argument
sigma = 40
height = 5000
distance = my_experiment.dims[0]
a_t_y_mean, y_peak = my_experiment.find_lane_peaks(1, sigma = sigma, distance = distance, height=height, plot = True)

In [None]:
y_peaks = {FOV: my_experiment.find_lane_peaks(FOV, sigma=sigma, distance=distance, height=height, plot=True)[1] for FOV in my_experiment.FOVs}

Set the y offsets from the halo (knowing the trench length can be helpful here)

In [None]:
my_experiment.trench_y_offsets = (1152, 272)

Check that only one y peak has been found for each FOV

In [None]:
y_peaks

Find the y limits

In [None]:
y_peaks = my_experiment.find_all_trench_y_positions_PC(channel=my_experiment.PC_channel,sigma=sigma,distance=distance,height=height,plot=False, plot_save=True)

### Extract the trenches as png

In [None]:
my_experiment.extract_trenches(force=True)

In [None]:
print(len(os.listdir("trenches")))

### Extract trenches as zarr

In [None]:
import zarr
from numcodecs import Blosc
from glob import glob
import numpy as np

n_trenches = 0
for FOV in my_experiment.FOVs:
    x_pos = my_experiment.pruned_experiment_trench_x_lims[FOV]
    n_trenches += len(x_pos)

trench_num = 0
trench_id_dict = dict()
for FOV in my_experiment.FOVs:
    y_pos = my_experiment.y_peaks[FOV][0]
    x_pos = my_experiment.pruned_experiment_trench_x_lims[FOV]
    trenches_id = []
    if type(x_pos) == dict:
        for key, value in x_pos.items():
            trenches_id.append([value, y_pos, trench_num])
            trench_num += 1
        trench_id_dict[FOV] = trenches_id
    else:
        for i, (L, R) in enumerate(x_pos):
            trenches_id.append([(L,R), y_pos, trench_num])
            trench_num += 1
        trench_id_dict[FOV] = trenches_id

if type(x_pos) == dict:
    trench_x_size = list(x_pos.values())[0][1] - list(x_pos.values())[0][0]
else:
    trench_x_size = x_pos[0][1] - x_pos[0][0]
trench_y_size = my_experiment.trench_y_offsets[0] - my_experiment.trench_y_offsets[1]
compressor = Blosc(cname='zstd', clevel=9, shuffle=Blosc.BITSHUFFLE)
z1 = zarr.open(f'{my_experiment.directory}/trenches.zarr', mode='w', shape=(n_trenches, len(my_experiment.times), len(my_experiment.channels), trench_y_size, trench_x_size),
                chunks=(1,1,1,trench_y_size, trench_x_size), dtype='uint16', compressor = compressor)


def extract_trenches_from_image(FOV, t, time, c, channel):
    image = my_experiment.get_image(FOV, channel, time, registered=my_experiment.is_registered)
    for (L, R), y_pos, tr in trench_id_dict[FOV]:
            trench = image[y_pos - my_experiment.trench_y_offsets[0]:y_pos - my_experiment.trench_y_offsets[1], L:R]
            z1[tr,t,c] = trench

a = list(product(trench_id_dict.keys(), enumerate(my_experiment.times), enumerate(my_experiment.channels)))

Parallel(n_jobs=-1)(delayed(extract_trenches_from_image)(FOV, t, time, c, channel) for FOV, (t, time), (c, channel) in tqdm(a))

### Save trench zarr metadata to json files
Aim here is to create a loadable mapping from the trench zarr indices to the information which was available prior to extraction. Three json files are created. When loaded back as a dictionary, the key of the dictionary corresponds to the relevant index of the trench zarr.
* A mapping from zarr trench number -> FOV
* A mapping from zarr timepoint -> experiment time point (could be a useful record if any time points were discarded)
* A mapping from zarr channel index -> channel string

In [None]:
import json

In [None]:
trench_num = 0
FOV_to_trench_dict = dict()
for FOV in my_experiment.FOVs:
    y_pos = my_experiment.y_peaks[FOV][0]
    x_pos = my_experiment.pruned_experiment_trench_x_lims[FOV]
    FOV_to_trench = []
    if type(x_pos) == dict:
        for key, value in x_pos.items():
            FOV_to_trench.append(trench_num)
            trench_num += 1
    else:
        for i, (L, R) in enumerate(x_pos):
            FOV_to_trench.append(trench_num)
            trench_num += 1
    FOV_to_trench_dict[FOV] = FOV_to_trench

In [None]:
### invert the mapping, i.e index in zarr array is a dict key which maps to experimental metadata
trench_to_FOV_dict = dict()
for key, value in FOV_to_trench_dict.items():
    for tr in value:
        trench_to_FOV_dict[tr] = key
trench_to_FOV_dict

In [None]:
nd2_file = "20230803_SB7_segmentation"  # could be imported from nd2 metadata json

file = "metadata_trench_zarr_FOVs_" + nd2_file + ".json"
with open(file, 'w') as f: 
    json.dump(trench_to_FOV_dict, f)

In [None]:
channel_info_dict = dict()
for count, channel in enumerate(my_experiment.channels):
    channel_info_dict[str(count)] = channel
    
file = "metadata_trench_zarr_channels_" + nd2_file + ".json"
with open(file, 'w') as f: 
    json.dump(channel_info_dict, f)

In [None]:
trench_zarr_times = dict()
times = my_experiment.times
for idx, value in enumerate(times):
    trench_zarr_times[idx] = value
    
file = "metadata_trench_zarr_times_" + nd2_file + ".json"
with open(file, 'w') as f: 
    json.dump(trench_zarr_times, f)

In [None]:
# to load in a json as a dictionary
with open(file, 'r') as f:
    test = json.load(f)
test