In [None]:
# This code is used to create a template for each session of the reaching task
# The code will prompt the user to enter the rat name and session number
# The code will then display the trials for the session and the user should pick the three best trials for the template
# The best trials are ones with an obvious reach in the trajectory plot
# The code will display the trails and you will have to select the point of the reach
# The x-axis (frame where grab begings) is important for selection, not y-axis
# Once 3 points are selected, the code will create a template and display the correlation between the template and the trials
# It will also display the average correlation for the session
# The code will then save the template to the designated excel sheet
# All templates must be created to run proper correlation analysis on the entire rat data

import re
import glob
import pandas as pd
import numpy as np
import os.path
import sys
from reaching_task_utils import list_available_rats, make_template_spread_sheet,templates_required, process_files

User_Dir =  # Enter the path to the directory containing the videos start with a r' and end with \\'
# Example: User_Dir = r'C:\Users\username\Documents\Reach_Task\\'

rat_list = list_available_rats(User_Dir)

  
# Call the function
while True:
    rat_name = input(f"Acceptable names are {rat_list}\n\nEnter rat name: ")
    if rat_name in rat_list: 
        sessions = 7 if rat_name in ['Fariborz', 'Iraj', 'Tur'] else 10
        make_template_spread_sheet(User_Dir,rat_name,sessions)
        required_sessions = templates_required(User_Dir, rat_name,sessions)
        if required_sessions ==[]:
            k=input(f"\nNo sessions are needed, do you want to overwite a session? ")
            if k == 'yes':
                session_number = input("\nEnter the session number: ")
                myList = process_files(User_Dir, rat_name, session_number)   
                break   
            elif k=='no':
                print(f'{rat_name} Templates are complete!')
                sys.exit()
        else:
            print(f"\nSessions needed are {required_sessions}")
            session_number = input("\nEnter the session number: ")
            myList = process_files(User_Dir, rat_name, session_number)   
            break    
    else: 
        print(f'{rat_name} Is not an acceptable rat name')
   


import tkinter as tk
from tkinter import ttk
import matplotlib.pyplot as plt
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg

LARGEFONT = ("Verdana", 35)

class tkinterApp(tk.Tk):

    def __init__(self, *args, **kwargs):
        tk.Tk.__init__(self, *args, **kwargs)

        container = tk.Frame(self)
        container.pack(side="top", fill="both", expand=True)

        container.grid_rowconfigure(0, weight=1)
        container.grid_columnconfigure(1, weight=1)  # Second column should expand

        self.frames = {}
        self.selected_trials = []  # List to store the selected trial indices

        # Scrollbar for trial selection
        trials = [f'Trial {i + 1}' for i in range(len(myList))]
        self.trial_vars = []  # List to store the variables for each checkbox

        # Create the frame for the checklist
        checklist_frame = tk.Frame(container, bd=2, relief="ridge", width=200)  # Set width to make it smaller
        checklist_frame.grid(row=0, column=0, sticky="nsew")

        canvas = tk.Canvas(checklist_frame)
        frame = tk.Frame(canvas)

        scroll_y = tk.Scrollbar(checklist_frame, orient="vertical", command=canvas.yview)
        canvas.configure(yscrollcommand=scroll_y.set)

        scroll_y.pack(side="right", fill="y")
        canvas.pack(side="left", fill="both", expand=True)
        canvas.create_window((0, 0), window=frame, anchor="nw")

        for i, trial in enumerate(trials):
            var = tk.StringVar(value="off")  # Initialize checkbox state as "off"
            self.trial_vars.append(var)
            trial_style = ttk.Style()
            trial_style.configure("Trial.TCheckbutton", font=("Verdana", 10))  # Set the font size for the style
            trial_checkbox = ttk.Checkbutton(frame, text=trial, variable=var,
                                             onvalue="on", offvalue="off", command=self.on_trial_checked, style="Trial.TCheckbutton")
            trial_checkbox.pack(anchor="w")  # Align checkboxes to the left

        # Bind the event to update scroll region after the frame size changes
        frame.bind("<Configure>", lambda e: canvas.configure(scrollregion=canvas.bbox("all")))

        # Create the plot frame
        plot_frame = tk.Frame(container)
        plot_frame.grid(row=0, column=1, sticky="nsew")

        self.container = plot_frame  # Store the plot frame in an instance variable

        # Create a plot for each trial and add it as a page
        for i in range(len(myList)):
            frame = PlotPage(plot_frame, self, myList[i], i)
            self.frames[i] = frame
            frame.grid(row=0, column=0, sticky="nsew")

        self.show_frame(0)  # Show the first plot page initially

        # Bind the close event of the application to a custom method
        self.protocol("WM_DELETE_WINDOW", self.on_close)

    def show_frame(self, cont):
        frame = self.frames[cont]
        frame.tkraise()

    def on_trial_checked(self):
        # Update the selected_trials list based on the checkboxes' states
        self.selected_trials = [i for i, var in enumerate(self.trial_vars) if var.get() == "on"]

    def on_close(self):
        # Custom method to handle application closing
        # Add any cleanup code here if needed
        self.destroy()


class PlotPage(tk.Frame):

    def __init__(self, parent, controller, data, trial_number):
        tk.Frame.__init__(self, parent)

        self.data = data
        self.trial_number = trial_number

        label = ttk.Label(self, text=f'Rat {rat_name}: Session {session_number}, Trial {trial_number + 1}', font=LARGEFONT)
        label.pack(pady=10, padx=10)

        # Create the plot for this page
        x = np.arange(len(data))
        plt.figure()
        plt.plot(x, data)
        plt.xlabel('Samples')
        plt.ylabel('Coordinate')
        plt.grid(True)
        plt.legend(labels=myList[0], bbox_to_anchor=(1.1, 1.05))
        plt.tight_layout()

        # Create a canvas for the plot
        canvas = FigureCanvasTkAgg(plt.gcf(), master=self)
        canvas.draw()
        canvas.get_tk_widget().pack(side=tk.TOP, fill=tk.BOTH, expand=True)

        # Close the figure to prevent accumulating too many open figures
        plt.close()

        # Button to navigate to the previous plot page
        if trial_number > 0:
            prev_button = ttk.Button(self, text="Previous Trial",
                                     command=lambda: controller.show_frame(trial_number - 1))
            prev_button.pack(side=tk.LEFT, pady=10, padx=10)

        # Button to navigate to the next plot page
        if trial_number < len(myList) - 1:
            next_button = ttk.Button(self, text="Next Trial",
                                     command=lambda: controller.show_frame(trial_number + 1))
            next_button.pack(side=tk.RIGHT, pady=10, padx=10)


if __name__ == "__main__":


    app = tkinterApp()
    app.mainloop()
    
    temp = [] 
    for trial in app.selected_trials:
        temp.append(trial + 1)
    
    print("\nSelected Trial Numbers:",temp)
    
    
import matplotlib
matplotlib.use('TKAgg')  # Set the backend to TkAgg

def choose_three_points(data):
    selected_points = []

    for i in range(len(data)):
        plt.figure()  # Create a new figure for each plot
        plt.plot(myList[data[i]])
        plt.title(f'Trial {i+1}')
        plt.xlabel('Samples')
        plt.ylabel('Coordinate')
        plt.grid(True)
        plt.show()

        point = plt.ginput(1, timeout=0)

        if point:
            x, y = point[0]
            plt.plot(myList[data[i]])
            plt.plot(x, y, 'ro', markersize=10)
            plt.title(f'Trial {i+1} with Selected Point')
            plt.xlabel('Samples')
            plt.ylabel('Coordinate')
            plt.grid(True)
            plt.show()

            selected_points.append((np.int_(x)))

    plt.close('all')  # Close all figures to prevent issues with the kernel

    return selected_points

selected_points = choose_three_points(temp)
print("Selected Points:", selected_points)


matplotlib.use('agg')  # Set the backend to agg

%matplotlib inline

    
template=[]
for i in range(len(temp)):
    if selected_points[i]>25:
        template.append(myList[temp[i]].iloc[selected_points[i]-25:selected_points[i]+25])
    else: template.append(myList[temp[i]].iloc[0:50])
temp_mean = np.nanmean(template,0)
plt.plot(temp_mean)
plt.plot(temp_mean)
plt.title(f'Rat {rat_name}: Session {session_number}, Template')
plt.xlabel('Sample')
plt.ylabel('Coordinate')
plt.grid(True)

#Takes the newly made template and stores it to the designated sheet
df = pd.DataFrame(temp_mean,columns=myList[0].columns)
with pd.ExcelWriter(os.path.join(User_Dir,'Templates',f'Rat_{rat_name}_Templates.xlsx'),
                 mode="a",engine="openpyxl",
                 if_sheet_exists="replace") as writer:
    df.to_excel(writer, sheet_name=f'Rat {rat_name} S{session_number} Template') 
from scipy import signal
# Create subplots
fig, (ax1, ax2, ax3,ax4) = plt.subplots(4, 1, sharex=False, figsize=(8, 10))
Corr=[]
trial = myList[0] 
# Loop through each body part
for i in range(10):
    x = trial.iloc[:, i]
    y = pd.DataFrame(temp_mean).iloc[:, i]

    # Plot the original signal on the first subplot
    ax1.plot(x, label=f'Signal {i+1}')

    # Plot the template signal on the second subplot
    ax2.plot(y, label=f'Template {i+1}')
    #normalize data
    x = x - np.nanmean(x)
    y = y - np.nanmean(y)
    x /=  np.linalg.norm(x) 
    y /=  np.linalg.norm(y) 
    #take the cross correlation
    corr = signal.correlate(x,y, mode='valid', method='fft')
    Corr.append(corr)
    # Plot the correlation result on the third subplot
    ax3.plot(corr, label=f'Correlation {i+1}',linestyle=':',linewidth=1.25)
Corr.append(corr)  
#Find the peak and index of the average corr
avg_corr = np.nanmean(Corr,0)
peak = np.max(avg_corr)
peak_idx = np.argmax(avg_corr)
index, _ = signal.find_peaks(avg_corr, height=0.15)
if len(index) != 0:
    last_idx = index[-1]   # Get the last index value
else:
    last_idx = peak_idx
sub_peak=(avg_corr[last_idx])
print(f'The Peak for the 1st video of S{session_number} =',peak,'and the index is',peak_idx)
print(f'The Last Peak for the 1st video of S{session_number} =',sub_peak,'and the index is',last_idx)
#Plot average corr across all body parts
ax3.plot(avg_corr,color='black')
ax4.plot(trial)
ax4.axvline(x=peak_idx, color='r', linestyle='-')
# Add legends to the subplots
labels = ['Wrist_x', 'Wrist_y', 'Thumb_x', 'Thumb_y', 'Index_x', 'Index_y',
       'Middle_x', 'Middle_y', 'Last_x', 'Last_y','Avg']
ax1.legend(labels=labels,bbox_to_anchor=(1.1, 1.05))
ax2.legend(labels=labels,bbox_to_anchor=(1.1, 1.05))
ax3.legend(labels=labels,bbox_to_anchor=(1.1, 1.05))
ax4.legend(labels=labels[:-1],bbox_to_anchor=(1.1, 1.05))
# Set titles and labels for each subplot
ax1.set_title('Original Signals')
ax1.set_ylabel('Amplitude')
ax1.set_xlabel('Sample')
ax1.margins(0, 0.1)

ax2.set_title('Template Signals')
ax2.set_ylabel('Amplitude')
ax2.set_xlabel('Sample')
ax2.margins(0, 0.1)

ax3.set_title('Correlation Results')
ax3.set_xlabel('Lag')
ax3.set_ylabel('Correlation')
ax3.margins(0, 0.1)

ax4.set_title('Signal With Temp Point')
ax4.set_ylabel('Amplitude')
ax4.set_xlabel('Sample')
ax4.margins(0, 0.1)
#add grids to each subplot
ax1.grid(True)
ax2.grid(True)
ax3.grid(True)
ax4.grid(True)
# Adjust the layout and display the plots
fig.tight_layout()
plt.show()


trial = myList
Corr=[]
peak_idx = [] 
peak = []
# Loop through each trial
for j in range(len(trial)):
    sub = []
# Loop through each body part  
    for i in range(10):
        x = trial[j].iloc[:, i]
        y = pd.DataFrame(temp_mean).iloc[:, i]
        
        x = x - np.nanmean(x)
        y = y - np.nanmean(y)
        x /=  np.linalg.norm(x) 
        y /=  np.linalg.norm(y) 

        corr = signal.correlate(x, y, mode='valid', method='fft')
        sub.append(corr)
    Corr.append(sub)
    
#find the peak index for each trial  
for i in range(len(Corr)):
    avg_corr = np.nanmean(Corr[i], 0)
    peak_idx.append(np.argmax(avg_corr))
    peak.append(np.max(avg_corr)) 

avg_peak = np.nanmean(peak)
print(f'The Average peak correlation for Rat {rat_name}, S{session_number}: ',avg_peak)