In [None]:
import cv2
import numpy
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.ndimage.filters import gaussian_filter

#Functions
def animaltracking(path):
    print("Chose with the rectangle the animal you want to track and press 'enter'.The animal will be tracked until the video ends or until you press 'q'")
    print("REMEMBER to adapt the dimension of the platform to the dimension of the video as close as posible for a better representation of the tracking heat maps and data reproducibility.")
    
    #Chose tracking method
    tracker = cv2.TrackerCSRT_create()
    tracker_name = 'CSRT tracker'

    #Capture video
    cap = cv2.VideoCapture(path)

    #Read the first frame
    ret, frame = cap.read()

    #Tracking
    roi = cv2.selectROI(frame, False)

    #Initialize tracker
    ret = tracker.init(frame, roi)

    result = []

    while cap.isOpened():
        #Read the first frame
        ret, frame = cap.read()
    
        #Update tracker
        success, roi = tracker.update(frame)
    
        (x,y,w,h) = tuple(map(int, roi))
    
        if success:
            pt1 = (x,y)
            pt2 = (x+w,y+h)
            cv2.rectangle(frame, pt1, pt2, (255,125,25), 3)
    
        else:
            cv2.putText(frame, "Fail", (100,200), cv2.FONT_HERSHEY_SIMPLEX, 1,(255,0,0),3)
    
        if ret == False:
            break
        
        cv2.putText(frame,tracker_name, (20,400), cv2.FONT_HERSHEY_SIMPLEX, 1,(0,255,0), 3)
    
        #Display result
        cv2.imshow(tracker_name,frame)
        #result
        result.append((x,y))
    
        if cv2.waitKey(1) & 0xFF == 113:
            break
        
    cap.release()
    cv2.destroyAllWindows()
    
    return result

def shallowsidetime(lst, videowidth, videoheight, platformwidth, platformheight):
    X = [0,videowidth]
    Y = [0,videoheight]
    data = {}
    
    for (x,y) in lst:
        X.append(x)
        Y.append(y)
        
    data = {'X': X, 'Y': Y}
    df = pd.DataFrame(data)
    
    #Adapt the video dimensions to the real platform dimensions
    df['Xcm'] = df['X']*platformwidth/X[1]
    df['Ycm'] = df['Y']*platformheight/Y[1]
    df['Ycm_inv'] = 0 + (platformheight - df['Ycm'])
    
    #Importan for heatmaps plotting
    vector_x = df['Xcm'].values
    vector_y = df['Ycm_inv'].values
    df['vector_x'] = vector_x
    df['vector_y'] = vector_y
    
    #Clasification of the position in each side (shallow or deep)
    Xmitad = 45 * 0.5
    df['side'] = df.Xcm > Xmitad
    df['side'] = df.side.replace({True : "shallowside", False : "deepside"})
    
    #Calculate the fraction of time spend in each side
    shallowcounter = 0
    deepcounter = 0
    for position in df['side']:
        if position == 'deepside':
            deepcounter = deepcounter + 1
        if position == 'shallowside':
            shallowcounter = shallowcounter + 1

    deeptimeporcentage = (deepcounter/25)*100/28 
    shallowtimeporcentage = (shallowcounter/25)*100/28
    dataplot = {'Deep side': [deeptimeporcentage], 'Shallow side': [shallowtimeporcentage]} 
    dataplot
    dt = pd.DataFrame(dataplot)
    
    #plot result
    sns.barplot(data = dt)
    plt.ylabel("Fraction of time spend in the shallow side per animal")
    plt.title("Virtual visual cliff")
    
    
    return df, dt

def animaltrackingheatmap(x, y, sigma, bins, platformwidht, platformheight, color,thres):  
    heatmap, xedges, yedges = np.histogram2d(x, y, bins=bins)
    heatmap = gaussian_filter(heatmap, sigma = sigma)
    extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
    
    fig= plt.figure(figsize=(10,10))
    a = plt.imshow(heatmap.T, extent = extent, origin ='lower', cmap = color, vmin = 0, vmax = thres)
    
    plt.suptitle('Heatmap')
    plt.axvline(x= platformwidht/2, color = 'white')
    plt.xlim(0, platformwidht)
    plt.ylim(0, platformheight)
    plt.title('Heatmap of tracked animal position')
    fig.colorbar(a)
    plot = plt.show()
    
    return plot