# This is just a notebook to visualise 1kHz filtered raw data

## Setup everything

### Import packages

In [None]:
%load_ext autoreload
%autoreload 2
%reload_ext autoreload
import os
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import pandas as pd

from ephyviewer import mkQApp, MainViewer, TraceViewer, TimeFreqViewer, InMemoryAnalogSignalSource, EventList
from ephyviewer import AnalogSignalSourceWithScatter, SpikeInterfaceRecordingSource, InMemoryEventSource

import mbTools

## Choose experiment
Select the folder of the experiment to display. If the experiment was already analyzed, you can select the iterimAnalysis folder. Otherwise select the raw data recording folder.

In [None]:
theExpe = mbTools.experiment()

## Load Data

### Map the whole data into memory

In [None]:
theExpe.analyseExpe_findData(fullSampling=False,spindleBN='SPINDproperties')
#theExpe.setNumLFPchannels(32)

### Extract submatrix of interest

In [None]:
#initiate combined and channelLabels
combined =  {}
channelLabels = {}
sample_rates = {}
t_start = {}

In [None]:
#LFP
if 'OE_LFP' in theExpe.data:
    sample_rates['LFP'] = theExpe.data['OE_LFP'].sampling_rate #20000
    t_start['LFP'] = theExpe.data['OE_LFP'].start
    combined['LFP'] = theExpe.data['OE_LFP'].combineStructures("All")#['M1'])
    channelLabels['LFP'] = theExpe.data['OE_LFP'].channelLabels[:]
    print("LFP data combined")
else:
    print("no LFP data to combine")

In [None]:
#LFP
if 'LFP_DS' in theExpe.data:
    theExpe.data['LFP_DS'].sampling_rate=1000
    theExpe.data['LFP_DS'].start=0
    print(theExpe.data['LFP_DS'].sampling_rate)

    sample_rates['LFP_DS'] = theExpe.data['LFP_DS'].sampling_rate #20000
    t_start['LFP_DS'] = theExpe.data['LFP_DS'].start
    combined['LFP_DS'] = theExpe.data['LFP_DS'].combineStructures("All")#['M1'])
    channelLabels['LFP_DS'] = theExpe.data['LFP_DS'].channelLabels[:]
    print("LFP data combined")
else:
    print("no LFP data to combine")

In [None]:
#NPX
if 'NPX' in theExpe.data:
    sample_rates['NPX'] = theExpe.data['NPX'].sampling_rate #30000
    t_start['NPX'] = theExpe.data['NPX'].start
    combined['NPX'] = theExpe.data['NPX'].signal['spike']
    channelLabels['NPX'] = theExpe.data['NPX'].channelLabels
    print("NPX data combined")
else:
    print("no NPX data to combine")

In [None]:
if 'Spindles' in theExpe.data:
    structure = 'M1'
    All_Spindle = theExpe.data['Spindles'][structure]
    print(All_Spindle)

In [None]:
#this cell can be used to plot very precisely time of interest. Beware that it conflicts with ephyviewer however. It might be possible to have 2 notebooks open simultanéeously...
if False:
    %matplotlib widget
    #you can confiure a y-offset and some scaling, have a look at the help of superCleanPlot
    mbTools.superCleanPlot(thedata.data['OE_LFP'], thedata.data['NPX'], structureLFP=['M1'], canauxNPX=[0], time=55) #canauxLFP=16, 
    picFN = os.path.sep.join([theExpe.rawDataPath,'A1-8978.svg'])
    plt.savefig(picFN, format="svg")

## Display

In [None]:
%gui qt
app = mkQApp()

#Create the main window that can contain several viewers
win = MainViewer(debug=True)

if 'LFP' in combined:
    source = InMemoryAnalogSignalSource(combined['LFP'], np.round(sample_rates['LFP']), t_start['LFP'], channel_names=channelLabels['LFP'])
    view1 = TraceViewer(source=source, name = 'LFP')

    #Parameters can be set in script
    view1.params['display_labels'] = True
    view1.params['scale_mode'] = 'same_for_all'
    view1.auto_scale()

    cmap = matplotlib.colormaps["hsv"]#Wistia"]
    nCh = len(view1.by_channel_params.children())
    for ch in range(nCh):
        #view1.by_channel_params[f'ch{ch}', 'gain'] = 0.00002
        #view1.by_channel_params[f'ch{ch}', 'offset'] = 0.1
        view1.by_channel_params[f'ch{ch}', 'color'] = matplotlib.colors.to_hex(cmap(ch/nCh), keep_alpha=False)
        pass

    #create a time freq viewer conencted to the same source
    view2 = TimeFreqViewer(source=source, name='tfr')
    view2.params['show_axis'] = False
    view2.params['timefreq', 'deltafreq'] = 1
    #view2.by_channel_params['ch3', 'visible'] = False
    view2.auto_scale()

    win.add_view(view1)
    #win.add_view(view2)

if 'LFP_DS' in combined:

    if All_Spindle is not None:
        #Create one data source with 3 event channel
        all_events = []
        conditions = ['All','Good','Bad']
        for c,cond in enumerate(conditions):
            match cond:
                case 'All':
                    selection = "All_Spindle['toKeep'] | ~All_Spindle['toKeep']"
                case 'Good':
                    selection = "All_Spindle['toKeep']"
                case 'Bad':
                    selection = "~All_Spindle['toKeep']"
            ev_times = mbTools.convertTheoricIndex2realTime(All_Spindle.loc[pd.eval(selection),'peak time'].values, realFreq=sample_rates['LFP_DS'], offset=t_start['LFP_DS'])
            ev_labels = [f'spindle {i}'for i in All_Spindle[pd.eval(selection)].index]
            all_events.append({ 'time':ev_times, 'label':ev_labels, 'name': conditions[c] })
        source_ev = InMemoryEventSource(all_events=all_events)

        Spindle_peak = All_Spindle['peak time'].astype(int)
        Spindle_start = All_Spindle['start time'].astype(int)
        Spindle_end = All_Spindle['end time'].astype(int)

        #create 2 familly scatters from theses 2 indexes
        scatter_indexes = {0: Spindle_peak, 1: Spindle_start, 2: Spindle_end}
        #and asign them to some channels each
        scatter_channels = {0: [0], 1: [0], 2: [0]}
        source = AnalogSignalSourceWithScatter(combined['LFP_DS'], sample_rates['LFP_DS'], t_start['LFP_DS'], scatter_indexes, scatter_channels, channel_names=channelLabels['LFP_DS'])
        view_Events = EventList(source=source_ev, name='event')
        
    else:
        source = InMemoryAnalogSignalSource(combined['LFP_DS'], sample_rates['LFP_DS'], t_start['LFP_DS'], channel_names=channelLabels['LFP_DS'])
        view_Events = None
    view_DS = TraceViewer(source=source, name = 'LFP_DS')

    #Parameters can be set in script
    view_DS.params['display_labels'] = True
    view_DS.params['scale_mode'] = 'same_for_all'
    view_DS.auto_scale()

    cmap = matplotlib.colormaps["hsv"]#Wistia"]
    nCh = len(view_DS.by_channel_params.children())
    for ch in range(nCh):
        #view_DS.by_channel_params[f'ch{ch}', 'gain'] = 0.00002
        #view_DS.by_channel_params[f'ch{ch}', 'offset'] = 0.1
        view_DS.by_channel_params[f'ch{ch}', 'color'] = matplotlib.colors.to_hex(cmap(ch/nCh), keep_alpha=False)
        pass

    win.add_view(view_DS)
else:
    view_Events=None


if 'NPX' in combined:
    sig_source = SpikeInterfaceRecordingSource(recording=combined['NPX'])
    #view3 = TraceViewer.from_numpy(combined['NPX'], sample_rates['NPX'], t_start['NPX'], 'NPX', channel_names=channelLabels['NPX'])
    view3 = TraceViewer(source=sig_source, name='NPX')
    win.add_view(view3)

    #Parameters can be set in script
    view3.params['display_labels'] = True
    view3.params['scale_mode'] = 'same_for_all'
    view3.auto_scale()

    cmap = matplotlib.colormaps["hsv"]#Wistia"]
    nCh = len(view3.by_channel_params.children())
    for ch in range(nCh):
        #view3.by_channel_params[f'ch{ch}', 'gain'] = 0.00002
        #view3.by_channel_params[f'ch{ch}', 'offset'] = 0.1
        view3.by_channel_params[f'ch{ch}', 'color'] = matplotlib.colors.to_hex(cmap(ch/nCh), keep_alpha=False)
        pass

if view_Events is not None:
    win.add_view(view_Events)


#Run
win.show()
#app.exec()  #if commented, the app is shown and fonctionnal. Maybe detecting buttons. the Python icon doesn't close any better