### **Packages install and initialization** <br>
___
*If you are using the program for the first time or in google colab, make sure that the necessary packages are installed. At the end, the kernel will automatically reboot so that the installed packages can be used.*

In [1]:
# %pip install cupy-cuda11x -f https://pip.cupy.dev/aarch64
# %pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116
# %pip install scipy
# %pip install git+https://github.com/endolith/wavelets
# %pip install ssqueezepy
# %pip install kaleido
# %pip install plotly
# import os
# os._exit(00)

*The necessary assets are imported.*

In [2]:
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import read_write
import visualization as vs
import preprocess as pr
import machine_learning as ml

[32;1m[🗸]   [0m|15:22:53|  [32;1m [0;32mPython package initialization done in read_write.[0m
[32;1m[🗸]   [0m|15:22:53|  [32;1m [0;32mPython package initialization done in visualization.[0m
[32;1m[🗸]   [0m|15:22:58|  [32;1m [0;32mPython package initialization done in preprocess.[0m
[32;1m[🗸]   [0m|15:23:08|  [32;1m [0;32mPython package initialization done in machine_learning.[0m


### **Define parameters** <br>
___
*Below it is possible to change the settings. If `is_test` is True, no file will be saved and only 1 chunk will be processed!!!*

In [3]:
base_path = './' #gives the source folder where the sources were saved
is_test   = True #if it's True no output data will be saved and make a minimum runtime to check functions
is_save   = True #if it's True all results will be saved(eg.: plots, models,etc.)

#------------------SIGNAL FILE------------------------------------
source_file_path        = base_path + 'R2470_experiment_1.nwb' #It gives the source path (base_path + 'source file name')
original_srate          = 30000 #Sampling rate of the raw signal in Hz. Defined to 30000
led_srate               = 1000  #Sampling rate of the raw LED signal in Hz. Defined to 30000
new_srate               = 20000 #Downscaled target sampling rate in Hz. Defined to 20000.
window_size             = 2000  #Size of one window in milliseconds. Defined to 2000.
chunk_size              = 30000 #Size of one chunk in milliseconds. Defined to 30000.

#----------------- VISUALIZATION----------------------------------
is_save_images          = True #If it is True the plots will be saved.
image_save_dir          = './images/' #The name of the target folder where the plots will be saved. Defined to './images/'
start_point             = 0 #It gives the start point of the plotted signal parts in ms. Defined to 0.
end_point               = 1000 #It gives the end point of the plotted signal parts in ms. Defined to 1000.

#------------------PREPROCESS-------------------------------------
"""
tasks gives the the required preprocessing steps and their settings.
Options:
    'Cheby_band'      : [lowcut, highcut, order]. Normally [2,250,5]
    'narrow_filt'     : [notch frequncy, quality factor]. Normally [50,20]
    'down_samp'       : [new sampling rate]. Leave it [new_srate]
    'detrend'         : True or False.
    'roll_mean'       : [windows factor]. Normally [20]
    'normalize'       : True or False.
    'feature_extract' : [lowcut, highcut, number of voices, scaletype ('log-piecewise' or 'linear' or log), wavelet('morlet'), rescaled size]. 
                        Normally [2,250,5,'log-piecewise','morlet',64] 
                        
    If you want to switch off a preprocessing step write False to the parameters section(e.g.: 'down_samp':False).
"""

tasks                   = {
                        'Cheby_band'     : [2,250,5],
                        'narrow_filt'    : [50,20],
                        'down_samp'      : [new_srate],
                        'detrend'        : True,
                        'roll_mean'      : [20],
                        'normalize'      : True,
                        'feature_extract': [2,250,5,'log-piecewise','morlet',64] 
                        }
cutoff                  = 0 #The size of the extra part which will be cutted off from the beginning and the end of the chunk to avoid edge effects.
                            #In milliseconds.
use_gpu                 = False #If it is True the GPU will be used for preprocessing if it is availabe
use_paral               = True  #If it is True paralell computing will be used to speed up preprocessing if >5 CPU cores available.
use_saver               = True #If it's True the progam will find the chunk size with the minimum data surplus
live_write              = False #live_write (bool): If it's True the processed data will be saved during the preprocessing. 
                                #This mode saves RAM, but slows the preprocessing.Default to False.

#--------------------LEARNING--------------------------------------
model_name              = 'CNN_behav_cloning' #The name of the deep learning model. Options 'CNN_behav_cloning','CNN_transf' or 'CNN_BiLSTM'
learn_rate              = 1e-2 #Learning rate. Defined to 1e-2.
epochs                  = 10 #Number of epochs. Defined to 15
batch_size              = 2 #Size of a learning batch. Defined to 128
validation_batch_size   = 2 #Size of a validation batch. Defined to 200
decay_steps             = 10 #Size of steps that will reduce the learning rate while learning
process_live            = False #If it is True the preprocessing will be made during training and not before. Usefull to save memory.
validation_ratio        = 0.3 #The ratio of the validation dataset compared to the whole dataset. Defined to 0.3
test_ratio              = 0.1 #The ratio of the test dataset compared to the whole dataset. Defined to 0.1
use_shuffle             = True #If it is True the dataset will be shuffled before training. Defined to True
#Datapoint filter options
"""It is possible to use only those windows for teaching that correspond 
to a preset mathematical relation. For example, if you want to use only 
those windows where the speed of the animal was 5: 
    paramater_name = 'speed'
    relation = '>'
    limit_value = 5
Only one such parameter and one condition can be specified at a time.
"""
paramater_name          = None #Options: ['speed','head_dir','position_x','position_y']. Defined to None.
relation                = None #Options: ['<','<=','==','!=','>=','>']. Defined to None.
limit_value             = None #Defined to None.

#------------------OTHER-------------------------------------
#if is_test=True don't save anything
if is_test:
    is_save=False
    is_save_images=False

### **Load source file** <br>
___
*`data_load_pipeline` manages the filepath given as `source_file_path`, check it and load it if there are no problems.*

In [4]:
dataset, source_file_path, load_code = read_write.data_load_pipeline(path=source_file_path, 
                                                                     original_srate=original_srate, 
                                                                     led_srate=led_srate, 
                                                                     is_test=is_test)

[33;1m[>>]  [0m|15:23:08|  [33;1m [0;33mThe source file is a raw file.[0m

The file has the extension .nwb. The program can generate a more
compact .h5 raw signal format from the current file, that contains
the necessary parameters to simplify later use.

[33;1m[>>]  [0m|15:23:08|  [33;1m [0;33mTest mode is active. No file will be saved.[0m
[32;1m[🗸]   [0m|15:23:08|  [32;1m [0;32mThe signal loading from raw .nwb file is done.[0m
[32;1m[🗸]   [0m|15:23:08|  [32;1m [0;32mFile loading has been finished[0m


**Plot raw tetrode signal from the first five channels** <br>
*Plots the first 5 channels of the raw tetrode signal*

In [5]:
vs.plot_raw_tetrode(dataset.tetrode_timestamps[start_point:end_point],
                    dataset.raw_signal[start_point:end_point,:],
                    is_save=is_save_images,
                    load_code=load_code, 
                    save_path = '{0}tetrode_raw.svg'.format(image_save_dir))

[32;1m[🗸]   [0m|15:23:18|  [32;1m [0;32mTetrode signal plotted and saved to: ./images/tetrode_raw.svg[0m


### **Preprocessing** <br>
___
*It runs the entire preprocessing pipeline in the order and with the parameters specified in the task variable and optionally saves the result.*

In [6]:
if not process_live:
    pr.preprocess_pipeline(dataset, 
                        tasks,
                        file_path=source_file_path, 
                        window_size=window_size,
                        chunk_size=chunk_size,
                        is_gpu= use_gpu,
                        is_paral=use_paral,
                        cutoff=cutoff,
                        load_code=load_code,
                        use_saver=use_saver,
                        is_test=is_test,
                        live_write=live_write)

[33;1m[>>]  [0m|15:23:18|  [33;1m [0;33mUse GPU: False[0m
[33;1m[>>]  [0m|15:23:18|  [33;1m [0;33mParalell computing with CPU: False[0m
[32;1m[🗸]   [0m|15:23:18|  [32;1m [0;32mPreprocessing function initialized.[0m
[33;1m[>>]  [0m|15:23:20|  [33;1m [0;33mPreprocessing has been started...[0m


Processing chunks:   0%|          | 0/1 [00:00<?, ?it/s]

[33;1m[>>]  [0m|15:24:19|  [33;1m [0;33mTest mode is active. No file will be saved and only 1 chunk will be processed.[0m
[32;1m[🗸]   [0m|15:24:19|  [32;1m [0;32mPreprocessing finished without saving the results.[0m


**Plot preprocessing results**<br>
*The plot shows the first channel of the signal before and after preprocessing. Also the head direction vs. speed  <br> 
and the position of the animal during that period.*

In [7]:
vs.processed_plot(dataset.tetrode_timestamps[start_point:end_point],
                  dataset.raw_signal[start_point:end_point,1],                   
                  dataset.processed_data[0][start_point:end_point,1], 
                  convert_ms = True,
                  is_save=is_save_images,
                  load_code=load_code,
                  save_path = '{0}tetrode_processed.svg'.format(image_save_dir))

[32;1m[🗸]   [0m|15:24:25|  [32;1m [0;32mProcessed signal plotted and saved.[0m


*Spectogram of the feature extracted data*

In [8]:
px.imshow(dataset.ft_extracted[0,:,:,0], color_continuous_scale='RdBu_r', aspect='equal')

### **Train model** <br>
___
*This is where the selected model is trained with the preprocessed data.*

In [9]:
history, train_idx, valid_idx, test_idx = df = ml.train_pipeline( dataset,
                                                                mname=model_name,
                                                                window_size=window_size,
                                                                init_lr=learn_rate, 
                                                                epochs=epochs, 
                                                                batch_size=batch_size,
                                                                val_batch_size=validation_batch_size,
                                                                decay_steps= decay_steps,
                                                                live_stream=process_live,
                                                                valr=validation_ratio,
                                                                testr=test_ratio,
                                                                shuff=use_shuffle,
                                                                col=paramater_name,
                                                                opr=relation,
                                                                val=limit_value,
                                                                is_save=is_save,
                                                                use_saver=use_saver,
                                                                save_path=base_path
                                                                )

[33;1m[>>]  [0m|15:24:27|  [33;1m [0;33mUse GPU: False[0m
[33;1m[>>]  [0m|15:24:27|  [33;1m [0;33mParalell computing with CPU: False[0m
[32;1m[🗸]   [0m|15:24:27|  [32;1m [0;32mTraining class initialized.[0m
[32;1m[🗸]   [0m|15:24:30|  [32;1m [0;32mBehaviour cloning modell assembling done with input shape (64, 64, 124).[0m
[32;1m[🗸]   [0m|15:24:33|  [32;1m [0;32mDataset indexes has been generated.[0m
[32;1m[🗸]   [0m|15:24:33|  [32;1m [0;32mGenerators indexes has been initialized.[0m
[32;1m[🗸]   [0m|15:24:33|  [32;1m [0;32mModel compiling done.[0m
[32;1m[🗸]   [0m|15:24:33|  [32;1m [0;32mModel callbacks have been created.[0m
[33;1m[>>]  [0m|15:24:33|  [33;1m [0;33mModel starts learning...[0m
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
[32;1m[🗸]   [0m|15:25:08|  [32;1m [0;32mModel learning done.[0m
[32;1m[🗸]   [0m|15:25:09|  [32;1m [0;32mModel and history saved to: ['./',

**Save variables**<br>
*Save all the important variables and settings to make easier the later version tracking and analysis.*

In [10]:
read_write.save_variables(path=source_file_path, tasks=tasks, tetr_srate=dataset.tetrode_srate, led_srate=led_srate, wind_size=window_size,
                        chunk_size=chunk_size, cutoff_size=cutoff, mod_name=model_name, lr_rate=learn_rate, nepochs=epochs, tbatch=batch_size,
                        val_batch=validation_batch_size, lr_ind=train_idx, val_ind=valid_idx, ts_ind=test_idx, shuffl=use_shuffle, fparam=paramater_name,
                        frel=relation, flim=limit_value, is_save=is_save)

[33;1m[>>]  [0m|15:25:12|  [33;1m [0;33mFile path changed to:./R2470_experiment_1_22_09_2022_15_25.nwb[0m
[32;1m[🗸]   [0m|15:25:12|  [32;1m [0;32mVariables and settings sucessfuly saved.[0m


**Plot losses**<br>
*-->TODO: create separate functions for these in visualization.*

In [12]:
import matplotlib.pyplot as plt
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
fig = go.Figure()
fig.add_trace(go.Scattergl(
                    y=history.history['loss'],
                    name='Train'))
fig.add_trace(go.Scattergl(
                    y=history.history['val_loss'],
                    name='Valid'))
fig.update_layout(height=500, 
                  width=700,
                  title='Overall loss',
                  xaxis_title='Epoch',
                  yaxis_title='Loss')
fig.show()

if is_save_images:
    read_write.save_plot(path=image_save_dir+'overall_loss.png',fig=fig)
    print(read_write.print_terminal(type='done',message='MAE plot of speed saved to {}'.format(image_save_dir+'overall_loss.png')))

[33;1m[>>]  [0m|15:26:32|  [33;1m [0;33mFile path changed to:./images/overall_loss_22_09_2022_15_26.png[0m
[32;1m[🗸]   [0m|15:26:38|  [32;1m [0;32mMAE plot of speed saved to ./images/overall_loss.png[0m


In [13]:
plt.clf()
fig = go.Figure()
fig.add_trace(go.Scattergl(
                    y=history.history['speed_output_mse'],
                    name='Train'))
fig.add_trace(go.Scattergl(
                    y=history.history['val_speed_output_mse'],
                    name='Valid'))
fig.update_layout(height=500, 
                  width=700,
                  title='Mean Absolute Error of speed feature',
                  xaxis_title='Epoch',
                  yaxis_title='Mean Absolute Error')
fig.show()

if is_save_images:
    read_write.save_plot(path=image_save_dir+'speed_MAE.png',fig=fig)
    print(read_write.print_terminal(type='done',message='MAE plot of speed saved to {}'.format(image_save_dir+'speed_MAE.png')))

[32;1m[🗸]   [0m|15:26:45|  [32;1m [0;32mMAE plot of speed saved to ./images/speed_MAE.png[0m


<Figure size 432x288 with 0 Axes>

In [14]:
plt.clf()
fig = go.Figure()
fig.add_trace(go.Scattergl(
                    y=history.history['head_dir_output_cyclical_mae_rad'],
                    name='Train'))
fig.add_trace(go.Scattergl(
                    y=history.history['val_head_dir_output_cyclical_mae_rad'],
                    name='Valid'))
fig.update_layout(height=500, 
                  width=700,
                  title='Cyclic Mean Absolute Error of head direction',
                  xaxis_title='Epoch',
                  yaxis_title='Mean Absolute Error')
fig.show()

if is_save_images:
    read_write.save_plot(path=image_save_dir+'head_dir_cyclicMAE.png',fig=fig)
    print(read_write.print_terminal(type='done',message='MAE plot of model saved to {}'.format(image_save_dir+'head_dir_cyclicMAE.png')))

[32;1m[🗸]   [0m|15:26:54|  [32;1m [0;32mMAE plot of model saved to ./images/head_dir_cyclicMAE.png[0m


<Figure size 432x288 with 0 Axes>