# Import modules

In [1]:
from pathlib import Path
import tifffile
from matplotlib import pyplot as plt
from skimage import morphology, filters, measure, feature, registration
import numpy as np
import pandas as pd
import math
import miguel_tools as mt
import iTools as iT
import seaborn as sns
import math
from collections import defaultdict
from scipy.spatial import distance
import napari
from skimage import data
from scipy.ndimage import gaussian_laplace
import datetime
from tqdm.notebook import tqdm

# To add metadata to an image

In [2]:
metadata = {}

In [3]:
# tifffile.imsave(file="test.tif", data=np.uint8(np.ones(shape=(125, 125))), imagej=True, 
#                 metadata={'axes': 'TYX', 'name':'pablito', 'timeseries':[1,2,3,4,5,6]})

# mt.extract_metadata_TiFF("test.tif", short=False)

In [4]:
%gui qt

from PyQt5.QtWidgets import QFileDialog

def gui_fname(dir=None):
    """Select a file via a dialog and return the file name."""
    if dir is None: dir ='./'
    fname = QFileDialog.getOpenFileName(None, "Select data file...", 
                dir, filter="All files (*);; SM Files (*.sm)")
    return fname[0]

def gui_dir(dir=None):
    """Select a directory path via a dialog and return the path."""
    if dir is None: dir ='./'
    dir_name = QFileDialog.getExistingDirectory(None, 'Select a folder:', 'C:\\', QFileDialog.ShowDirsOnly)
    return dir_name

In [5]:
def calculate_shift_with_marker(im, slide):
    with napari.gui_qt():
        viewer = napari.Viewer()
        viewer.add_image(im[slide-2:slide+2])
        viewer.add_points(name='point1', face_color='red', opacity=0.9, symbol='square', size=1, n_dimensional=False)
#         viewer.add_points(name='point2', face_color='green', opacity=0.9, symbol='disc', size=1, n_dimensional=False)

    p1, p2 = viewer.layers['point1'].__dict__['_data'][:,1:]
    p1y, p1x = p1[0], p1[1]
    p2y, p2x = p2[0], p2[1]

    manual_shift = [int(round(i)) for i in [p2y - p1y, p2x - p1x]]
    return manual_shift

def calculate_shift(im1, im2):
    shift, error, diffphase = registration.phase_cross_correlation(im1, im2)
#     print(shift, error, diffphase)
    return [int(i) for i in shift]

def shiftIM(im, shift):
    Xshifted = np.roll(im, shift=shift[1], axis=1)
    Yshifted = np.roll(Xshifted, shift=shift[0], axis=0)
    return Yshifted

def registration_checkpoint(im1, im2, shift, savein=None):
    fig, ax = plt.subplots(ncols=2, figsize=(20,10))
    ax[0].set_title("Raw")
    ax[0].imshow(im1, alpha=0.5, cmap='Reds')
    ax[0].imshow(im2, alpha=0.5, cmap='Greens')
    
    im2_shifted = shiftIM(im2, shift)
    ax[1].set_title("Shifted")
    ax[1].imshow(im1, alpha=0.5, cmap='Reds')
    ax[1].imshow(im2_shifted, alpha=0.5, cmap='Greens')
    
    for i in ax:
        i.axis('off')
    
    if savein != None:
        plt.savefig(savein, dpi=300, bbox_inches = 'tight')



In [6]:
# folder_jpg = Path(r"D:\TEMP_PICTURES\Lichun\Lichun_Plate1_Full")
folder_jpg = Path(gui_dir())
jpg_paths = sorted(list(folder_jpg.glob("*jpg")))

save_folder = folder_jpg.parent / "{}_results".format(folder_jpg.name)
save_folder.mkdir(exist_ok=True)

In [7]:
timepoints = [datetime.datetime.strptime(i.stem.rstrip("W"), "%Y%m%d_%H%M%S") for i in jpg_paths]
# timepoints = [datetime.datetime.strptime(str(i).split(" ")[-1], "%H:%M:%S") for i in timepoints]

In [8]:
metadata['timepoints'] = timepoints

In [None]:
ims = [plt.imread(i)[:, :, 0] for i in tqdm(jpg_paths)] # Reading images in the folder
shapes,  = list(set([i.shape for i in ims]))
shapes

HBox(children=(FloatProgress(value=0.0, max=420.0), HTML(value='')))

In [None]:
# Correcting different sizes for the plate, taking an equal dimension square in the center
plate_size = 120 # mm
pixel_plate_dim = (3324, 3312) # Pixel plate dimensions extracted from Image J measurements
min_dim = np.min(pixel_plate_dim)

pixel_size = plate_size/min_dim
pixel_size
print("{} mm/pixel --> {} um/pixel".format(round(pixel_size, 3), round(pixel_size*pow(10, 3), 2)))

heigth = int((shapes[0]-min_dim)/2)
width = int((shapes[1]-min_dim)/2)

In [None]:
metadata['pixel_size_mm'] = pixel_size

In [None]:
imscut = np.stack([i[heigth:heigth+min_dim, width:width+min_dim] for i in tqdm(ims)], axis=0)
# Delete ims variable to free up memory usage
del ims

In [None]:
# To visually check the slide where the pictures start rotated
# mt.napariView(imscut)

In [None]:
rot_from = int(input("Choose the slide number where the rotation starts:\t"))
# rot_from = 119
for i in tqdm(range(rot_from, imscut.shape[0])):
    imscut[i] = np.rot90(imscut[i], axes=(1,0))

In [None]:
metadata['slide_rotated'] = rot_from

In [None]:
# mt.napariView(imscut)

In [None]:
# tifffile.imsave(file=save_folder / "{}_rotated.tif".format(folder_jpg.name), data=np.uint8(imscut), imagej=True)

In [None]:
# Calculating shift with the two images where the change happends
shift = calculate_shift(imscut[rot_from-1], imscut[rot_from]) 
shift

In [None]:
registration_checkpoint(imscut[rot_from-1], 
                        imscut[rot_from], 
                        shift=shift,
                        savein=save_folder / "{}_registration_check.png".format(folder_jpg.name))

In [None]:
# Correcting shift
shift_from = rot_from
first_part = imscut[:shift_from]
second_part = []
for i in tqdm(range(shift_from, imscut.shape[0])):
#     imscut[i] = 
    second_part.append(shiftIM(imscut[i], shift))

second_part = np.stack(second_part, axis=0)
combined = np.concatenate([first_part, second_part], axis=0)
imscut.shape, combined.shape

In [None]:
# mt.napariView(combined)

In [None]:
# # Using the shift values to relocate all the pictures after the change happened.
# shift_from = 119
# for i in tqdm(range(shift_from, imscut.shape[0])):
#     imscut[i] = shiftIM(imscut[i], shift)

In [None]:
imscut_highlithed = imscut - np.max(imscut, axis=0)

In [None]:
tifffile.imsave(file=save_folder / "{}_rotated_reg.tif".format(folder_jpg.name), 
                data=np.uint16(imscut), 
                imagej=True,
               metadata=metadata)

In [None]:
# mt.napariView(imscut)

In [None]:
# # This can be removed because the plate is already cut
# image = ims[-1]
# otsu = image > filters.threshold_otsu(image)
# otsu_filled = morphology.remove_small_holes(otsu, area_threshold=18000)
# props = measure.regionprops(morphology.label(otsu_filled))
# bbox = [i['bbox'] for i in props if i['area'] > 9000000 ][0]
# dify, difx = bbox[2] - bbox[0],  bbox[3] - bbox[1]
# width = np.max([dify, difx])
# ims_plate_cutted = np.stack([i[bbox[0]:bbox[0]+width, bbox[1]:bbox[1]+width] for i in ims], axis=0)

# Read the fixed image

In [None]:
imscut = tifffile.imread(save_folder / "{}_rotated_reg.tif".format(folder_jpg.name))
imscut.shape

# Extracting plants with upperleft-downright corners

In [None]:
# # Trying with squares but doesn't give back a proper label shape
# with napari.gui_qt():
#     viewer = napari.Viewer()
#     viewer.add_image(imscut[-1])

# labels = viewer.layers['Shapes'].to_labels()
# labels.shape, imscut[-1].shape
# plt.imshow(labels, alpha=0.5)
# plt.imshow(imscut[-1], alpha=0.5)

In [None]:
with napari.gui_qt():
    viewer = napari.Viewer()
    viewer.add_image(combined[-1])
    viewer.add_points(name='plants', face_color='red', opacity=0.9, symbol='disc', size=1, n_dimensional=False)
    viewer.layers['plants'].mode = 'add'

In [None]:
coordinates = viewer.layers['plants'].__dict__['_data']
coordinates = coordinates.astype(int)
coordinates = mt.sublist(list(coordinates), 2)
# coordinates

In [None]:
seedling_names = ["seedling_{}".format(i+1) for i in range(len(coordinates))]
# seedling_names

In [None]:
plants_separated = []
for coord in tqdm(coordinates):
    y1, x1 = coord[0]
    y2, x2 = coord[1]
    plant_cut = imscut[:, y1:y2, x1:x2]
    plants_separated.append(plant_cut)
    del plant_cut

In [None]:
# mt.napariView(plants_separated)

In [None]:
del imscut

In [None]:
# test = plants_separated[0].copy()

# Showing plants extracted

In [None]:
n = math.ceil(np.sqrt(len(plants_separated)))
fig, ax = plt.subplots(nrows=n, ncols=n, figsize=(10,10))
ax = ax.ravel()
for ix, i in enumerate(ax):
    try:
        i.imshow(plants_separated[ix][-1])
        i.axis('off')
        i.set_title("seedling {}".format(ix+1))
    except:
        i.axis('off')
        
plt.savefig(save_folder / "{}_seedlings_extracted.png".format(folder_jpg.name), dpi=300, bbox_inches = 'tight')

In [None]:
# plt.imshow(plants_separated[0][-1][0:200,plants_separated[0].shape[-1]-200:plants_separated[0].shape[-1]])

# Re-aligment

In [None]:
# For realignment seems to works better with new list appending the realigned ones.
plants_separated_aligned = []

for ix, each_plant in enumerate(plants_separated):
#     shift_from = 119
    first_part = each_plant[:shift_from]
    second_part = []
    
    shift_value = calculate_shift(each_plant[shift_from-1], each_plant[shift_from])
    for sl in range(shift_from, each_plant.shape[0]):
#         each_plant[sl] = shiftIM(each_plant[sl], shift_value)
        second_part.append(shiftIM(each_plant[sl], shift_value))
#     plants_separated[ix] = each_plant
    second_part = np.stack(second_part, axis=0)
    plant_combined = np.concatenate([first_part, second_part], axis=0)
    plants_separated_aligned.append(plant_combined)

In [None]:
del plants_separated

In [None]:
# # Doesn't work so well
# shift_from = 119
# for ix, i in tqdm(enumerate(plants_separated)):
#     shift_plant = calculate_shift(plants_separated[ix][shift_from-1], plants_separated[ix][shift_from])
# #     registration_checkpoint(plants_separated[0][shift_from-1], plants_separated[0][shift_from], shift_plant)
#     print(shift_plant)
#     for sl in range(shift_from, i.shape[0]):
# #         i[ix][sl] = shiftIM(i[sl], shift_plant)
#         print(ix, sl)

In [None]:
for e in plants_separated_aligned:
    print(e.shape)

In [None]:
mt.napariView(plants_separated_aligned)

# Saving plants extracted

In [None]:
for ix, i in tqdm(enumerate(plants_separated_aligned)):
    meta_plant = metadata.copy()
    meta_plant['seedling'] = seedling_names[ix]
    tifffile.imsave(file=save_folder / "{}_seedling_{}.tif".format(folder_jpg.name, ix+1), 
                    data=np.uint8(plants_separated_aligned[ix]), 
                    imagej=True,
                   metadata=metadata)

# Reading metadata

In [None]:
meta_dict = mt.extract_metadata_TiFF(Path(r"D:\TEMP_PICTURES\Lichun\Lichun_Plate1_Full_results\Lichun_Plate1_Full_seedling_1.tif"), short=False)

In [None]:
d = {}
for i in meta_dict['ImageDescription'][0].split("\n"):
    try:
        d[i.split("=")[0]] = i.split("=")[1]
    except:
        pass

In [None]:
# d

In [None]:
# Discard possibly

def extract_baby_plant(im_slice):
    mask_otsu = im_slice < filters.threshold_otsu(im_slice)
    labelIM = morphology.label(mask_otsu)
    props = measure.regionprops(label_image=labelIM, intensity_image=im_slice)

    features = []

    for feat in ['label', 'area', 'perimeter', 'bbox']:
        serie = pd.Series([i[feat] for i in props], name=feat)
        features.append(serie)

    DF = pd.DataFrame(features).T

    biggest_label = DF.sort_values(['area'], ascending=False).reset_index().loc[0]['label']
    bbox = DF.sort_values(['area'], ascending=False).reset_index().loc[0]['bbox']
    baby_mask = np.where(labelIM == biggest_label, 1, 0)
    
    return baby_mask, bbox

plant_highlithed = filters.sobel(plants_separated_aligned[0])

plant_highlithed_rot = np.rot90(plant_highlithed, axes=(1,2))

# fix pixels by mm plot

In [None]:
def extract_points(im):
    with napari.gui_qt():
        viewer = napari.Viewer()
        viewer.add_image(im)
        viewer.add_points(name='path', face_color='red', opacity=0.9, symbol='square', size=1, n_dimensional=False)
        viewer.layers['path'].mode = 'add'

    DATA = pd.DataFrame(viewer.layers['path'].__dict__['_data'], columns=['timepoint', 'y', 'x'])
    DATA.insert(loc=0, column='cap', value=DATA.index)
    DATA.insert(loc=0, column='time', value=pd.Series(timepoints))
    DATA['x_mm'] = DATA.x.apply(lambda x: x*pixel_size)
    DATA['y_mm'] = DATA.y.apply(lambda y: y*pixel_size)

    
    angle_of_line = lambda point1, point2: math.degrees(math.atan2(point2['y']-point1['y'], \
                                                                   point2['x']-point1['x'])) 
    
    middle_point = lambda point1, point2: ((point1['y']+(point2['y'] - point1['y'])/2), \
                                                (point1['x']+(point2['x'] - point1['x'])/2))
    
    eu = lambda point1, point2: distance.euclidean((point1['y_mm'],point1['x_mm']), 
                                                   (point2['y_mm'],point2['x_mm']))
    # angle_of_line(DATA.loc[1], DATA.loc[2])

    middles = defaultdict(list)
    for i in DATA.index:
        try:
            middle = middle_point(DATA.loc[i], DATA.loc[i+1])
            middles['caps'].append("{}-{}".format(int(DATA.loc[i]['cap']), int(DATA.loc[i+1]['cap'])))
            middles['y_degrees_label'].append(middle[0])
            middles['x_degrees_label'].append(middle[1])

            
#             position_point1_mm = (DATA.loc[i]['y_mm'], DATA.loc[i]['x_mm'])
#             position_point2_mm = (DATA.loc[i+1]['y_mm'], DATA.loc[i+1]['x_mm'])
#             position_point1 = (DATA.loc[i]['y'], DATA.loc[i]['x'])
#             position_point2 = (DATA.loc[i+1]['y'], DATA.loc[i+1]['x'])
#             middles['position_point1_mm'].append(position_point1_mm)
#             middles['position_point2_mm'].append(position_point2_mm)
#             middles['position_point1'].append(position_point1)
#             middles['position_point2'].append(position_point2)
            
            time_dif = (DATA.loc[i+1]['time'] - DATA.loc[i]['time']).total_seconds()
            degree = angle_of_line(DATA.loc[i], DATA.loc[i+1])
            euclidean = eu(DATA.loc[i], DATA.loc[i+1])            
            middles['degree'].append(degree)
            middles['euclidean (mm)'].append(euclidean)
            middles['time_dif (s)'].append(time_dif)
            middles['speed (mm/s)'].append(euclidean/time_dif)
        except:
            pass

    DFresults = pd.DataFrame(middles)
    DFresults['speed (mm/h)'] = DFresults['speed (mm/s)'].apply(lambda x: x*60*60)
    
    
    return DATA, DFresults

In [None]:
DATA, DFresults = extract_points(plant_highlithed_rot)

In [None]:
display(DATA)
display(DFresults)

In [None]:
# def main(files, ims, index):
#     im = ims[index]
#     r, g, b = [im[:,:,:, i] for i in range(im.shape[-1])]
#     baby_plant_highlighted = hightlight_babyplant(r)
    
    
#     ## Extraction of the plant for one of the plots
#     med = np.mean(r, axis=0)
#     med_sharp = filters.unsharp_mask(med, radius=8, amount=40)
#     a, bbox = extract_baby_plant(med_sharp)
#     a = morphology.remove_small_holes(a.astype(bool), area_threshold=300)
#     mask = np.where(a == True, 50, 1)
#     result = mask * r # Cleaned background
#     result_cut = result[:, bbox[0]:bbox[2], bbox[1]:bbox[3]]
#     max_result_cut = np.max(result_cut, axis=2)
    
#     ## Lateral view of the plant extracted
#     fig = plt.figure(figsize=(20,20))
#     plt.imshow(max_result_cut)
#     xaxis, yaxis = max_result_cut.shape[-1], max_result_cut.shape[-2]
#     xticks = np.linspace(1, xaxis, 16)
#     pixel_size = 2 # mm
#     xlabels = np.array([int(i)*pixel_size for i in xticks])
#     plt.xticks(ticks=xticks, labels=xlabels, rotation=90, size=12);
#     plt.xlabel("mm", size=15)

#     yticks = np.linspace(1, yaxis, 16)
#     ylabels = [timepoints[i] for i in range(len(yticks))]
#     plt.yticks(ticks=yticks, labels=ylabels, rotation=0, size=12);
#     plt.ylabel("time (s)", size=15)
#     plt.title("{}".format(files[index].stem))

#     plt.grid(color='black')
#     plt.savefig(folder / "{}_lateral_view.png".format(files[index].stem), dpi=300,bbox_inches = 'tight')
    
    

    

In [None]:
def degrees_plot(DATA, results, seedling_name):
    ## Degrees plot
    fig = plt.figure(figsize=(15,5))
    # drift = 0
    # last_position = 0
    # plt.imshow(plant_highlithed_rot[-1])
    for name, group in DATA.groupby(['timepoint']):
        plt.scatter(group.x, group.y, label=timepoints[int(name)])

    plt.plot(DATA.x, DATA.y, color='grey', linestyle='--')
    for ix, dg in enumerate(DFresults['degree']):
        plt.text(x=DFresults.loc[ix]['x_degrees_label']-2,y=DFresults.loc[ix]['y_degrees_label']+2, s="{}$^\circ$".format(math.ceil(dg)), size=12)

    plt.ylabel(" um ")
    plt.xlabel(" distance")
    plt.ylim(DATA.y.max()+10, DATA.y.min()-10)
    # plt.xticks(ticks=np.linspace(100, 200, 10), labels=[int(i) for i in np.linspace(100, 200, 10)], rotation=45);
    plt.title("{}".format(seedling_name))
    plt.grid(axis='both')

    # plt.legend(bbox_to_anchor=(1.04,1), loc="upper left");
    plt.savefig(save_folder / "{}_lateral_view_degrees.png".format(seedling_name), dpi=300, bbox_inches = 'tight')

In [None]:

degrees_plot(DATA, DFresults, seedling_name="{}_{}".format(folder_jpg.stem, seedling_names[0]))

In [None]:
mt.writting_excel(DF=[DATA, DFresults], 
              pathname=save_folder / "{}_data.xlsx".format(folder_jpg.stem), 
              sheet_name=['raw_data', 'degrees'])

In [None]:
[i for i in ax.get_xticks()]

In [None]:
import seaborn as sns

In [None]:
sns.catplot(kind='strip', x='time2', y='y_um', data=DATA)
plt.xticks(rotation='vertical');
plt.ylim(DATA.y_um.max()+10, DATA.y_um.min()-10);

In [None]:
DATA