## Gel Image Manual Band Selector

This notebook allows a user to interactively select lanes and gel bands from an input gel image.  The gel bands must be rectangles, but these can be rotated freely.  Final selections can be saved to file in CSV format and read for analysis in accompanying analysis notebook (scroll to end of file for tutorial on how to use).

### Imports and file setup

In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
from mpl_toolkits.axes_grid1 import make_axes_locatable, AxesGrid
from IPython.display import display, clear_output
import ipywidgets as widgets
from ipywidgets import Layout, Output
from PIL import Image
from matplotlib.widgets import RectangleSelector
from matplotlib.patches import Polygon
from numpy.linalg import lstsq
import copy
import json
from scipy.signal import find_peaks
import scipy.signal as signal
from scipy import sparse
from scipy.sparse.linalg import spsolve

In [None]:
# Set filepaths according to your system

base_path = r'SET INPUT FOLDER'
main_output_loc = 'SET OUTPUT LOCATION'

# Main DNA File
base_data = os.path.join(base_path, r'LOD_v7_7:10:21/images')
print('Main DNA file available:', os.path.isdir(base_data))
DNA_file = os.path.join(base_data, '1_3.tif')  # 48 hour BRCA1, TBE

# 24-hr DNA File
base_data = os.path.join(base_path, r'LOD_v8_8:10:21/images')
print('24-hr DNA file available:', os.path.isdir(base_data))
DNA_file_24 = os.path.join(base_data, 'dna_2.tif')  # 24 hour BRCA1, TBE

# TAE DNA File
base_data = os.path.join(base_path, r'LOD_processing_v2_28:9:21/images')
print('Better TAE DNA file available:', os.path.isdir(base_data))
better_tae_DNA_file = os.path.join(base_data, '2_top.tif')  # 24 hour BRCA1, TAE

# Main RNA File
base_data = os.path.join(base_path, r'LOD_v6_6:10:21/images')
print('Main RNA file available:', os.path.isdir(base_data))
RNA_file = os.path.join(base_data, 'tbe_2_2.tif')  # 24 hour RNA, TBE

# Main Aldosterone File
base_data = os.path.join(base_path, r'LOD_v8_8:10:21/images')
print('Main aldosterone file available:', os.path.isdir(base_data))
Aldos_file = os.path.join(base_data, 'aldos_2.tif') # 24 hour aldosterone, TBE

# Mutations Analysis
base_data = os.path.join(base_path, r'LOD_v6_6:10:21/images')
print('TAE Mutation file available:', os.path.isdir(base_data))
mutations_file = os.path.join(base_data, 'tae_3_2.tif')  # 24 hour detection, TAE

# Mutations Analysis (TBE)
base_data = os.path.join(base_path, r'LOD_v6_6:10:21/images')
print('TBE Mutation analysis file available:', os.path.isdir(base_data))
mutations_tbe_file = os.path.join(base_data, 'tbe_1_zoomed.tif')  # 24 hour detection, TBE

# Aldosterone FBS analysis (TBE)
base_data = os.path.join(base_path,'..', r'FBS Phase 5/FBS_long_term_testing_3_7:2:22/images')
print('Aldosterone FBS file available:', os.path.isdir(base_data))
aldos_fbs_file = os.path.join(base_data, 'aldos_1.tif')  # 2 hour detection, TBE

### Image Loading and Quick Image Display

In [None]:
image = Image.open(aldos_fbs_file)
im_array = np.array(image)
# fig = plt.figure(figsize=(20,20))
plt.imshow(1-im_array, cmap='gray')
plt.axis('off')
pass

## Main Lain Selection GUI

In [None]:
# Config initialization
%matplotlib notebook
lane = None
trace = None
point = None
lane_line = None
current_setup = 'width'
lane_queue = []
lane_data = {}
patches = [] # lanes plotted on screen
temp_patches = []
l_lock = False
band_line = []

def dist_2_pts(p1, p2):
    # finds distance between two points (2D)
    return np.sqrt((p2[0] - p1[0])**2 + (p2[1] - p1[1])**2)

def plot_v_line(btn):
    # Plots a vertical line on the gel image, and a linear trace in the top plot (currently unused)
    global lane, ax, trace
    if lane is not None:
        lane.remove()
    if trace is not None:
        trace.remove()
        
    lane = ax[1].axvline(slider.value)
    trace = ax[0].scatter(range(im_array.shape[0]), im_array[:,slider.value])
#     ax[1].scatter(x=last_mouse_click[0],y=last_mouse_click[1])

def plot_point(x, y):
    # plots a single point at the coordinate specified
    global point
    if point is not None:
        point.remove()
    point = ax[1].scatter(x=x,y=y)

def plot_im_line(p1, p2):
    # plots a line between a mouse-defined start and stop position (will delete previous line if present)
    global ax, lane_line
    
    if lane_line is not None:
        ll = lane_line.pop(0)
        ll.remove()
        
    xs = [p1[0], p2[0]]
    ys = [p1[1], p2[1]]
    
    lane_line = ax[1].plot(xs, ys)

def plot_dbl_line(p1, p2):
    # plots a line between a mouse-defined start and stop position (allows for max 2 to drawn)
    global ax, band_line
    
    if len(band_line) == 2:
        for band in band_line:
            ll = band.pop(0)
            ll.remove()
        band_line = []
        
    xs = [p1[0], p2[0]]
    ys = [p1[1], p2[1]]
    
    band = ax[1].plot(xs, ys)
    band_line.append(band)
    
def clear_dbl_line():
    # clears double lines
    global band_line
    if len(band_line) == 2:
        for band in band_line:
            ll = band.pop(0)
            ll.remove()
        band_line = []

def clear_im_line():
    # clears any single lines on screen
    global lane_line
    if lane_line is not None:
        ll = lane_line.pop(0)
        ll.remove()
    lane_line = None

def diff_check_sign(a, b):
    # subtracts two values and returns sign of result
    sub = a - b
    if sub < 0:
        return np.abs(sub), False
    else:
        return np.abs(sub), True

def measure_width_two_clicks(p1, p2):
    # measures distance between two clicks, and sets current lane width value
    dist = dist_2_pts(p1, p2)
    lane_width.value = str(int(dist))
    
def dual_click_mouse_capture(event):
    # captures mouse click data and plots a line between two specified points
    global last_mouse_click, lane_queue, current_setup, temp_patches, l_lock, lane_length, lane_data, output
    
    with output:  # this is required to allow printing of error callbacks
        lane_queue.append((int(event.xdata),int(event.ydata)))

        if len(lane_queue) == 2:  # will trigger only if two clicks registered
            
            if current_setup == 'positioning': # set start/stop positions of lane
                p1 = (lane_queue[0][0], lane_queue[0][1])
                p2 = (lane_queue[1][0], lane_queue[1][1])
                if l_lock:
                    p2 = reduce_length_to_standard(p1, p2) # reduce length to standard value if set
                else:
                    lane_length.value = str(int(round(dist_2_pts(p1, p2))))
                lane_pos.value = 'Lane Position: [%d, %d] to [%d, %d]' % (p1[0], p1[1], p2[0], p2[1])
                draw_current_lane()
                
            elif current_setup == 'width':  # measuring band width
                if len(temp_patches) != 0:  # delete old patches
                    for p in temp_patches:
                        p.remove()
                    temp_patches = []
                measure_width_two_clicks(lane_queue[0], lane_queue[1])
                
            else: # band selection
                found_bands = []

                for lane_id, lane in lane_data.items():  # checks all saved lanes one by one
                    lane_positions = lane['lane_divisions']
                    for point in lane_queue:
                        for p_index, position in enumerate(lane_positions):
                            if point in position:
                                plot_dbl_line(position[0], position[-1])
                                found_bands.append((lane_id, p_index))
                                break
                    if len(found_bands) == 2:
                        break
                    if len(found_bands) == 1:
                        print('Error in band finding, please try again.')
                if len(found_bands) == 2:
                    lane_data[found_bands[0][0]]['band'] = [found_bands[0][1], found_bands[1][1]]

            if current_setup != 'band_select':
                plot_im_line(lane_queue[0], lane_queue[1])
            lane_queue = []

def profile_plot(btn):
    # using specified lane, calculates and plots pixel intensity profile (recommended to draw diagram of graphing system to understand how it works)
    global ax, im_array, plot_intensity, lane_data
    for lane_id, lane in lane_data.items():
        corners = copy.copy(lane['corners'])  # extract corners data for each lane
        
        corners_all = copy.copy(corners)
        max_point = 0
        ind_max = 0
        
        for ind in range(corners.shape[0]):  # finding lowest coordinate (i.e. highest y-value)
            if corners[ind][1] > max_point:
                anchor_1 = corners[ind]
                ind_max = ind
            max_point = max(max_point, corners[ind][1]) 

        corners = np.delete(corners, ind_max, 0) # deleting lowest coordinate from central corners list
        
        dists = []
        for c in corners:  # calculating distance between bottom point and the rest of the corners
            dists.append(dist_2_pts(anchor_1, c))

        dist_sort = sorted(dists)
        anchor_2 = corners[dists.index(dist_sort[0])]  # closest corner is second from the bottom (anchor 2)
        base_1 = corners[dists.index(dist_sort[1])]  # next closest is the base corner matching the first anchor
        base_2 = corners[dists.index(dist_sort[2])]  # finally, this base will link to anchor 2

        # eqn for rect height
        points = [anchor_1,base_1]
        x_coords, y_coords = zip(*points)
        A = np.vstack([y_coords,np.ones(len(y_coords))]).T
        m_b, c_b = lstsq(A, x_coords, rcond=None)[0]  # this is the long edge containing anchor_1 and base_1

         # eqn for rect height 2
        points = [anchor_2,base_2]
        x_coords, y_coords = zip(*points)
        A = np.vstack([y_coords,np.ones(len(y_coords))]).T
        m_b2, c_b2 = lstsq(A, x_coords, rcond=None)[0]  # this is the long edge containing anchor_2 and base_2

        x_as = []
        plot_intensity = []
        lane_divisions = []
        for ind, i_y in enumerate(range(base_1[1], anchor_1[1]+1)):  # cycle through y values between base_1 and anchor_1
            i_y2 = base_2[1] + ind  # find corresponding y value on other edge 
            x_a = m_b * i_y + c_b  # calculate corresponding x position for both ys
            x_b = m_b2 * i_y2 + c_b2
        
            x_as.append(round(x_a))

            x_coords = (round(x_a), round(x_b))
            y_coords = (i_y, i_y2)
            
            A = np.vstack([x_coords,np.ones(len(x_coords))]).T
            m_bt, c_bt = lstsq(A, y_coords, rcond=None)[0]  # find equation of line connecting the two points found on the edges of the rectangle
            
            if x_a > x_b:
                direction = -1
            else:
                direction = 1
                
            x_test = range(int(round(x_a)), int(round(x_b)), direction)
            y_test = m_bt*x_test + c_bt  # extract all available x/y pixels in this section of the profile
            
            data = []  # accumulate actual image data here
            inner_divisions = []  # accumulate each line's coverage here, which will help for easy band selection
            for x,y in zip(x_test, y_test): 
                data.append(im_array[int(round(y)), x])  # image axes flipped around in array, so have to reverse x and y
                inner_divisions.append((x, int(round(y))))
            plot_intensity.append(np.mean(data))  # profile intensity is an average of entire width of the lane
            lane_divisions.append(inner_divisions)  # contains the entire pixel coverage corresponding to the position along the lane
        
        lane_data[lane_id]['lane_divisions'] = lane_divisions
        lane_data[lane_id]['trace'] = plot_intensity 
        
        ax[0].scatter(range(len(plot_intensity)), plot_intensity, label=lane_id, s=10)  # plots profile in top plot
    ax[0].legend()
    
def reduce_length_to_standard(p1, p2):
    # checks to see how long drawn lane is, and changes its size to match standard value if not equal
    
    target_length = int(lane_length.value)
    actual_length = dist_2_pts(p1, p2)
    
    if actual_length == target_length:
        return p2
        
    change, change_sign = diff_check_sign(actual_length, target_length)
    
    xdev, h_sign = diff_check_sign(p1[0], p2[0])
    ydev, v_sign = diff_check_sign(p1[1], p2[1])
    
    if xdev == 0:
        return (p2[0], p2[1] + change)

    points_angle = np.arctan(ydev/xdev)
           
    v_change = np.abs(np.sin(points_angle) * change)
    h_change = np.abs(np.cos(points_angle) * change)

    if v_sign:
        v_change = - v_change 
    if h_sign:
        h_change = - h_change
    if change_sign:
        h_change = - h_change
        v_change = - v_change
        
    new_p2 = (int(round(p2[0] + h_change)), int(round(p2[1] + v_change)))
        
    return new_p2
    
################## State Callbacks #################

def set_width_on(btn):
    global current_setup, mode
    current_setup = 'width'
    mode.value = 'Mode: %s' % current_setup
    clear_im_line()
    clear_dbl_line()
    
def set_pos_on(btn):
    global current_setup, mode
    current_setup = 'positioning'
    mode.value = 'Mode: %s' % current_setup
    clear_im_line()
    clear_dbl_line()
    
def set_band_sel_on(btn):
    global current_setup, mode
    current_setup = 'band_select'
    mode.value = 'Mode: %s' % current_setup
    clear_im_line()
    clear_dbl_line()
    
def get_current_lane():
    start = lane_pos.value.split('[')[1].split(']')[0]
    st_x = int(start.split(',')[0])
    st_y = int(start.split(',')[1])
    
    stop = lane_pos.value.split('[')[2].split(']')[0]
    sp_x = int(stop.split(',')[0])
    sp_y = int(stop.split(',')[1])
    return int(lane_width.value), int(lane_length.value), (st_x, st_y), (sp_x, sp_y)
    
def save_lane(btn):
    # saves data in a dict for future exporting
    width, length, start, stop = get_current_lane()
    lane_data[current_lane.value] = {
        'width': width,
        'length': length,
        'start': start,
        'stop': stop
    }
    
def save_to_csv(btn):
    # export data to csv file
    if 'lane_divisions' in lane_data[next(iter(lane_data))]:
        df = pd.DataFrame.from_dict(lane_data, orient='index').drop(['lane_divisions'],axis=1)
    else:
        df = pd.DataFrame.from_dict(lane_data, orient='index')
    
    df.to_csv(output_f.value)
    
def load_from_csv(btn):
    # reload previously saved data to be able to reconstruct lanes
    global lane_data
    df = pd.read_csv(input_f.value, header=0, index_col=0)
    df['width'] = df['width'].astype(int)
    df['start'] = df['start'].map(lambda x: tuple(map(int,x.lstrip('(').rstrip(')').split(','))))
    df['stop'] = df['stop'].map(lambda x: tuple(map(int,x.lstrip('(').rstrip(')').split(','))))
    lane_data = df.to_dict(orient='index')

def rot_rect(width, p1, p2):
    # find corners of a rotated (not horizontal/vertical) rectangle
    rot_corners = np.zeros((4, 2), dtype=np.int)
        
    cx = p1[0] + (p2[0] - p1[0])/2 # with lane width + start/stop, we can find centre of lane 'rectangle'
    cy = p1[1] + (p2[1] - p1[1])/2

    height = dist_2_pts(p1, p2)
        
    # after finding corners of vertical rectangle, can then move on to find angle of rotation required for specified lane
    corners = [(cx - width/2, cy+height/2), (cx+width/2, cy+height/2), (cx+width/2, cy-height/2), (cx-width/2, cy-height/2)]
    angle = -np.arctan((p2[0] - cx) / (p2[1] - cy) )

    for index, (x,y) in enumerate(corners):  # after finding angle, need to apply transformation to all corners of rectangle
        tx = x - cx
        ty = y - cy
        nx = tx*np.cos(angle) - ty*np.sin(angle) + cx
        ny = tx*np.sin(angle) + ty*np.cos(angle) + cy
        rot_corners[index, :] = [int(nx), int(ny)]
    
    return rot_corners
    
def draw_lanes(btn):
    # draw all saved lanes on image
    global ax, temp_patches, patches, lane_line
    
    if len(patches) != 0:  # delete old patches
        for p in patches:
            p.remove()
        patches = []
        
    if len(temp_patches) != 0:  # delete old patches
        for p in temp_patches:
            p.remove()
        temp_patches = []
    
    clear_im_line()
    
    for lane_id, lane in lane_data.items():  # draws all saved lanes one by one
        width = lane['width']
        p1 = lane['start']
        p2 = lane['stop']
        rot_corners = rot_rect(width, p1, p2)
        patch = ax[1].add_patch(Polygon(rot_corners, linewidth=1, edgecolor='r', facecolor='none'))
        patches.append(patch)
        lane_data[lane_id]['corners'] = rot_corners

def draw_current_lane():
    # extracts current lane parameters, and plots directly
    global temp_patches
    if len(temp_patches) != 0:  # delete old patches
        for p in temp_patches:
            p.remove()
        temp_patches = []
    width, length, p1, p2 = get_current_lane()
    rot_corners = rot_rect(width, p1, p2)
    patch = ax[1].add_patch(Polygon(rot_corners, linewidth=1, edgecolor='r', facecolor='none'))
    temp_patches.append(patch)

def lock_length(btn):
    # ensure lanes have same length
    global l_lock, length_lock
    l_lock = not l_lock
    if l_lock:
        length_lock.description = 'Unlock Length' 
    else:
        length_lock.description = 'Lock Length'
        
def update_lane_length(btn):
    lane_length.value = str(lane_data[current_lane.value]['length'])
    
################## UI #################

max_pixel = im_array.shape[1]
layout = widgets.Layout(width='auto') #set width to expand to description size

slider = widgets.IntSlider(
                                value=int(max_pixel/2),
                                min=0,
                                max=max_pixel,
                                step=1,
                                description='Lane Positioning:',
                                disabled=False,
                                layout=Layout(width="80%"),
                                continuous_update=False,
                                orientation='horizontal',
                                readout=True,
                                readout_format='d',
                                style={'description_width': 'initial'}
                            )
slider.observe(plot_v_line)  # calls this function when changed


##### Lane Info
current_lane = widgets.Text(value='0', 
                            description='Lane ID:',
                            style={'description_width': 'initial'},
                            layout=widgets.Layout(display='flex', width='20%'))
input_f = widgets.Text(
                            description='Load CSV:',
                            value= os.path.join(os.getcwd(),'lane_data', 'lane_file.csv'),
                            disabled=False,
                            style={'description_width': 'initial'},
                            layout=widgets.Layout(display='flex', width='40%')
                        )
output_f = widgets.Text(
                            description='Save CSV:',
                            value= os.path.join(main_output_loc, 'lane_file.csv'),
                            disabled=False,
                            style={'description_width': 'initial'},
                            layout=widgets.Layout(display='flex', width='40%')
                        )

lane_stack = widgets.HBox([current_lane, input_f, output_f])
##### 

##### Labels
lane_pos = widgets.Label('Lane Position:', layout=widgets.Layout(display='flex', width='30%', padding='0px 0px 0px 5px', border='solid 1px'))
lane_width = widgets.Text(
                            placeholder='Enter lane width here',
                            description='Lane Width:',
                            value='80', 
                            disabled=False,
                            layout=widgets.Layout(display='flex', width='20%'),
                            style={'description_width': 'initial'}
                        )
lane_length = widgets.Text(
                            placeholder='Enter lane length here',
                            description='Lane Length:',
                            value='150', 
                            disabled=False,
                            layout=widgets.Layout(display='flex', width='20%'),
                            style={'description_width': 'initial'}
                        )
mode = widgets.Label('Mode: %s' % current_setup, layout=widgets.Layout(display='flex', width='15%'))

length_lock = widgets.Button(button_style='danger', description = 'Lock Length', layout=layout)
length_lock.on_click(lock_length)

label_stack = widgets.HBox([lane_pos, lane_width, lane_length, length_lock, mode])
##### 

##### Bands

band_text = widgets.Text(
                            placeholder='Enter band positions here',
                            description='Band Positions:',
                            value='150', 
                            disabled=False,
                            layout=widgets.Layout(display='flex', width='20%'),
                            style={'description_width': 'initial'}
                        )
update_length_btn = widgets.Button(button_style='success', description = 'Update Length from Selected Lane', layout=layout)
update_length_btn.on_click(update_lane_length)
band_stack = widgets.HBox([band_text, update_length_btn])
##### 

##### Buttons

width_btn = widgets.Button(button_style='success', description = 'Set Lane Width', layout=layout)
width_btn.on_click(set_width_on)

define_btn = widgets.Button(button_style='success',description = 'Set Lane Start/Stop', layout=layout)
define_btn.on_click(set_pos_on)

band_btn = widgets.Button(button_style='success',description = 'Set Band', layout=layout)
band_btn.on_click(set_band_sel_on)

draw_btn = widgets.Button(button_style='warning',description = 'Draw Lanes', layout=layout)
draw_btn.on_click(draw_lanes)

add_btn = widgets.Button(button_style='warning',description = 'Save Lane', layout=layout)
add_btn.on_click(save_lane)

profile_btn = widgets.Button(description = 'Plot Profile', layout=layout, button_style='primary')
profile_btn.on_click(profile_plot)

save_btn = widgets.Button(button_style='info',description = 'Save Lanes to CSV', layout=layout)
save_btn.on_click(save_to_csv)

load_btn = widgets.Button(button_style='info',description = 'Load Lanes from CSV', layout=layout)
load_btn.on_click(load_from_csv)

button_stack = widgets.HBox([width_btn, define_btn, band_btn, draw_btn, add_btn, profile_btn, save_btn, load_btn])
##### 

full_stack = widgets.VBox([lane_stack, label_stack, band_stack, button_stack])

display(full_stack)

fig, ax = plt.subplots(nrows=2, ncols=1, figsize=(9,10), gridspec_kw={'height_ratios': [1, 3]})
ax[1].imshow(1-im_array, cmap='gray')  # main gel image
ax[1].axis('off')
ax[0].set_ylabel('Average Intensity')
ax[0].set_xlabel('Position on Lane (top to bottom)')
plt.tight_layout()
cb = fig.canvas.mpl_connect('button_press_event', dual_click_mouse_capture)  # general mouse capture system

output = widgets.Output()  # for error checking mouse callbacks
display(fig.canvas, output)
pass

## Usage Guidelines

### UI Guide

File IO:
- Lane ID: Set to name of current lane (can be any string)
- Load CSV: Filepath to load CSV file from a previous analysis (for continuation/visualization)
- Save CSV: Filepath to save current analysis (add .csv extension manually)
- Save Lanes to CSV: Click to save to csv file (filepath as above)
- Load Lanes from CSV: Click to load csv file (filepath as above)

Lane Specs:
- Lane Width:  Set current lane width (in pixels)
- Lane Length: Current lane length (in pixels)
- Lock/Unlock Length: Toggle to enforce all subsequent lanes to have the currently specified length

Lane Manipulation:

- Set Lane Width: Click to activate width measurement tool.  Select two locations on current image to measure distance and set lane width for future new lanes
- Set Lane Start/Stop: Click to activate lane length tool.  Select two locations on current image to identify a new lane.  Width will be set according to selected width (either manually entered or selected via width measurement tool)
- Draw Lanes: Draw all lanes saved in memory (including those loaded from a CSV file)
- Save Lane: Save current lane (width, start/stop and ID) to memory.  Subsequently saving to CSV will also save this lane to CSV.
- Plot Profile: Plot the intensity profile of all currently saved lanes.

Band Manipulation:

- Set Band:  Click to activate band selection tool.  Click any two points in a lane to set the locations of a band.  These are automatically saved to memory and will be saved to CSV along with the corresponding lanes.  To overwrite a band, simply select two new spots on a lane.  The profile plot needs to be computed first before going to band selection mode.

### Typical Workflow

- Set width using lane measurement tool.
- Set length of first lane using lane length tool.
- Lock lane length for all subsequent lanes.
- Save current lane to memory.
- Change lane ID.  Set positions of new lane and save to memory.  Repeat for all lanes.
- Draw all saved lanes on the plot.
- Plot all profiles.
- Select a band for each lane.
- Save to CSV.