In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import re
from pathlib import Path
import shutil
import tifffile
import cv2
import math
from napari_correct_drift import CorrectDrift
from skimage import filters

# MM3 req 1: Channels are mostly vertical
# MM3 Req 2: Channels have at least 20 pixels from their ends to the top and bottom edge of the image
Channel ends are in the top and bottom third of the image, regardless of orientation.
There are no artefacts in the image above and below the channels. The numbers on the mother machine can confuse the script.

In [None]:
root_dir = '/Users/nvivanco/Desktop/20240926/dimm_giTG69_glucose_1'
unstacked_path = napari_ready_format_drift_correct(root_dir, 'dimm', c = 0)

In [None]:
unstacked_path = '/Users/nvivanco/Desktop/20240926/dimm_giTG69_glucose_1/hyperstacked/drift_corrected/cyx_files_for_mm3'
plot_lines_across_FOVs(unstacked_path, c = 0)

In [None]:
rotate_stack(unstacked_path, c = 0, growth_channel_length = 290)

# Why are trenches not getting detected? This is the function in napari mm3 compile.py responsible for identifying trenches/peaks/microfluidic channels

In [None]:
unstacked_path = '/Users/nvivanco/Desktop/20240926/dimm_giTG69_glucose_1/hyperstacked/drift_corrected/cyx_files_for_mm3/rotated/xy006'

In [None]:
file_dict = org_by_timepoint([unstacked_path])

In [None]:
file_dict

In [None]:
ex_path = file_dict['006'][0]['stacked']

In [None]:
ex_path

In [None]:
# I counted 31 peaks in this image

In [None]:
from scipy.signal import find_peaks_cwt

In [None]:
image_data = tifffile.TiffFile(ex_path).asarray() 

In [None]:
image_data.shape

In [None]:
projection_x = image_data[0].sum(axis=0).astype(np.int32)

In [None]:
projection_x.shape #only trying phase

In [None]:
chan_w = 10
chan_sep = 45

In [None]:
np.arange(chan_w - 5, chan_w + 5)

In [None]:
#min_snr determines the number of peaks, I'm not sure where chan_snr = params["compile"]["channel_detection_snr"] comes from

In [None]:

# find_peaks_cwt is a function which attempts to find the peaks in a 1-D array by
# convolving it with a wave. here the wave is the default Mexican hat wave
# but the minimum signal to noise ratio is specified
# *** The range here should be a parameter or changed to a fraction.
peaks = find_peaks_cwt(
    projection_x, np.arange(chan_w - 5, chan_w + 5), min_snr=1)

In [None]:
len(peaks)

In [None]:
# If the left-most peak position is within half of a channel separation,
# discard the channel from the list.
if peaks[0] < (chan_sep / 2):
    peaks = peaks[1:]
# If the diference between the right-most peak position and the right edge
# of the image is less than half of a channel separation, discard the channel.
if image_data.shape[1] - peaks[-1] < (chan_sep / 2):
    peaks = peaks[:-1]

In [None]:
len(peaks)

In [None]:
# Find the average channel ends for the y-projected image
projection_y = image_data[0].sum(axis=1)
# find derivative, must use int32 because it was unsigned 16b before.
proj_y_d = np.diff(projection_y.astype(np.int32))
# use the top third to look for closed end, is pixel location of highest deriv
onethirdpoint_y = int(projection_y.shape[0] / 3.0)
default_closed_end_px = proj_y_d[:onethirdpoint_y].argmax()
# use bottom third to look for open end, pixel location of lowest deriv
twothirdpoint_y = int(projection_y.shape[0] * 2.0 / 3.0)
default_open_end_px = twothirdpoint_y + proj_y_d[twothirdpoint_y:].argmin()
default_length = default_open_end_px - default_closed_end_px  # used for checks

In [None]:
crop_wp = 10
# go through peaks and assign information
# dict for channel dimensions
chnl_loc_dict = {}
# key is peak location, value is dict with {'closed_end_px': px, 'open_end_px': px}

for peak in peaks:
    # set defaults
    chnl_loc_dict[peak] = {
        "closed_end_px": default_closed_end_px,
        "open_end_px": default_open_end_px,
    }
    # redo the previous y projection finding with just this channel
    channel_slice = image_data[0][:, peak - crop_wp : peak + crop_wp]
    slice_projection_y = channel_slice.sum(axis=1)
    slice_proj_y_d = np.diff(slice_projection_y.astype(np.int32))
    slice_closed_end_px = slice_proj_y_d[:onethirdpoint_y].argmax()
    slice_open_end_px = twothirdpoint_y + slice_proj_y_d[twothirdpoint_y:].argmin()
    slice_length = slice_open_end_px - slice_closed_end_px

    # check if these values make sense. If so, use them. If not, use default
    # make sure lenght is not 30 pixels bigger or smaller than default
    # *** This 15 should probably be a parameter or at least changed to a fraction.
    if slice_length + 15 < default_length or slice_length - 15 > default_length:
        continue
    # make sure ends are greater than 15 pixels from image edge
    if slice_closed_end_px < 15 or slice_open_end_px > image_data.shape[0] - 15:
        continue

    # if you made it to this point then update the entry
    chnl_loc_dict[peak] = {
        "closed_end_px": slice_closed_end_px,
        "open_end_px": slice_open_end_px,
    }

In [None]:
image_rows= image_data.shape[1]
image_cols= image_data.shape[2]
crop_wp= 10
chan_lp= 10
crop_wp = 10

In [None]:
mask_corners_dict = {}
consensus_mask = np.zeros([image_rows, image_cols])  # mask for labeling entire image
# for each trench in each image make a single mask
img_chnl_mask = np.zeros([image_rows, image_cols])

# and add the channel/peak mask to it
# Assuming chnl_loc_dict is a NumPy array
for chnl_peak in chnl_loc_dict:
    peak_ends = chnl_loc_dict[chnl_peak]
    # pull out the peak location and top and bottom location
    # and expand by padding 
    x1 = max(chnl_peak- crop_wp, 0)
    x2 = min(chnl_peak + crop_wp, image_cols)
    y1 = max(peak_ends["closed_end_px"] - chan_lp, 0)
    y2 = min(peak_ends["open_end_px"] + chan_lp, image_rows)
    mask_corners_dict[chnl_peak] = [y1, y2, x1, x2]

    # add it to the mask for this image
    img_chnl_mask[y1:y2, x1:x2] = 1

# add it to the consensus mask
consensus_mask += img_chnl_mask

# Normalize consensus mask between 0 and 1.
consensus_mask = consensus_mask.astype("float32") / float(np.amax(consensus_mask))

In [None]:
plt.imshow(image_data[0], cmap='gray')
plt.title('Channel mask')
plt.axis('off')  # Hide axis labels
plt.show()

In [None]:
plt.imshow(consensus_mask, cmap='gray')
plt.title('Channel mask')
plt.axis('off')  # Hide axis labels
plt.show()

In [None]:
masked_image = image_data[0]*consensus_mask

In [None]:
consensus_mask.shape

In [None]:
plt.imshow(masked_image, cmap='gray')
plt.title('Channel mask')
plt.axis('off')  # Hide axis labels
plt.show()

In [None]:
mask_corners_dict.keys()

In [None]:
y1,y2, x1,x2= mask_corners_dict[38]

In [None]:
plt.imshow(image_data[0][y1:y2, x1:x2], cmap='gray')
plt.title('Channel mask')
plt.axis('off')  # Hide axis labels
plt.show()

In [None]:
# Extract trenches and save as stacked TIFF images
for trench in mask_corners_dict.keys():
    y1,y2, x1,x2= mask_corners_dict[trench]
    trench_region = image_data[:, y1:y2, x1:x2]# assuming image is stacked as c y x
    filename = f'region_{trench}.tif'
    path = os.path.join(unstacked_path, filename)
    tifffile.imwrite(path, trench_region)

In [None]:
def make_consensus_mask(
    fov: int,
    analyzed_imgs: dict,
    image_rows: int,
    image_cols: int,
    crop_wp: int,
    chan_lp: int,
) -> np.ndarray:
    """
    Generate consensus channel mask for a given fov.

    Parameters
    ----------
    fov: int
        fov to analyze
    analyzed_imgs: dict
        image data
    image_rows: int
        image height
    image_cols: int
        image width
    crop_wp: int
        channel width padding
    crop_lp: int
        channel_width padding

    Returns
    -------
    consensus_mask: np.ndarray
    """

    consensus_mask = np.zeros([image_rows, image_cols])  # mask for labeling

    # bring up information for each image
    for img_k in analyzed_imgs.keys():
        img_v = analyzed_imgs[img_k]
        # skip this one if it is not of the current fov
        if img_v["fov"] != fov:
            continue

        # for each channel in each image make a single mask
        img_chnl_mask = np.zeros([image_rows, image_cols])

        # and add the channel mask to it
        for chnl_peak, peak_ends in six.iteritems(img_v["channels"]):
            # pull out the peak location and top and bottom location
            # and expand by padding (more padding done later for width)
            x1 = max(chnl_peak - crop_wp, 0)
            x2 = min(chnl_peak + crop_wp, image_cols)
            y1 = max(peak_ends["closed_end_px"] - chan_lp, 0)
            y2 = min(peak_ends["open_end_px"] + chan_lp, image_rows)

            # add it to the mask for this image
            img_chnl_mask[y1:y2, x1:x2] = 1

        # add it to the consensus mask
        consensus_mask += img_chnl_mask

    # Normalize consensus mask between 0 and 1.
    consensus_mask = consensus_mask.astype("float32") / float(np.amax(consensus_mask))

    # threshhold and homogenize each channel mask within the mask, label them
    # label when value is above 0.1 (so 90% occupancy), transpose.
    # the [0] is for the array ([1] is the number of regions)
    # It transposes and then transposes again so regions are labeled left to right
    # clear border it to make sure the channels are off the edge
    consensus_mask = ndi.label(consensus_mask)[0]

    return consensus_mask

In [None]:
def find_channel_locs(params: dict, image_data: np.ndarray) -> dict:
    """Finds the location of channels from a phase contrast image. The channels are returned in
    a dictionary where the key is the x position of the channel in pixel and the value is a
    dicionary with the open and closed end in pixels in y.


    Called by
    get_tif_params

    """

    # declare temp variables from yaml parameter dict.
    chan_w = params["compile"]["channel_width"]
    chan_sep = params["compile"]["channel_separation"]
    crop_wp = int(params["compile"]["channel_width_pad"] + chan_w / 2)
    chan_snr = params["compile"]["channel_detection_snr"]

    # Detect peaks in the x projection (i.e. find the channels)
    projection_x = image_data.sum(axis=0).astype(np.int32)
    # find_peaks_cwt is a function which attempts to find the peaks in a 1-D array by
    # convolving it with a wave. here the wave is the default Mexican hat wave
    # but the minimum signal to noise ratio is specified
    # *** The range here should be a parameter or changed to a fraction.
    peaks = find_peaks_cwt(
        projection_x, np.arange(chan_w - 5, chan_w + 5), min_snr=chan_snr
    )

    # If the left-most peak position is within half of a channel separation,
    # discard the channel from the list.
    if peaks[0] < (chan_sep / 2):
        peaks = peaks[1:]
    # If the diference between the right-most peak position and the right edge
    # of the image is less than half of a channel separation, discard the channel.
    if image_data.shape[1] - peaks[-1] < (chan_sep / 2):
        peaks = peaks[:-1]

    # Find the average channel ends for the y-projected image
    projection_y = image_data.sum(axis=1)
    # find derivative, must use int32 because it was unsigned 16b before.
    proj_y_d = np.diff(projection_y.astype(np.int32))
    # use the top third to look for closed end, is pixel location of highest deriv
    onethirdpoint_y = int(projection_y.shape[0] / 3.0)
    default_closed_end_px = proj_y_d[:onethirdpoint_y].argmax()
    # use bottom third to look for open end, pixel location of lowest deriv
    twothirdpoint_y = int(projection_y.shape[0] * 2.0 / 3.0)
    default_open_end_px = twothirdpoint_y + proj_y_d[twothirdpoint_y:].argmin()
    default_length = default_open_end_px - default_closed_end_px  # used for checks

    # go through peaks and assign information
    # dict for channel dimensions
    chnl_loc_dict = {}
    # key is peak location, value is dict with {'closed_end_px': px, 'open_end_px': px}

    for peak in peaks:
        # set defaults
        chnl_loc_dict[peak] = {
            "closed_end_px": default_closed_end_px,
            "open_end_px": default_open_end_px,
        }
        # redo the previous y projection finding with just this channel
        channel_slice = image_data[:, peak - crop_wp : peak + crop_wp]
        slice_projection_y = channel_slice.sum(axis=1)
        slice_proj_y_d = np.diff(slice_projection_y.astype(np.int32))
        slice_closed_end_px = slice_proj_y_d[:onethirdpoint_y].argmax()
        slice_open_end_px = twothirdpoint_y + slice_proj_y_d[twothirdpoint_y:].argmin()
        slice_length = slice_open_end_px - slice_closed_end_px

        # check if these values make sense. If so, use them. If not, use default
        # make sure lenght is not 30 pixels bigger or smaller than default
        # *** This 15 should probably be a parameter or at least changed to a fraction.
        if slice_length + 15 < default_length or slice_length - 15 > default_length:
            continue
        # make sure ends are greater than 15 pixels from image edge
        if slice_closed_end_px < 15 or slice_open_end_px > image_data.shape[0] - 15:
            continue

        # if you made it to this point then update the entry
        chnl_loc_dict[peak] = {
            "closed_end_px": slice_closed_end_px,
            "open_end_px": slice_open_end_px,
        }

    return chnl_loc_dict

# Functions for pre-processing mother machine images

In [None]:
def midpoint_distance(line, center):
    # Function to calculate line midpoint distance
    midpoint_x = (line[0][0] + line[0][2]) / 2
    midpoint_y = (line[0][1] + line[0][3]) / 2
    distance = np.sqrt((midpoint_x - center[0])**2 + (midpoint_y - center[1])**2)
    return distance

def crop_around_central_flow(h_lines, w, h, growth_channel_length= 150):
    threshold = 200 # distance from center of image
    center_x, center_y = w // 2, h // 2
    if h_lines is not None:
        # Filter lines based on distance
        filtered_lines = []
        for line in h_lines:
            distance = midpoint_distance(line, (center_x, center_y))
            if distance <= threshold:
                filtered_lines.append(line)
        x1, y1, x2, y2 = filtered_lines[0][0]

        # Determine crop boundaries
        crop_start = max(y1 - growth_channel_length, 0)
        crop_end = min(y1 + 150, h)

        print('Cropping reference image')

        return crop_start, crop_end

    else:
        print('Warning: horizontal lines were not detected')
        return None

In [None]:
def rotate_stack(path_to_stack, c = 0, growth_channel_length = 295):
    """Args:
    path_to_stack: Path to stack of cyx format files, in string format
    c: Phase channel index in integer format, default = 0
    orientation: Orientation of lines to use to rotate and crop files, 
    value is a string indicating "horizontal" or "vertical". Defualt is vertical.
    growth_channel_length: length in pixels of growth channel. 
    Shorter channels are approx 130 pixels. 150 is default.
    """
    ## I want to make this work in tcyx files, because I can skip unstacking the drift corrected files if I isolate trenches with my own code and don't use napari-mm3
    #create an output directory for the rotated files
    path_to_rotated_images = os.path.join(path_to_stack, 'rotated')
    os.makedirs(path_to_rotated_images, exist_ok=True)

    file_groups = org_by_timepoint([path_to_stack])

    for position in file_groups.keys():
        earliest_timepoint = min(file_groups[position].keys())
        first_file_path = file_groups[position][earliest_timepoint]['stacked']
        print(first_file_path)
        ref_img = tifffile.imread(first_file_path)
        ref_phase_img = ref_img[c, :, :]
        h, w = ref_phase_img.shape
        horizontal_lines, vertical_lines = id_lines(ref_phase_img)
        rotation_angle = calculate_rotation_angle(ref_phase_img, horizontal_lines)
        ref_rotated_image = apply_image_rotation(ref_phase_img, rotation_angle)
        rot_horizontal_lines, rot_vertical_lines = id_lines(ref_rotated_image)
        crop_start, crop_end = crop_around_central_flow(rot_horizontal_lines, w, h, growth_channel_length)
        ref_cropped_img = ref_rotated_image[crop_start:crop_end, :]
        print('Lines identified in FOV ' + position)
        plt.figure()
        plt.imshow(ref_cropped_img, cmap='gray')
        plt.show()
        
        #apply rotation and crop to all other images in path
        for time in file_groups[position]:
            path = file_groups[position][time]['stacked']
            time_img = tifffile.imread(path)
            rotated_image = apply_image_rotation(time_img, rotation_angle)
            cropped_img = rotated_image[:, crop_start:crop_end, :]
            filename = os.path.basename(path)
            new_filename = f'rotated_{filename}'
            new_path = os.path.join(path_to_rotated_images, new_filename)
            tifffile.imwrite(new_path, cropped_img)
        print('Successfully rotated stack')

In [None]:
def detect_clear_image(image):
    laplacian_image = filters.laplace(image)
    blur_score = np.var(laplacian_image)
    if blur_score >= 0:
        return True

In [None]:
def napari_ready_format_drift_correct(root_dir, experiment_name, c = 0):
    """ 
    Arg
    root_dir: parent directory containing multiple 'Pos#' directories, 
    each containing tif files of single timepoints and channels of the given position. 
    File name is default from the Covert lab microscope.
    experiment_name: unique id to label output files
    c = int representing phase channel index
    output: drift corrected files across multiple positions and timepoints. Found within
    'cyx_files_for_mm3' directory
    
    """

    hyperstacked_path = os.path.join(root_dir, 'hyperstacked')
    drift_corrected_path = os.path.join(hyperstacked_path, 'drift_corrected')
    output_dir_path = os.path.join(drift_corrected_path, 'cyx_files_for_mm3')

    time_dict = hyperstack_tif_tcyx(root_dir, experiment_name)
    
    drift_correction_napari(hyperstacked_path)
    
    unstack_tcyx_to_cyx(drift_corrected_path, time_dict)
    
    return output_dir_path

def hyperstack_tif_tcyx(root_dir, experiment_name, c = 0):
    
    """Renames TIFF files without deleting originals.
    Args:
    input_dir: parent directory.
    experiment_name: The desired experiment name.
    """
    root = Path(root_dir)
    input_dirs = [str(path) for path in root.glob('**//Pos*') if path.is_dir()]

    # Create output directory if it doesn't exist
    output_dir_path = os.path.join(root_dir, 'renamed')
    os.makedirs(output_dir_path, exist_ok=True)
    stacked_path = os.path.join(root_dir, 'stacked')
    os.makedirs(stacked_path, exist_ok=True)
    hyperstacked_path = os.path.join(root_dir, 'hyperstacked')
    os.makedirs(hyperstacked_path, exist_ok=True)

    time_clear_dict = {}

    file_groups = org_by_timepoint(input_dirs)
    for position, time in sorted(file_groups.items()):
        time_clear_dict[position] ={}
        time_stacked_image_data = []
        for time, channels in sorted(time.items()):
            image_data = []
            for channel, image_path in sorted(channels.items()):
                new_filename = f'{experiment_name}_t{time:04.0f}xy{position}c{channel}.tif'
                new_path = os.path.join(output_dir_path, new_filename)
                try:
                    # Copy the file to the new path
                    shutil.copy(str(image_path), str(new_path))
                    channel_image = tifffile.TiffFile(new_path).asarray() 
                    image_data.append(channel_image)
                    
                except OSError as e:
                    print(f'Error copying file: {e}')
            stacked_image = np.stack(image_data, axis=0)  # Assuming channels are the first dimension
            output_stacked_file = Path(stacked_path) / f"{experiment_name}_t{time:04.0f}xy{position}.tif"
            tifffile.imwrite(str(output_stacked_file), stacked_image)
            phase_image = stacked_image[c, :, :]
            if detect_clear_image(phase_image): # only time stack clear images
                time_stacked_image_data.append(stacked_image) 
                time_clear_dict[position][time] = 'clear'
            else:
                time_clear_dict[position][time] = 'blurry'
                print('blurry')
        hyperstacked_image = np.stack(time_stacked_image_data, axis=0)  # time as the first dimension
        output_hyperstacked_file = Path(hyperstacked_path) / f"{experiment_name}_xy{position}.tif"
        tifffile.imwrite(str(output_hyperstacked_file), hyperstacked_image)
        
    return time_clear_dict

def drift_correction_napari(hyperstacked_path):
    output_dir_path = os.path.join(hyperstacked_path, 'drift_corrected')
    os.makedirs(output_dir_path, exist_ok=True)
    
    for filename in os.listdir(hyperstacked_path):
        if filename.endswith('.tif') or filename.endswith('.tiff'):
            if re.match(r'(.*)_xy(\d+)\.' ,filename):
                match = re.match(r'(.*)_xy(\d+)\.' ,filename)
                experiment, position = match.groups()
                img_path = os.path.join(hyperstacked_path, filename)
                hyperstacked_img = tifffile.imread(img_path)
                # multi-channel 2D-movie
                cd = CorrectDrift(hyperstacked_img, "tcyx")
                # estimate drift table
                drifts = cd.estimate_drift(t0=0, channel=0)
                # correct drift
                img_cor = cd.apply_drifts(drifts)
                img_cor_file = Path(output_dir_path) / f"drift_cor_{experiment}_xy{position}.tif"
                tifffile.imwrite(str(img_cor_file), img_cor)

def org_by_timepoint(input_dirs):
    
    """Group files by time and channel id, it does not take into account the z axis
    Reads in files in the format exported by the Covert lab scope, 
    which is as follows: 'img_channel(\d+)_position(\d+)_time(\d+)_z(\d+)\.'

    Returns a dictionary in the following format:
    dict[time_frame] = {channel_id : '/path/to/tif/file'}
    """

    time = 'hyperstacked'
    channel = 'stacked'
    position = '0'
    
    file_groups = {}
    
    for input_dir in input_dirs:
        for filename in os.listdir(input_dir):
            if filename.endswith('.tif') or filename.endswith('.tiff'):
                match = re.match(r'img_channel(\d+)_position(\d+)_time(\d+)_z(\d+)\.' ,filename)
                if match:
                  channel, position, time, z = match.groups()
                  time = int(time)
                elif re.match(r'(.*)_t(\d+)xy(\d+)\.',filename):
                    match = re.match(r'(.*)_t(\d+)xy(\d+)\.',filename)
                    experiment, time, position = match.groups()
                    time = int(time)
                else:
                    match = re.match(r'(.*)_xy(\d+)\.' ,filename)
                    if match:
                        experiment, position = match.groups()
                path = os.path.join(input_dir, filename)
                if position not in file_groups:
                  file_groups[position] = {}
                if time not in file_groups[position]:
                  file_groups[position][time] = {}
                if channel not in file_groups[position][time]:
                  file_groups[position][time][channel] = path
    
    return file_groups

def unstack_tcyx_to_cyx(path_to_hyperstacked, time_clear_dict):
    
    """
    input_dir: directory where movies are hyperstacked as tcyx
    output_dir: The output directory for TIFF files stacked as cyx
    """

    # Create output directory if it doesn't exist
    output_dir_path = os.path.join(path_to_hyperstacked, 'cyx_files_for_mm3')
    os.makedirs(output_dir_path, exist_ok=True)
    

    file_groups = org_by_timepoint([path_to_hyperstacked])
    for position, time in sorted(file_groups.items()):
        for time, channels in sorted(time.items()):
            for channel, image_path in sorted(channels.items()):
                filename = os.path.basename(image_path)
                match = re.match(r'(.*)_xy(\d+)\.' ,filename)
                if match:
                    experiment, position = match.groups()
                    hyperstacked_img = tifffile.imread(image_path)
                    real_times = [key for key, val in time_clear_dict[position].items() if val == 'clear']
                    index_times = [i for i in range(0, hyperstacked_img.shape[0], 1)]
                    index_to_real_time = dict(zip(index_times, real_times))
    
                    for index in range(hyperstacked_img.shape[0]):
                        cyx_image = hyperstacked_img[index, :, :, :]
                        real_time = index_to_real_time[index]
                        output_cyx_file = Path(output_dir_path) / f"{experiment}_t{real_time:04.0f}xy{position}.tif"
                        tifffile.imwrite(str(output_cyx_file), cyx_image)
                        

def plot_lines_across_FOVs(path_to_stack, c = 0):
    
    """Args:
    path_to_stack: Path to stack of cyx format files, in string format
    c: Phase channel index in integer format, default = 0
    Output is a series of plotted images with identified 
    horizontal or vertical lines across FOVs/positions
    """
    file_groups = org_by_timepoint([path_to_stack])

    for position in file_groups.keys():
        earliest_timepoint = min(file_groups[position].keys())
        first_file_path = file_groups[position][earliest_timepoint]['stacked']
        ref_img = tifffile.imread(first_file_path)
        ref_phase_img = ref_img[c, :, :]
        horizontal_lines, vertical_lines = id_lines(ref_phase_img)
        print('Lines identified in FOV ' + position)
        print('Horizontal lines')
        plot_lines(ref_phase_img, horizontal_lines)

def calculate_line_angle(x1, y1, x2, y2):
    dx = x2 - x1
    dy = y2 - y1
    angle = np.arctan2(dy, dx) * 180 / np.pi
    return angle

def find_lines(img):
    normalized_img = (img / img.max() * 255).astype(np.uint8)
    edges = cv2.Canny(normalized_img, 50, 150)
    lines = cv2.HoughLinesP(edges, 1, np.pi/180, 100, minLineLength=100, maxLineGap=10)
    return lines
    
def id_lines(img):
    lines = find_lines(img)
    h_lines = []
    v_lines = []
    count = 0
    if lines is not None:
        for line in lines:
            x1, y1, x2, y2 = line[0]
            angle = calculate_line_angle(x1, y1, x2, y2)
            if abs(angle) < 30:  # Adjust threshold as needed
                h_lines.append(line)
                count+= 1
            elif 60 <= abs(angle) <= 120:  # Adjust threshold as needed
                v_lines.append(line)
                count+= 1
            if count >= 50:
                break
    return h_lines, v_lines


def calculate_rotation_angle(img, lines):
    """calculate rotation angle based on phase image"""
    angles = []
    for line in lines:
        # Calculate angle of the line
        x1, y1, x2, y2 = line[0]
        angle = calculate_line_angle(x1, y1, x2, y2)
        angles.append(abs(angle))
    average_angle = sum(angles) / len(angles)
    return average_angle

def plot_lines(original_img, lines):
    plt.figure()
    plt.imshow(original_img, cmap='gray')
    if lines is not None:
        for line in lines:
            x1, y1, x2, y2 = line[0]
            plt.plot([x1, x2], [y1, y2], color='green', linewidth=2) 
        plt.show()


def apply_image_rotation(image_stack, rotation_angle):
    """Applies rotation to an image stacked as cyx.

    Args:
        image: image in Grey or BGR format for OpenCV 
        rotation_angle: The rotation angle in degrees.

    Returns:
        Rotated image in BGR format.
    """
    rotated_stack = np.zeros_like(image_stack)
    h = None
    w = None
    if image_stack.ndim == 3:
        h, w = image_stack.shape[1:]
        center = (w // 2, h // 2)
        M = cv2.getRotationMatrix2D(center, rotation_angle, 1.0)
        for i in range(image_stack.shape[0]):
            rotated_stack[i] = cv2.warpAffine(image_stack[i], M, (w, h))
    
    elif image_stack.ndim == 2:
        h, w = image_stack.shape
        center = (w // 2, h // 2)
        M = cv2.getRotationMatrix2D(center, rotation_angle, 1.0)
        rotated_stack = cv2.warpAffine(image_stack, M, (w, h))

    return rotated_stack