# Extraction of SWR from CA1 recordings

Restarting from LFPwake0 and LFPwakeremoved.

LFPwakeremoved will be used to determined signal variance for threshold adjustement. 

LFPwake0 will be used for time determination. 

## Load LFP and packages

In [None]:
from scipy import signal
from scipy.signal import find_peaks
from scipy.signal import chirp, find_peaks, peak_widths
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider, Button, Cursor
from scipy import fftpack
import pandas as pd
from pathlib import Path
import os
from ipyfilechooser import FileChooser
from IPython.display import display
import ipywidgets as widgets


%matplotlib widget

from ephyviewer import mkQApp, MainViewer, TraceViewer, EventList, InMemoryEventSource
from ephyviewer import AnalogSignalSourceWithScatter
import ephyviewer

In [None]:
dpath = ""
try:
    %store -r dpath
except:
    print("data path not in strore")
    dpath = os.path.expanduser("~")

fc1 = FileChooser(dpath,select_default=True, show_only_dirs = True, title = "<b>OpenEphys Folder</b>")
display(fc1)

# Sample callback function
def update_my_folder(chooser):
    global dpath
    dpath = chooser.selected
    %store dpath
    return 

# Register callback function
fc1.register_callback(update_my_folder)

In [None]:
suffix='_ABTEST'#'_AB'
sep = -5
animalIDPos = -3
dirPathComponents = os.path.normpath(dpath).split(os.sep)
mapPath = os.path.sep.join(dirPathComponents[:sep]) # path to the channelperMice file
folder_base = os.path.sep.join(dirPathComponents[sep:])
mice = dirPathComponents[animalIDPos]
#os.chdir(mapPath)
print(mapPath)
print(folder_base)
print(mice)

In [None]:
class color:
   PURPLE = '\033[95m'
   CYAN = '\033[96m'
   DARKCYAN = '\033[36m'
   BLUE = '\033[94m'
   GREEN = '\033[92m'
   YELLOW = '\033[93m'
   RED = '\033[91m'
   BOLD = '\033[1m'
   UNDERLINE = '\033[4m'
   END = '\033[0m'

filename = os.path.join(dpath,f'LFPwake0{suffix}.npy')
filename3 = os.path.join(dpath,f'LFPwakeremoved{suffix}.npy')
filename2 = os.path.join(dpath,'RawDataChannelExtractedDS.npy')
EMGbooleaninput = os.path.join(dpath,f'EMGframeBoolean{suffix}.pkl')


EMGboolean = pd.read_pickle(EMGbooleaninput)
LFPwakeremoved = np.load(filename3, mmap_mode= 'r')
All = np.load(filename2, mmap_mode= 'r')

try:
    Channels = os.path.join(mapPath,f'LFPChannels_perMice.xlsx')
    allchannels = pd.read_excel(Channels)
    PFCch1=int(allchannels[mice][0].split(',')[0])
    PFCch2=int(allchannels[mice][0].split(',')[1])
    CA1ch1=int(allchannels[mice][2].split(',')[0])
    CA1ch2=int(allchannels[mice][2].split(',')[1])
except FileNotFoundError as e:
    print(color.BOLD + color.YELLOW)
    print(f"File {os.path.join(mapPath,f'LFPChannels_perMice.xlsx')} not found!")
    print("probably you are not Aurélie... or the path to access it is wrong.")
    print("In the first case, make sure the mapping is rightfully setup in the curent cell")
    print("In the second case, you can play with the 'sep' variable of cell 3, or directly change the path")
    print(color.END)
    PFCch1=19
    PFCch2=20
    CA1ch1=21
    CA1ch2=22

PFC  =  All[:, PFCch1]-All[:, PFCch2] 
CA1  =  All[:, CA1ch1]-All[:, CA1ch2] 
CA1wakeremoved = LFPwakeremoved[:,CA1ch1]-LFPwakeremoved[:,CA1ch2] 

# Band pass filter
        SWR: 120-200 Hz

In [None]:
# Filtre parameter:
f_lowcut = 120.
f_hicut = 200.
fs = 1000
nyq = 0.5 * fs
N = 6                 # Filtre order
Wn = [f_lowcut/nyq,f_hicut/nyq]  # Nyquist frequency fraction

# Filtering:
b, a = signal.butter(N, Wn, 'band')
filt_CA1 = signal.filtfilt(b, a, CA1)
filt_CA1wakeremoved = signal.filtfilt(b, a, CA1wakeremoved)

# Plot
times = np.arange(0, CA1.size/fs, 1./fs)
#timesmin = np.arange(0, CA1.size/fs/60, 1./fs/60)
#fig, ax = plt.subplots()
#ax.plot(timesmin, filt_CA1)

## Continuous Wavelet Transform and projection calculation

First on signal with no wake time to determine sd of signal

In [None]:
# Parameter and computation of CWT
w = 10.
freq = np.linspace(120, 200, 80)
widths = w*fs / (2*freq*np.pi)
CA1NWcwt = signal.cwt(filt_CA1wakeremoved, signal.morlet2, widths, w=w)

# Projection calculation
absCA1NWcwt = np.absolute(CA1NWcwt)
proj_CA1NWcwt = np.sum(absCA1NWcwt, axis = 0)/80
sdproj_CA1cwt = np.std(proj_CA1NWcwt)
sd3proj_CA1cwt = sdproj_CA1cwt*3
sd10proj_CA1cwt = sdproj_CA1cwt*10
sd8proj_CA1cwt = sdproj_CA1cwt*8
sd7proj_CA1cwt = sdproj_CA1cwt*7
sd05proj_CA1cwt = sdproj_CA1cwt*0.5

Second on the signal for which wake times have been zeroed

In [None]:
# Conservative boolean filtering of CA1 filtered signal
BooleanCons = EMGboolean['BooleanConservative']
fCA1wake0C = filt_CA1.copy()
fCA1wake0C[BooleanCons] = 0
CA1wake0C = CA1.copy()
CA1wake0C[BooleanCons] = 0
# Liberal boolean filtering of CA1 filtered signal
BooleanLib = EMGboolean['BooleanLiberal']
fCA1wake0L = filt_CA1.copy()
fCA1wake0L[BooleanLib] = 0
CA1wake0L = CA1.copy()
CA1wake0L[BooleanLib] = 0

# Computation of CWT
CA1cwtWake0cons = signal.cwt(fCA1wake0C, signal.morlet2, widths, w=w)
CA1cwtWake0lib = signal.cwt(fCA1wake0L, signal.morlet2, widths, w=w)

# Projection calculation
absCA1W0Ccwt = np.absolute(CA1cwtWake0cons)
proj_CA1W0Ccwt = np.sum(absCA1W0Ccwt, axis = 0)/80
absCA1W0Lcwt = np.absolute(CA1cwtWake0lib)
proj_CA1W0Lcwt = np.sum(absCA1W0Lcwt, axis = 0)/80

combined = np.stack([CA1, filt_CA1, proj_CA1W0Ccwt, proj_CA1W0Lcwt], axis = 1)

sample_rate = 1000.
t_start = 0.

## Extracting SWRs and determining main properties 

First extraction of SWR peaks, initiation, end and width

In [None]:
# 3 sd threshold
peaks, properties = find_peaks(proj_CA1W0Lcwt, prominence=1, width=20, height=sd3proj_CA1cwt)
properties["prominences"], properties["widths"]

# SWR boundaries taken at 70% from peak of intensity. This means that the SWRs with small amplitude will be longer than the big ones.
results_width = peak_widths(proj_CA1W0Lcwt, peaks, rel_height=0.7)

# Organise results in numpy array
peaks2 = peaks.reshape(len(peaks),1)
npresults_width = np.array(results_width).reshape(4,-1)
SWR_prop = np.append(peaks2, results_width).reshape(5,len(peaks2)).round()

Display subset

Second extraction of main frequency and power 

In [None]:
projMaxP_cwtmg = np.max(CA1cwtWake0lib, axis = 0)
projMaxF_cwtmg = np.argmax(CA1cwtWake0lib, axis = 0) + 120
projMaxP_cwtmg.shape

nb_SWR = len(peaks)
data = np.zeros((nb_SWR,4))

for tt in np.arange(nb_SWR):
    SWR_start = int(SWR_prop[3,tt])
    SWR_stop = int(SWR_prop[4,tt])
    SWR_MaxP = projMaxP_cwtmg[SWR_start:SWR_stop]
    SWR_MaxF = projMaxF_cwtmg[SWR_start:SWR_stop]
    data[tt, 0] = max(SWR_MaxF).round()
    data[tt, 1] = max(SWR_MaxP).round()
    data[tt, 2] = round(sum(SWR_MaxF)/len(SWR_MaxF))
    data[tt, 3] = round(sum(SWR_MaxP)/len(SWR_MaxP))

param_SWR = pd.DataFrame(data, columns = ['Max freq', 'Max int', 'Avg freq', 'Avg int'])
tSWR_prop = SWR_prop.transpose()
pd_prop_SWR = pd.DataFrame(tSWR_prop, columns = ['peak time', 'Duration', 'peak amp', 'start time', 'end time'])
pd_tokeep = pd.DataFrame(np.ones(nb_SWR).astype(bool), columns = ['toKeep'])
All_SWR = pd.concat([pd_tokeep,pd_prop_SWR, param_SWR], axis=1)

SWR_peak = peaks
SWR_start = SWR_prop[3,:].astype(int)
SWR_end = SWR_prop[4,:].astype(int)

### Store the results in All_SWR_prop pd dataframe and save as pkl/csv for post processing.

End of Notebook. 

In [None]:
filename2 = os.path.join(dpath,f'SWRproperties{suffix}.pkl')
filename3 = os.path.join(dpath,f'SWRproperties{suffix}.csv')
All_SWR.to_pickle(filename2)
All_SWR.to_csv(filename3, sep = ',')

combined = np.stack([fCA1wake0C, proj_CA1W0Ccwt], axis = 1)
filenameC = os.path.join(dpath,f'SignalCA1.npy')
np.save(filenameC, combined)

sample_rate = 1000.
t_start = 0.
# if done and no intention to display for assessment
#%reset
#plt.close('all')

### Display and asses SWRs

#### ephys viewer to check SWR detection

In [None]:
%%capture
%gui qt

app = mkQApp()

#Create one data source with 3 event channel
all_events = []
conditions = ['All','Good','Bad']
for c,cond in enumerate(conditions):
    match cond:
        case 'All':
            selection = "All_SWR['toKeep'] | ~All_SWR['toKeep']"
        case 'Good':
            selection = "All_SWR['toKeep']"
        case 'Bad':
            selection = "~All_SWR['toKeep']"
    ev_times = All_SWR.loc[pd.eval(selection),'peak time'].values/1000
    ev_labels = [f'SWR {i}'for i in All_SWR[pd.eval(selection)].index]
    all_events.append({ 'time':ev_times, 'label':ev_labels, 'name': conditions[c] })
source_ev = InMemoryEventSource(all_events=all_events)

TTL = All[0:10000000, 11]#
combined = np.stack([CA1, fCA1wake0C, proj_CA1W0Lcwt, proj_CA1W0Ccwt, TTL], axis = 1)
sample_rate = 1000.
t_start = 0.

#create 2 familly scatters from theses 2 indexes
scatter_indexes = {0: SWR_peak, 1: SWR_start, 2: SWR_end}
#and asign them to some channels each
scatter_channels = {0: [1, 2], 1: [0, 1], 2: [0, 1]}
source = AnalogSignalSourceWithScatter(combined, sample_rate, t_start, scatter_indexes, scatter_channels)

#Create the main window that can contain several viewers
win = MainViewer(debug=True, show_auto_scale=True)

#create a viewer for signal with TraceViewer
#connected to the signal source
view1 = TraceViewer(source=source)

#Parameters can be set in script
view1.params['scale_mode'] = 'same_for_all'
view1.params['display_labels'] = True
view1.auto_scale()

nCh = len(view1.by_channel_params.children())
mult = 5
for ch in range(nCh):
    match ch%mult:
        case 0: # raw traces
            view1.by_channel_params[f'ch{ch}', 'offset'] = 2.5 + 5*int(ch/mult)
            #view1.by_channel_params[f'ch{ch}', 'gain'] = 0.05
            view1.by_channel_params[f'ch{ch}', 'color'] = '#ffffff'
        case 1: # filtered traces
            view1.by_channel_params[f'ch{ch}', 'offset'] = 0.5 + 5*int(ch/mult)
            #view1.by_channel_params[f'ch{ch}', 'gain'] = 0.1
            view1.by_channel_params[f'ch{ch}', 'color'] = '#0055ff'
        case 2: # envelop
            view1.by_channel_params[f'ch{ch}', 'offset'] = 0.5 + 5*int(ch/mult)
            #view1.by_channel_params[f'ch{ch}', 'gain'] = 0.1
            view1.by_channel_params[f'ch{ch}', 'color'] = '#ff5500'
        case 3: # envelop
            view1.by_channel_params[f'ch{ch}', 'offset'] = 0.5 + 5*int(ch/mult)
            #view1.by_channel_params[f'ch{ch}', 'gain'] = 0.1
            view1.by_channel_params[f'ch{ch}', 'color'] = '#ffffff'
        case 4: # TTL
            view1.by_channel_params[f'ch{ch}', 'offset'] = 4.5 + 5*int(ch/mult)
            #view1.by_channel_params[f'ch{ch}', 'gain'] = 0.1
            view1.by_channel_params[f'ch{ch}', 'color'] = '#ffffff'
view1.params['ylim_max']=5*int((nCh+1)/mult)
view1.params['ylim_min']=0

view2 = EventList(source=source_ev, name='event')


#put this viewer in the main window
win.add_view(view1)
win.add_view(view2, location='bottom',  orientation='horizontal')

#Run
win.show()

#### Select and deselect SWRs

In [None]:
def clicked(arg):
    selectedEvent = view2.list_widget.currentItem().text()
    selectedSWR = int(selectedEvent.split('SWR')[1])
    #print(selectedSWR)
    match arg.description:
        case 'Keep SWR':
            All_SWR.loc[selectedSWR,'toKeep']=True
            print(f'SWR {selectedSWR} restored')
        case 'Discard SWR':
            All_SWR.loc[selectedSWR,'toKeep']=False
            print(f'SWR {selectedSWR} discarded')
    #save modif
    All_SWR.to_pickle(filename2)
    All_SWR.to_csv(filename3, sep = ',')
    #Create one data source with 3 event channel
    all_events = []
    conditions = ['All','Good','Bad']
    for c,cond in enumerate(conditions):
        match cond:
            case 'All':
                selection = "All_SWR['toKeep'] | ~All_SWR['toKeep']"
            case 'Good':
                selection = "All_SWR['toKeep']"
            case 'Bad':
                selection = "~All_SWR['toKeep']"
        ev_times = All_SWR.loc[pd.eval(selection),'peak time'].values/1000
        ev_labels = [f'SWR {i}'for i in All_SWR[pd.eval(selection)].index]
        all_events.append({ 'time':ev_times, 'label':ev_labels, 'name': conditions[c] })
    source_ev = InMemoryEventSource(all_events=all_events)
    view2.source = source_ev
    view2.refresh_list(view2.combo.currentIndex())

button_Good = widgets.Button(description = 'Keep SWR')   
button_Good.on_click(clicked)

button_Bad = widgets.Button(description = 'Discard SWR')   
button_Bad.on_click(clicked)

display(button_Good, button_Bad)



In [None]:

# input data:
MaxSWR = param_SWR['Max int']
AvgSWR = param_SWR['Avg int']

# plotting the points
plt.scatter(np.zeros(len(MaxSWR)), MaxSWR)
plt.scatter(np.ones(len(AvgSWR)), AvgSWR)


In [None]:
# plotting the lines
for i in range(len(MaxSWR)):
    plt.plot( [0,1], [MaxSWR[i], AvgSWR[i]], c='k')

plt.xticks([0,1], ['Max', 'Avg'])
plt.show()