In [3]:
# Importo librerías necesarias

import cv2
import pylab as plt
import matplotlib
import numpy as np
import pickle
import pandas as pd
import os
from sklearn.linear_model import LinearRegression
import time
import pykalman
from scipy.optimize import curve_fit
from scipy.ndimage import gaussian_filter1d
from scipy.signal import find_peaks
import functools
import traceback

In [5]:
# funciones auxiliares

# registro el tiempo que pasó desde el último punto de chequeo
def save_time(bff):
    this_time = time.process_time()
    bff["rec_times"].append(this_time - bff["last_time"])
    bff["last_time"] = this_time

# creo una máscara encontrando las diferencias entre dos imágenes y quedándome solo con las que superan un cierto umbral
def detect_movement(img, prev_img, take_abs = False, positive = True, threshold = None):
    #prev_img = cv2.cvtColor(prev_img, cv2.COLOR_BGR2GRAY)
    prev_img = prev_img.astype('int16')
    
    #img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    img = img.astype('int16')
    
    dif = prev_img - img
    
    if take_abs:
        dif = np.abs(dif)
    elif positive:
        dif[dif < 0] = 0
    else:
        dif[dif > 0] = 0
        dif = np.abs(dif)
    
    dif = dif.astype('uint8')
    
    if threshold is not None:
        dif[dif > threshold] = 255
    
    return dif    

# dada una posición y un rango, encuentro el cuadrado cuyo centro es esa posición y devuelvo sus coordenadas
def get_bounds(pos, r, width, height):
    if pos is None:
        return 0, width, 0, height
    
    minx = max(int(pos[0]-r), 0)
    maxx = min(int(pos[0]+r), width)
    miny = max(int(pos[1]-r), 0)
    maxy = min(int(pos[1]+r), height)
    
    return minx, maxx, miny, maxy

# dada una máscara, encuentro los contornos, me quedo con los que no son muy chicos y los uno en uno solo
# si no encontré contorno válido luego de este procedimiento, me quedo con el último contorno válido que tengo
def get_contour(p, bff):
    
    if bff["cv2version"] == 2 or bff["cv2version"] == 4:
        contours, _ = cv2.findContours(bff["maskroi"], cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    else:
        _, contours, _ = cv2.findContours(bff["maskroi"], cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    
    big_conts = []

    count = 0
    
    area = 0
    contour = None
    for cnt in contours:
        if cv2.contourArea(cnt) > area:
            area = cv2.contourArea(cnt)
            contour = cnt
    
    if contour is None:
        print("Can't find animal.")
    else:        
        contour = contour.squeeze()
        contour = contour + np.array([bff["minx"], bff["miny"]], dtype = contour.dtype)
        contour = np.expand_dims(contour, axis = 1)
    
    return contour

# dado un frame del video, encuentro sus diferencias con n frames del video al azar (precargados), promedio todas las diferencias
# y me quedo solo con lo que supere un cierto threshold
def get_mask(p, bff, samples = None):
    if bff["anterior"] is not None:
        bff["minx"], bff["maxx"], bff["miny"], bff["maxy"] = get_bounds(bff["anterior"], p["maxRange"][0], bff["width"], bff["height"])
    else:
        bff["minx"] = p["xmarginleft"][0]
        bff["miny"] = 0
        bff["maxx"] = bff["width"] - p["xmarginright"][0]
        bff["maxy"] = bff["height"]
    
    roi = bff["image"][bff["miny"]:bff["maxy"], bff["minx"]:bff["maxx"]]
    bff["unprocessed_roi"] = roi.copy()
    
    movements = []
    
    if samples is None:
        samples = p["samples"][0]
    
    for curr_frame in np.random.choice(len(bff["sparse_frames"]), size = samples):
        other = bff["sparse_frames"][curr_frame]
        other_roi = other[bff["miny"]:bff["maxy"], bff["minx"]:bff["maxx"]]
        
        movements.append(detect_movement(roi, other_roi))
    
    movements = np.array(movements, dtype ='uint8')
    mean_figure = movements.mean(axis = 0)
    mean_figure = np.array(mean_figure, dtype ='uint8')
    
    mean_figure=cv2.addWeighted(mean_figure,p["alpha"][0],np.zeros(mean_figure.shape, mean_figure.dtype),0,0)
    
    mean_figure = mean_figure - mean_figure.min()
    mean_figure = mean_figure / mean_figure.max()
    mean_figure = (mean_figure * 255).astype('uint8')
    
    
    fishroi = mean_figure.copy()
    
    if p["threshold"][0] is not None:
        mean_figure[mean_figure >= p["threshold"][0]] = 255
        mean_figure[mean_figure < p["threshold"][0]] = 0
    
    masked_fishroi = 255 - roi
    
    if len(masked_fishroi.shape) > 2:
        masked_fishroi = cv2.cvtColor(masked_fishroi, cv2.COLOR_RGB2GRAY)
    
    masked_fishroi = masked_fishroi*(mean_figure/mean_figure.max())
    bff["masked_fishroi"] = masked_fishroi.astype('uint8')
    
    return mean_figure, fishroi

# para un dado contorno, calculo el vector que coincide con la dirección de mayor varianza de los puntos que lo componen
# y devuelvo los parámetros necesarios para construir la recta que coincide con esa dirección
def getMajorAxis(bff, contour):
    contour = np.array(contour, dtype = 'float32').squeeze()
    
    regX = contour[:,0].reshape(-1, 1)
    regY = contour[:,1].reshape(-1, 1)

    rangeX = max(regX) - min(regX)
    rangeY = max(regY) - min(regY)

    if rangeY > rangeX:
        temp = regX
        regX = regY
        regY = temp
    
    reg = LinearRegression().fit(regX, regY)
    a = reg.coef_[0][0]
    
    intercept = reg.intercept_
    
    regMinX = int(min(regX))
    regMaxX = int(max(regX))
    regMinY = int(regMinX * a + intercept)
    regMaxY = int(regMaxX * a + intercept)
    
    if rangeY <= rangeX:
        return (regMinX, regMinY), (regMaxX, regMaxY), a, intercept
    else:
        return (regMinY, regMinX), (regMaxY, regMaxX), 1/a, (-1)*intercept/a

# promedio un array cada n elementos
def average_by(arr, n):
    arr = np.array(arr)
    
    modn = len(arr)%n
    
    if modn != 0:
        arr = arr[:-modn]
    
    res = np.mean(arr.reshape(-1, n), axis=1)
    
    return res

# promedio un array en ventanas de n elementos (moving window con stride de 1)
def moving_average(arr, n):
    arr = np.array(arr)
    
    final_n = len(arr)-n + 1
    
    res = np.zeros((final_n,))
    
    for i in range(final_n):
        res[i] = arr[i:i+n].mean(axis=0)
    
    return res        
        
# toma un array de posiciones (x, y) y lo transforma a un array de distancias entre cada posicion consecutiva
def pos_to_dist(pos, n = None, avg = moving_average):
    if n is not None:
        posx = avg(pos[:,0], n)
        posy = avg(pos[:,0], n)
        pos = np.column_stack([posx, posy])
    
    deltas = np.diff(pos, axis = 0)
    dx = deltas[:,0]
    dy = deltas[:,1]
    dist = np.sqrt(dx**2 + dy**2)
    return dist

# toma dos pendientes y calcula el ángulo entre ellas
def angle_from_slopes(m1, m2):
    if m1 == m2:
        return 0.0
    
    v1 = np.array([1,m1])
    v1 = v1/np.linalg.norm(v1)

    v2 = np.array([1,m2])
    v2 = v2/np.linalg.norm(v2)

    a = np.arccos(v1 @ v2)/(2*np.pi)
    
    if a > 0.25:
        a = 0.5 - a
    
    return a

def turn_from_slopes(axis1, axis2):
    angles = []
    
    for i in range(len(axis1)):
        m1 = axis1[i]
        m2 = axis2[i]
        angles.append(angle_from_slopes(m1, m2))
    
    return angles

def two_vector_angle(v1, v2):
    #https://stackoverflow.com/questions/14066933/direct-way-of-computing-clockwise-angle-between-2-vectors
    v1 = np.array(v1)
    v2 = np.array(v2)
    
    v1 = v1/np.linalg.norm(v1)
    v2 = v2/np.linalg.norm(v2)

    dot = v1 @ v2
    dot = np.minimum(dot, 1) if dot > 0 else np.maximum(dot, -1)
    
    det = v1[0]*v2[1] - v1[1]*v2[0]
    a = np.arctan2(dot, det)
    a = (a/(2*np.pi) - 0.25)*(-1)
    
    if a > 0.5:
        a = a - 1
    
    return a

def three_point_angle(center, p1, p2):
    center = np.array(center)
    p1 = np.array(p1)
    p2 = np.array(p2)
    
    v1 = p1 - center
    v2 = p2 - center
    
    return two_vector_angle(v1, v2)

# toma un array de pendientes y lo convierte en un array de ángulos entre pendientes consecutivas
def slope_to_angle(slopes, n = None, avg = moving_average):
    if n is not None:
        slopes = avg(slopes, n)
    
    slopes = np.array(slopes)
    angles = []
    for i in range(len(slopes)-1):
        angles.append(angle_from_slopes(slopes[i], slopes[i+1]))
    
    return np.array(angles)

def kalman(data, n_dim_obs = 2, n_dim_state = 2, n_iter=5, initial_cov = 0.1):
    kf = pykalman.KalmanFilter(
      n_dim_state=n_dim_state,
      n_dim_obs=n_dim_obs,
      em_vars=['transition_matrices', 'transition_covariance',
               'observation_matrices', 'observation_covariance', 'initial_state_covariance']
    )
    
    # Provide the initial states
    kf.initial_state_mean = data[0]
    #kf.initial_state_covariance = initial_cov*np.eye(n_dim_state)

    # Estimate the parameters from data using the EM algorithm
    kf.em(data, n_iter=n_iter)
    
    # Run Kalman Filter on data
    mu, sigma = kf.smooth(data)
    return mu, sigma

def draw_bounds(img, corner1, corner2, color, alpha, thickness):
    rect = np.zeros_like(img, np.uint8)
    cv2.rectangle(rect, corner1, corner2, color, thickness = thickness)
    mask = rect.astype(bool)
    img[mask] = cv2.addWeighted(img, alpha, rect, 1 - alpha, 0)[mask]
    return img

def findOrthogonal(pt1, pt2, goes_through = 2):
    pt1 = np.asarray(pt1)
    pt2 = np.asarray(pt2)
    
    dif = np.array(pt2 - pt1)
    dif = dif/np.linalg.norm(dif)
    
    ortho_dif = [dif[1], -dif[0]]
    
    if ortho_dif[0] == 0:
        a = None
        
        if goes_through == 1:
            b = pt1[0]
        else:
            b = pt2[0]
        
        return a, b
    
    a = ortho_dif[1]/ortho_dif[0]
    
    if goes_through == 1:
        b = pt1[1] - a*pt1[0]
    else:
        b = pt2[1] - a*pt2[0]
    
    return a, b
   
def drawTrajectory(positions, img, color, thickness):
    for i in range(len(positions)-2):
        pos_i = positions[i]
        pos_f = positions[i+1]
        
        if pos_i is None or pos_f is None:
            continue
        
        pos_i = [int(pos_i[0]), int(pos_i[1])]
        pos_f = [int(pos_f[0]), int(pos_f[1])]
    
        cv2.line(img, tuple(pos_i), tuple(pos_f), color = color, thickness = thickness)
        
    return img

def convolve(image, kernel):
    # grab the spatial dimensions of the image, along with
    # the spatial dimensions of the kernel
    (iH, iW) = image.shape[:2]
    (kH, kW) = kernel.shape[:2]
    # allocate memory for the output image, taking care to
    # "pad" the borders of the input image so the spatial
    # size (i.e., width and height) are not reduced
    pad = (kW - 1) // 2
    image = cv2.copyMakeBorder(image, pad, pad, pad, pad,
        cv2.BORDER_REPLICATE)
    output = np.zeros((iH, iW), dtype="float32")
    
    # loop over the input image, "sliding" the kernel across
    # each (x, y)-coordinate from left-to-right and top to
    # bottom
    for y in np.arange(pad, iH + pad):
        for x in np.arange(pad, iW + pad):
            # extract the ROI of the image by extracting the
            # *center* region of the current (x, y)-coordinates
            # dimensions
            roi = image[y - pad:y + pad + 1, x - pad:x + pad + 1]
            # perform the actual convolution by taking the
            # element-wise multiplicate between the ROI and
            # the kernel, then summing the matrix
            k = (roi * kernel).sum()
            # store the convolved value in the output (x,y)-
            # coordinate of the output image
            output[y - pad, x - pad] = k
    
    # rescale the output image to be in the range [0, 255]
    output = output - np.min(output)
    output = output/np.max(output)
    output = (output * 255).astype("uint8")
    # return the output image
    return output

def get_convolved_image(img, n=9, k=1.3):
    v = np.arange(-(n-1)/2, (n-1)/2 + 1)
    v = - abs(v)

    mat = v.copy()

    for i in np.arange(1,(n-1)/2 + 1):
        mat = np.vstack((v-i, mat))
        mat = np.vstack((mat, v-i))

    mat = k**mat
    
    output = convolve(img, mat)
    return output


def isValidAnterior(p, bff, x, y, roi):
    points = get_n_directions([x, y], n = p["angle_precision_validation"][0], l = p["angle_distance_validation"][0])
    bff["validationPoints"] = np.array(points) #+ np.array([bff["minx"], bff["miny"]])
    
    bs = []
    window = 2
    for point in points:
        #get window
        minxi, maxxi, minyi, maxyi = get_bounds(point, window, roi.shape[1], roi.shape[0])
        
        #chequear que no me pase
        roi_roi = roi[minyi:maxyi,minxi:maxxi]
        
        #get mean of window
        bs.append(roi_roi.mean())
    
    max_index = np.argmax(bs)
    bff["max_index_validation"] = max_index
    
    opposite_index = int((max_index + (len(points)/2)) % len(points))
    bff["opp_index_validation"] = opposite_index
    
    
    max_point = points[max_index]

        
    n_opposites = int(p["angle_precision_validation"][0] / 4)
    opposites = []
    for i in range(int(-n_opposites/2), int(n_opposites/2 + 1)):
        idx = ((opposite_index + i) + len(points)) % len(points)
        opposites.append(points[idx])
    
    bs = []
    for opposite_point in opposites:
        op_y = min(max(0, int(opposite_point[1])),roi.shape[0]-1)
        op_x = min(max(0, int(opposite_point[0])),roi.shape[1]-1)
        
        b = roi[op_y, op_x]
        bs.append(b)
    
    
    return np.max(bs) < p["max_opp_b_validation"][0]

def get_anterior(p, bff, counter = 0):
    output = cv2.GaussianBlur(bff["fishroi"], (p["kernel_size"][0],p["kernel_size"][0]), p["kernel_k"][0])
    output = output - output.min()
    output = output / output.max()
    output = (output * 255).astype('uint8')
    
    ys, xs = np.unravel_index(np.argsort(output.ravel()),output.shape)
    
    ys = np.flip(ys)
    xs = np.flip(xs)
    
    maxPos = np.array([0,0])
    
    for i in range(counter, len(xs)):
        xi = xs[i]
        yi = ys[i]
        
        if isValidAnterior(p, bff, xi, yi, output):
            maxPos = np.array([xi, yi])
            break
    
    maxPos = [maxPos[0] + bff["minx"], maxPos[1] + bff["miny"]]
    
    return maxPos, output

def get_point_at_an_angle(center, angle, l):
    dx = np.cos(angle)*l
    dy = np.sin(angle)*l
    return [center[0]+dx, center[1]+dy]

def get_n_directions(center, n, l, limit = None):
    if limit is not None:
        pt1, pt2 = limit
        
        rangeX = abs(pt1[0] - pt2[0])
        rangeY = abs(pt1[1] - pt2[1])
        
        buffered_point_1 = pt1
        buffered_point_2 = pt2

        if rangeY < rangeX:
            pt1 = [pt1[1], pt1[0]]
            pt2 = [pt2[1], pt2[0]]
        
        a, b = findOrthogonal(pt1, pt2, goes_through = 2)
        
        if rangeY < rangeX:
            if a != 0:
                b = (-1)*b/a
                a = 1/a
            else:
                a = None
    
        pt1 = buffered_point_1
        pt2 = buffered_point_2
    
    points = []
    
    for i in range(n):
        angle = (i/n)*2*np.pi
        new_point = get_point_at_an_angle(center, angle, l)
        
        add_point = True
        
        if limit is not None:
            if a is not None:
                center_side = pt1[0]*a + b <= pt1[1]
                point_side = new_point[0]*a + b <= new_point[1]
            else:
                center_side = pt1[0] <= b
                point_side = new_point[0] <= b
            
            add_point = center_side != point_side
        
        if add_point:
            points.append(new_point)
    
    return points

def build_rectangle(pt1, pt2, w):
    pt1 = np.array(pt1)
    pt2 = np.array(pt2)
    
    dif = pt2 - pt1
    perp = [dif[1], -dif[0]]
    perp = perp/(np.sqrt(perp[0]**2 + perp[1]**2))
    perp = perp*w
    
    crn11 = pt1 - perp
    crn12 = pt1 + perp
    crn21 = pt2 - perp
    crn22 = pt2 + perp
    
    area = np.linalg.norm(dif) * w * 2
    
    return crn11, crn12, crn22, crn21, area
    
#(ΔABC) = (1/2) |x1(y2 − y3) + x2(y3 − y1) + x3(y1 − y2)|
def area_of_triangle(pt1, pt2, pt3):
    return (1/2) * abs(pt1[0]*(pt2[1] - pt3[1]) + pt2[0]*(pt3[1] - pt1[1]) + pt3[0]*(pt1[1] - pt2[1]))

def is_point_in_rectangle(my_point, crn11, crn12, crn22, crn21, rect_area, epsilon = 0.1):
    area1 = area_of_triangle(crn11, crn12, my_point)
    area2 = area_of_triangle(crn12, crn22, my_point)
    area3 = area_of_triangle(crn22, crn21, my_point)
    area4 = area_of_triangle(crn21, crn11, my_point)
    
    total_area = area1 + area2 + area3 + area4
    
    return total_area - rect_area < epsilon

def pick_next_center(center, points, roi, lowerx, lowery, w = 2, step = 2, length = 10):
    
    values = []
    for point in points:
        crn11, crn12, crn22, crn21, area = build_rectangle(center, point, w)
        pixel_values = []

        corners = np.array([crn11, crn12, crn22, crn21])
        minx = int(corners.min(axis = 0)[0])
        miny = int(corners.min(axis = 0)[1])
        maxx = int(corners.max(axis = 0)[0]) + 1
        maxy = int(corners.max(axis = 0)[1]) + 1

        for y in range(miny, maxy, step):
            corrected_y = y - lowery
            for x in range(minx, maxx, step):
                corrected_x = x - lowerx

                if is_point_in_rectangle([x, y], crn11, crn12, crn22, crn21, area):
                    pixel_values.append(roi[min(corrected_y, roi.shape[0]-1), min(corrected_x, roi.shape[1]-1)])
        
        if len(pixel_values) > 0:
            mean_brightness = np.mean(pixel_values)
        else:
            mean_brightness = 0
        
        values.append(mean_brightness)
    
    next_center = points[np.argmax(values)]

    center = np.array(center)
    next_center = np.array(next_center)
    
    dif = next_center - center
    dif = dif/(np.linalg.norm(dif))
    dif = dif * length
    
    next_center = center + dif
    
    return next_center

def filter_points_by_angle(center, last_center, points, angle):
    new_points = []
    
    min_angle = 0.5 - angle/2
    
    for point in points:
        a = three_point_angle(center, last_center, point)
        
        if abs(a) > min_angle:
            new_points.append(point)
    
    return new_points

def filter_points_by_bound(points, minx, miny, maxx, maxy):
    new_points = []
    
    for poi in points:
        if minx < poi[0] and poi[0] < maxx and miny < poi[1] and poi[1] < maxy:
            new_points.append(poi)
    
    return new_points

def pad_image(image, margin = 20):
    new_image = np.zeros((image.shape[0] + 2*margin, image.shape[1] + 2*margin))
    new_image[margin:image.shape[0] + margin, margin:image.shape[1] + margin] = image
    return new_image

def find_animal_shape(p, bff):
    centers = [bff["anterior"]]
    
    bff["segment_points"] = []
    
    for i in range(p["n_segments"][0]):
        limit_line_points = None
        
        #if len(centers) > 1:
        #    limit_line_points = (centers[-2], centers[-1])
        
        points = get_n_directions(centers[-1], n = p["angle_precision"][0], l = p["angle_distance"][0], limit = limit_line_points)
        
        #points = filter_points_by_bound(points, bff["minx"], bff["miny"], bff["maxx"], bff["maxy"])
        
        if len(centers) > 1:
            points = filter_points_by_angle(centers[-1], centers[-2], points, p["next_segment_angle"][0])
        
        bff["segment_points"].append(np.array(points) - np.array([bff["minx"], bff["miny"]]))
        
        #im = bff["mask"] # o será mejor bff["fishroi"] ?
        im = bff["convolvedRoi"]
        
        next_center = pick_next_center(centers[-1], points, im, bff["minx"], bff["miny"],
                                   w = p["angle_rect_width"][0], step = p["angle_step"][0],
                                   length = p["segment_length"][0])
        centers.append(next_center)
    
    return centers

def append_df_to_excel(filename, df, sheet_name='Sheet1', startrow=None,
                       truncate_sheet=False, 
                       **to_excel_kwargs):
    """
    Append a DataFrame [df] to existing Excel file [filename]
    into [sheet_name] Sheet.
    If [filename] doesn't exist, then this function will create it.

    Parameters:
      filename : File path or existing ExcelWriter
                 (Example: '/path/to/file.xlsx')
      df : dataframe to save to workbook
      sheet_name : Name of sheet which will contain DataFrame.
                   (default: 'Sheet1')
      startrow : upper left cell row to dump data frame.
                 Per default (startrow=None) calculate the last row
                 in the existing DF and write to the next row...
      truncate_sheet : truncate (remove and recreate) [sheet_name]
                       before writing DataFrame to Excel file
      to_excel_kwargs : arguments which will be passed to `DataFrame.to_excel()`
                        [can be dictionary]

    Returns: None
    """
    from openpyxl import load_workbook
    import pandas

    # ignore [engine] parameter if it was passed
    if 'engine' in to_excel_kwargs:
        to_excel_kwargs.pop('engine')

    writer = pandas.ExcelWriter(filename, engine='openpyxl', mode='a')

    # Python 2.x: define [FileNotFoundError] exception if it doesn't exist 
    try:
        FileNotFoundError
    except NameError:
        FileNotFoundError = IOError


    try:
        # try to open an existing workbook
        writer.book = load_workbook(filename)

        # get the last row in the existing Excel sheet
        # if it was not specified explicitly
        if startrow is None and sheet_name in writer.book.sheetnames:
            startrow = writer.book[sheet_name].max_row

        # truncate sheet
        if truncate_sheet and sheet_name in writer.book.sheetnames:
            # index of [sheet_name] sheet
            idx = writer.book.sheetnames.index(sheet_name)
            # remove [sheet_name]
            writer.book.remove(writer.book.worksheets[idx])
            # create an empty sheet [sheet_name] using old index
            writer.book.create_sheet(sheet_name, idx)

        # copy existing sheets
        writer.sheets = {ws.title:ws for ws in writer.book.worksheets}
    except FileNotFoundError:
        # file does not exist yet, we will create it
        pass

    if startrow is None:
        startrow = 0

    # write out the new sheet
    df.to_excel(writer, sheet_name, startrow=startrow, **to_excel_kwargs)

    # save the workbook
    writer.save()

In [7]:
# encapsulo partes del loop principal

def find_animal(p, bff, res):
    bff["maskroi"], bff["fishroi"] = get_mask(p, bff)
    
    counter = 0
    
    while True:
        bff["anterior"], bff["convolvedRoi"] = get_anterior(p, bff, counter)
        bff["axis"] = find_animal_shape(p, bff)
        
        head_vector = np.array(bff["axis"][0]) - np.array(bff["axis"][1])
        
        if "last_head_vector" in bff and bff["last_head_vector"] is not None: 
            angle = two_vector_angle(bff["last_head_vector"], head_vector)
            
            if abs(angle) < 0.25:
                bff["last_head_vector"] = head_vector
                break
            else:
                #print("CORRECTING!!!!!!!!!!!!!!!", angle, counter, bff["axis"])
                counter = counter + 1
        else:
            bff["last_head_vector"] = head_vector
            break

def save_params(p, bff, res):
    res["anterior"].append(bff["anterior"])
    res["axis"].append(bff["axis"])
        

def draw_frame(p, bff, res, live = True, estimatedCenter = None, center = None, slope = None):
    bff["image"] = bff["raw_image"]
    
    if len(bff["image"].shape) == 2:
        bff["image"] = cv2.cvtColor(bff["image"], cv2.COLOR_GRAY2BGR)
    
    if live and not bff["mask"]:

        bff["image"] = drawTrajectory(res["anterior"], bff["image"], color=(0,0,0), thickness=1)

        for center in bff["axis"]:
            cv2.circle(bff["image"], (int(center[0]), int(center[1])), 1, (255,0,150), -1)

        cv2.circle(bff["image"], (int(bff["anterior"][0]), int(bff["anterior"][1])), 2, (0,0,255), -1)

        for i in range(len(bff["axis"])-1):
            p1 = bff["axis"][i]
            p2 = bff["axis"][i+1]

            p1 = (int(p1[0]), int(p1[1]))
            p2 = (int(p2[0]), int(p2[1]))

            cv2.line(bff["image"], p1, p2, color = (0, 255, 0), thickness = 1)            

        bff["image"] = draw_bounds(bff["image"],
                                   (bff["minx"], bff["miny"]),
                                   (bff["maxx"], bff["maxy"]),
                                   (255, 0, 255), 0.5, 2)
    
    cv2.imshow('image',bff["image"])

def draw_params(p, bff, res, pshow):
    bff["p_frame"] = np.zeros((int(pshow["base_p_height"]*p["imgResize"][0]),
                             int(pshow["base_p_width"]*p["imgResize"][0]), 3), np.uint8)
    
    count = 0
    
    for key in p:
        value = round(p[key][0], 2)
        cv2.putText(bff["p_frame"], key + ": " + str(value), (int(pshow["base_p_height"]*pshow["word_x"]*p["imgResize"][0]),
                     int(bff["height"]*(pshow["word_start_y"] + count*pshow["word_vertical_jump"])*p["imgResize"][0])),
                    cv2.FONT_HERSHEY_SIMPLEX, pshow["word_size"]*p["imgResize"][0], (150, 0, 255), pshow["stroke_width"])
        cv2.putText(bff["p_frame"], key + ": ", (int(pshow["base_p_height"]*pshow["word_x"]*p["imgResize"][0]),
                     int(bff["height"]*(pshow["word_start_y"] + count*pshow["word_vertical_jump"])*p["imgResize"][0])),
                    cv2.FONT_HERSHEY_SIMPLEX, pshow["word_size"]*p["imgResize"][0], (255, 255, 255), pshow["stroke_width"])

        if count == bff["current_parameter"]:
            cv2.putText(bff["p_frame"], key + ": ", (int(pshow["base_p_height"]*pshow["word_x"]*p["imgResize"][0]),
                     int(bff["height"]*(pshow["word_start_y"] + count*pshow["word_vertical_jump"])*p["imgResize"][0])),
                        cv2.FONT_HERSHEY_SIMPLEX, pshow["word_size"]*p["imgResize"][0], (0, 255, 0), pshow["stroke_width"])
            bff["current_key"] = key

        count += 1
    
    
    cv2.putText(bff["p_frame"], "Frame: " + str(bff["frame"]), (int(pshow["base_p_height"]*pshow["word_x"]*p["imgResize"][0]),
                 int(pshow["base_p_height"]*0.05*p["imgResize"][0])), cv2.FONT_HERSHEY_SIMPLEX,
                pshow["word_size"]*p["imgResize"][0]*1.3, (150, 0, 255), pshow["stroke_width"])
    cv2.putText(bff["p_frame"], "Frame: ", (int(pshow["base_p_height"]*pshow["word_x"]*p["imgResize"][0]),
                 int(pshow["base_p_height"]*0.05*p["imgResize"][0])), cv2.FONT_HERSHEY_SIMPLEX,
                pshow["word_size"]*p["imgResize"][0]*1.3, (255, 255, 255), pshow["stroke_width"])
    
    cv2.imshow('parameters', bff["p_frame"])

def draw_mask(p, bff, res):
    if not bff["mask"]:
        drawing = cv2.cvtColor(bff["convolvedRoi"],cv2.COLOR_GRAY2RGB)
        
        for center in bff["axis"]:
            cv2.circle(drawing, (int(center[0] - bff["minx"]), int(center[1] - bff["miny"])), 1, (255,0,150), -1)
        
        for i in range(len(bff["axis"])-1):
            p1 = bff["axis"][i]
            p2 = bff["axis"][i+1]
            
            p1 = (int(p1[0] - bff["minx"]), int(p1[1] - bff["miny"]))
            p2 = (int(p2[0] - bff["minx"]), int(p2[1] - bff["miny"]))
            
            cv2.line(drawing, p1, p2, color = (0, 255, 0), thickness = 1)
        
        cv2.circle(drawing, (int(bff["anterior"][0] - bff["minx"]), int(bff["anterior"][1] - bff["miny"])), 1, (150,0,150), -1)
    else:
        
        drawing = cv2.cvtColor(bff["convolvedRoi"],cv2.COLOR_GRAY2RGB)
        
        for i in range(1, len(bff["segment_points"])):
            list_of_points = bff["segment_points"][i]
            
            for each_point in list_of_points:
                cv2.circle(drawing, (int(each_point[0]), int(each_point[1])), 1, (0,255,150), -1)
        
        for center in bff["axis"]:
            cv2.circle(drawing, (int(center[0] - bff["minx"]), int(center[1] - bff["miny"])), 1, (255,0,150), -1)
        
    cv2.imshow("roi", drawing)
    
def handle_keys(p, bff, res):
    k = cv2.waitKey(1) & 0xFF
    bff["read"] = True

    if k == 27:
        bff["broken"] = True
    
    elif k == ord('p') or k == ord('P'):
        bff["play"] = not bff["play"]
        
    elif k == ord('x') or k == ord('X'):
        bff["mask"] = not bff["mask"]
        
    elif k == ord('s') or k == ord('S'):
        bff["current_parameter"] = (bff["current_parameter"] + 1) % len(p)
    elif k == ord('w') or k == ord('W'):
        bff["current_parameter"] = (bff["current_parameter"] - 1 + len(p)) % len(p)
    elif k == ord('a') or k == ord('A'):
        p[bff["current_key"]][0] = max(p[bff["current_key"]][0] - p[bff["current_key"]][3], p[bff["current_key"]][1])
    elif k == ord('d') or k == ord('D'):
        p[bff["current_key"]][0] = min(p[bff["current_key"]][0] + p[bff["current_key"]][3], p[bff["current_key"]][2])
    
    elif k == ord('e') or k == ord('E'):
        bff["frame"] = bff["frame"] + 1
        bff["read"] = False
    elif k == ord('q') or k == ord('Q'):
        bff["frame"] = max(bff["frame"] - 1, 0)
        bff["read"] = False
    
    elif k == ord('c') or k == ord('C'):       # si pulsé 'f', avanzo N fotogramas en el video
        bff["frame"] = bff["frame"] + 100
        bff["read"] = False
    elif k == ord('z') or k == ord('Z'):       # si pulsé 'f', avanzo N fotogramas en el video
        bff["frame"] = bff["frame"] - 100
        bff["read"] = False
    
    if bff["play"]:
        res["frames"].append(bff["frame"])
        bff["frame"] = bff["frame"] + p["precision"][0]
        bff["read"] = False
        
    
def init_video(p, bff, res, video_dir):
    bff["frame"] = 0
    
    bff["vidcap"] = cv2.VideoCapture(video_dir)

    bff["vidcap"].set(cv2.CAP_PROP_POS_FRAMES, 0)
    success,image = bff["vidcap"].read()

    print("Opened video: " + str(success))

    bff["total_frames"] = int(bff["vidcap"].get(cv2.CAP_PROP_FRAME_COUNT))
    bff["fps"] = bff["vidcap"].get(cv2.CAP_PROP_FPS)
    
    bff["sparse_frames"] = get_video_frames(p, bff)
    
    p["maxRange"][0] = 100
    
    bff["anterior"] = None
    
    fish_size = get_fish_size(p, bff, n = 20)
    p["maxRange"][0] = max(fish_size*p["rangeMultiplier"][0], p["minRange"][0])
    p["segment_length"][0] = fish_size*p["segmentMultiplier"][0]
    
    cv2.namedWindow('image', cv2.WINDOW_NORMAL)

    cv2.setWindowProperty('image',cv2.WND_PROP_FULLSCREEN,cv2.WINDOW_FULLSCREEN)
    cv2.setWindowProperty('image',cv2.WND_PROP_FULLSCREEN,cv2.WINDOW_NORMAL)

    cv2.namedWindow('roi', cv2.WINDOW_NORMAL)

    bff["broken"] = False
    
    
    bff["last_head_vector"] = None
    
    bff["last_time"] = time.process_time()
    bff["rec_times"] = []

# tomo las características del video y sampleo una serie de fotogramas que almaceno para luego usarlos rápidamente

def get_video_frames(p, bff):
    sparse_frames = []

    for i in range(0, bff["total_frames"], p["sampledFrameFrequency"][0]):
        bff["vidcap"].set(cv2.CAP_PROP_POS_FRAMES, i)
        success, image = bff["vidcap"].read()
        
        if len(image.shape) > 2:
            image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        
        if "image_limits" in p:
            image = image[p["image_limits"][0]:p["image_limits"][1], p["image_limits"][2]:p["image_limits"][3]]
            image = pad_image(image)
        
        #print("Sparse image",image.dtype)
        
        sparse_frames.append(image)

    return sparse_frames

def get_fish_size(p, bff, n = 20):
    distances = []

    for i in range(0, len(bff["sparse_frames"]), int(len(bff["sparse_frames"])/n)):
        bff["image"] = bff["sparse_frames"][i]

        bff["width"] = bff["image"].shape[1]
        bff["height"] = bff["image"].shape[0]

        bff["maskroi"], bff["fishroi"] = get_mask(p, bff, samples = 150)
        contour = get_contour(p, bff)
        
        if contour is not None:
            contour = contour.squeeze()
        else:
            continue

        maxd = 0

        for j in range(len(contour)):
            p1 = np.array(contour[j])

            for k in range(j+1, len(contour)):
                p2 = np.array(contour[k])

                d = np.linalg.norm(p2 - p1)

                if d > maxd:
                    maxd = d

        distances.append(maxd)
    
    if len(distances) == 0:
        print("Can't find fish size.")
        return 25.0
    
    distances = np.array(distances)
    return np.median(distances)

def init_frame(p, bff, pshow, res):
    bff["vidcap"].set(cv2.CAP_PROP_POS_FRAMES, bff["frame"])
    success, bff["raw_image"] = bff["vidcap"].read()
    
    if len(bff["raw_image"].shape) > 2:
        bff["raw_image"] = cv2.cvtColor(bff["raw_image"], cv2.COLOR_BGR2GRAY)
    
    if "image_limits" in p:
        bff["raw_image"] = bff["raw_image"][p["image_limits"][0]:p["image_limits"][1], p["image_limits"][2]:p["image_limits"][3]]
        bff["raw_image"] = pad_image(bff["raw_image"])
        
    bff["raw_image"] = bff["raw_image"].astype(np.uint8)
    
    bff["image"] = bff["raw_image"].copy()
    
    bff["width"] = bff["image"].shape[1]
    bff["height"] = bff["image"].shape[0]
    
    cv2.resizeWindow('image',int(bff["width"]*p["imgResize"][0]),int(bff["height"]*p["imgResize"][0]))
    cv2.resizeWindow('roi',int(p["maxRange"][0]*2*p["imgResize"][0]*3),int(p["maxRange"][0]*2*p["imgResize"][0]*3))
    
    bff["p_frame"] = np.zeros((int(pshow["base_p_height"]*p["imgResize"][0]),
                             int(pshow["base_p_width"]*p["imgResize"][0]), 3), np.uint8)

def track(p, bff, pshow, video_dir, image_limits = None):
    
    if image_limits is not None:
        p["image_limits"] = image_limits
    
    res = {}

    res["frames"] = []
    res["anterior"] = []
    res["axis"] = []
    
    init_video(p, bff, res, video_dir)
    bff["read"] = False

    while bff["frame"] < bff["total_frames"] and not bff["broken"]:
        if bff["frame"] < 0:
            bff["frame"] = 0

        if bff["frame"] >= bff["total_frames"]:
            break
        
        if not bff["read"]:
            init_frame(p, bff, pshow, res)
            find_animal(p, bff, res)
            save_params(p, bff, res)
        
        draw_frame(p, bff, res)
        draw_params(p, bff, res, pshow)
        draw_mask(p, bff, res)
        handle_keys(p, bff, res)

        save_time(bff)

    cv2.destroyAllWindows()
    
    return res

def process_results(p, bff, res, video, video_dir, n = 5, xlim = None):
    res["anterior"] = np.array(res["anterior"]) 
    res["dist"] = pos_to_dist(res["anterior"], n, avg = moving_average)
    res["dist"] = moving_average(res["dist"], n)
    plt.plot(res["dist"])
    #plt.xlim((200, 250))
    plt.show()
    
    res["axis"] = np.array(res["axis"])
    print(res["axis"].shape)

    res["axis_angles"] = []

    for i in range(len(res["axis"])):
        frame_angles = []

        for j in range(len(res["axis"][i])-2):
            p1 = res["axis"][i][j]
            center = res["axis"][i][j+1]
            p2 = res["axis"][i][j+2]
            a = three_point_angle(center, p1, p2)
            frame_angles.append(a)

        res["axis_angles"].append(frame_angles)

    res["axis_angles"] = np.abs(np.array(res["axis_angles"]))

    plt.plot(res["axis_angles"][:,0])
    #plt.xlim((300, 500))
    plt.show()
    plt.plot(res["axis_angles"][:,1])
    #plt.xlim((300, 500))
    plt.show()
    plt.plot(res["axis_angles"][:,2])
    #plt.xlim((300, 500))
    plt.show()

    plt.plot(res["axis_angles"].sum(axis=1))
    #plt.xlim((300, 500))
    plt.show()
    
    res["angles"] = []

    for i in range(len(res["axis"])-1):
        p1 = res["axis"][i,0,:]
        p2 = res["axis"][i,1,:]

        v1 = p2 - p1

        p1 = res["axis"][i+1,0,:]
        p2 = res["axis"][i+1,1,:]

        v2 = p2 - p1

        a = two_vector_angle(v1, v2)

        res["angles"].append(a)

    res["angles"] = np.array(res["angles"])
    plt.plot(res["angles"])
    #plt.xlim((150, 220))
    plt.show()
    
    save_results(p, bff, res, video, video_dir)

def save_results(p, bff, res, video, video_dir):
    os.chdir(video_dir)
    frames = np.array(res["frames"])
    anterior = np.array(res["anterior"])
    anterior_x = anterior[:,0]
    anterior_y = anterior[:,1]

    axis = np.array(res["axis"])
    axis_angles = np.array(res["axis_angles"])
    angles = np.array(res["angles"])

    dist = pos_to_dist(res["anterior"], n=1, avg = moving_average)
    dist = moving_average(dist, 1)
    dist = np.array(dist)

    colnames = ["frame", "anterior_x", "anterior_y", "anterior_angular_change", "distance"]
    results = np.vstack((frames[:-1], anterior_x[:-1], anterior_y[:-1], angles, dist))

    for i in range(axis.shape[1]):
        for j in range(axis.shape[2]):
            nm = "axis_point_" + str(i) + "_"
            nm = nm + "x" if j == 0 else nm + "y"
            results = np.vstack((results,axis[:-1, i, j]))
            colnames.append(nm)

    for i in range(axis_angles.shape[1]):
        nm = "axis_angle_" + str(i)
        results = np.vstack((results, axis_angles[:-1,i]))
        colnames.append(nm)

    results = results.T
    print(results.shape)
    print(colnames)

    measurements = {}

    for i in range(len(colnames)):
        measurements[colnames[i]] = results[:, i]

    output = pd.DataFrame.from_dict(measurements)
    #output.to_csv(video.split('.')[0] + ".csv",sep='\t')
    output.to_excel(video.split('.')[0] + ".xlsx")  

def look_for_video(main):
    for d in os.listdir():
        os.chdir(main)
        v = main + '/' + d
        if '.' in v:
            directories.append(v)
        else:
            os.chdir(v)
            look_for_video(v)