# 1 Importing packages and the functions.py

In [1]:
import os
import h5py
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import importlib
import getpass
import glob
import seaborn as sns
import functions
import lfp_pre_processing_functions
import power_functions
import coherence_functions
import spectrogram_plotting_functions
import plotting_styles
import scipy.stats
import mne_connectivity
importlib.reload(functions) #loads our custom made functions.py file
importlib.reload(spectrogram_plotting_functions)
importlib.reload(plotting_styles)

linestyle = plotting_styles.linestyles
colors = plotting_styles.colors

# 2 - Loading the data files

This code fetches the current 'user' by using getpass. Then it sets the basepath, loads the files and specifies the savepath. Note that the basepath, files and savepath need to be changed depending on where you have kept the files and where you want the results to be stored. In this case, I have set it up to be in a particular folder in my Dropbox account, which is stored locally.

In [2]:
#Fetch the current user
user= (getpass.getuser())
print("Hello", user)


if user == 'CPLab':
    base='D:\\Dropbox\\CPLab'
else:
    base='C:\\Users\\{}\\Dropbox\\CPLab'.format(user)
#Set the basepath, savepath and load the data files
files = glob.glob(base+'\\all_data_mat_250825\\*.mat')
savepath = base+'\\results\\'
print("Base path:", base)
print("Save path:", savepath)
print(files)

all_bands_dict = {'total':[1,100], 'beta':[12,30], 'gamma':[30,80], 'theta':[4,12]}


Hello CPLab
Base path: D:\Dropbox\CPLab
Save path: D:\Dropbox\CPLab\results\
['D:\\Dropbox\\CPLab\\all_data_mat_250825\\20230529_dk1_nocontext.mat', 'D:\\Dropbox\\CPLab\\all_data_mat_250825\\20230529_dk3_nocontext.mat', 'D:\\Dropbox\\CPLab\\all_data_mat_250825\\20230529_dk5_nocontext.mat', 'D:\\Dropbox\\CPLab\\all_data_mat_250825\\20230529_dk6_nocontext.mat', 'D:\\Dropbox\\CPLab\\all_data_mat_250825\\20230531_dk1_nocontext_day2.mat', 'D:\\Dropbox\\CPLab\\all_data_mat_250825\\20230531_dk3_nocontext_day2.mat', 'D:\\Dropbox\\CPLab\\all_data_mat_250825\\20230531_dk5_nocontext_day2.mat', 'D:\\Dropbox\\CPLab\\all_data_mat_250825\\20230531_dk6_nocontext_day2.mat', 'D:\\Dropbox\\CPLab\\all_data_mat_250825\\20230609_dk1_BW_nocontext_day1.mat', 'D:\\Dropbox\\CPLab\\all_data_mat_250825\\20230609_dk3_BW_nocontext_day1.mat', 'D:\\Dropbox\\CPLab\\all_data_mat_250825\\20230610_dk1_BW_nocontext_day2.mat', 'D:\\Dropbox\\CPLab\\all_data_mat_250825\\20230610_dk3_BW_nocontext_day2.mat', 'D:\\Dropbox\\CPLa

In [3]:
keyboard_dict={'98':'b','119':'w','120':'nc','49':'1','48':'0'} #specifying the map of keyboard annotations to their meanings.
all_bands={'total':[1,100],'beta':[12,30], 'gamma':[30,80], 'theta':[4,12]}
importlib.reload(lfp_pre_processing_functions) #Reloading the lfp_pre_processing_functions module to ensure we have the latest version
#files=[f'C:\\Users\\{user}\\Dropbox\\CPLab\\all_data_mat_filtered\\20230615_dk6_BW_context_day1.mat', f'C:\\Users\\{user}\\Dropbox\\CPLab\\all_data_mat\\20230626_dk6_BW_nocontext_day1.mat'] #This is just for testing purposes

#Initializing a few empty things to store data
events_codes_all = {}
compiled_data_all_epochs = []
compiled_data_list=[]
compiled_shuffled_data_list = []
baseline_lfp_all = []
normalization_comparison_all = []
baseline_dict = {}
for file in files: #Looping through data files
    
    ## Get the date, mouse_id and task from the file name
    base_name = os.path.basename(file)
    base_name, _ = os.path.splitext(base_name)
    date, mouse_id, task=lfp_pre_processing_functions.exp_params(base_name) #Using a custom made function [see functions.py]
    print(date, mouse_id, task)
    if task == 'nocontextday2' or task == 'nocontextos2':
        task = 'nocontext'
    if task =='nocontext':
        continue
    f=h5py.File(file, 'r')  ## Open the data file
    channels = list(f.keys()) ## Extract channels list from the data file
    print(base_name, channels)
    if not any("AON" in channel or "vHp" in channel for channel in channels):
        continue
    events,reference_electrode=lfp_pre_processing_functions.get_keyboard_and_ref_channels(f,channels)

    events_codes=np.array(events['codes'][0]) #saving the keyboard annotations of the events (door open, door close etc.)
    events_times=np.array(events['times'][0]) #saving when the events happened
    events_codes_all[base_name] = events_codes #saving the codes in a dictionary to be analyzed later for events other than the ones in our keyboard_dict map
    
    #Generating epochs from events (epochs are basically start of a trial and end of a trial)
    epochs=lfp_pre_processing_functions.generate_epochs_with_first_event(events_codes, events_times)

    # task Start time
    first_event=events_times[0]
    #finding global start and end time of all channels, since they start and end recordings at different times
    global_start_time, global_end_time=lfp_pre_processing_functions.find_global_start_end_times(f,channels)
    
    ## Reference electrode finding and padding
    reference_time = np.array(reference_electrode['times']).flatten()
    reference_value = np.array(reference_electrode['values']).flatten()
    padd_ref_data,padded_ref_time=lfp_pre_processing_functions.pad_raw_data_raw_time(reference_value,reference_time,global_start_time,global_end_time,sampling_rate=2000)

    for channeli in channels:
        if "AON" in channeli or  "vHp" in channeli :
            
            channel_id=channeli
            # Extracting raw data and time
            data_all=f[channeli]
            raw_data=np.array(data_all['values']).flatten()
            raw_time = np.array(data_all['times']).flatten()
            sampling_rate = 2000
            print(channel_id)
            print(raw_data.shape, raw_time.shape, sampling_rate)
            
            padded_data,padded_time=lfp_pre_processing_functions.pad_raw_data_raw_time(raw_data,raw_time,global_start_time,global_end_time,sampling_rate)
            #ref_subtracted_data = padded_data - padd_ref_data # Subtracting the reference electrode data from the raw data
            
            #notch_filtered_data = lfp_pre_processing_functions.iir_notch(raw_data, sampling_rate, 60)
            
            def extract_baseline_data(data,time,first_event,sampling_rate):
                if first_event>2.0:
                    baseline_data=data[np.where(time>first_event)[0][0]-2*sampling_rate:np.where(time>first_event)[0][0]]
                else:
                    baseline_data=data[0:np.where(time>first_event)[0][0]]
                baseline_mean=np.mean(baseline_data)
                baseline_std=np.std(baseline_data)
                
                #baseline_data_norm=(baseline_data-baseline_mean)/baseline_std
                print('normalizing data')
                return baseline_data,time, baseline_mean, baseline_std
            
            # Extracting baseline data
            #data_before, time, baseline_mean, baseline_std=extract_baseline_data(ref_subtracted_data, raw_time, first_event, sampling_rate)
            data_before, time, baseline_mean, baseline_std=extract_baseline_data(padded_data, raw_time, first_event, sampling_rate)
            print(len(data_before))
            complete_baseline_data=padded_data[0:np.where(time>first_event)[0][0]]
            baseline_row=[base_name, mouse_id,task,channel_id,first_event,np.array(data_before), np.array(complete_baseline_data)]
            baseline_lfp_all.append(baseline_row)
            baseline_dict[base_name] = baseline_row
baseline_lfp_all_df=pd.DataFrame(baseline_lfp_all, columns=['base_name', 'mouse_id', 'task', 'channel_id','first_event', 'data_before','complete_baseline_data'])


20230529 dk1 nocontext
20230529 dk3 nocontext
20230529 dk5 nocontext
20230529 dk6 nocontext
20230531 dk1 nocontextday2
20230531 dk3 nocontextday2
20230531 dk5 nocontextday2
20230531 dk6 nocontextday2
20230609 dk1 BWnocontext
20230609_dk1_BW_nocontext_day1 ['LFP1_AON', 'LFP1_vHp', 'LFP2_AON', 'LFP2_vHp', 'LFP3_AON', 'LFP4_AON', 'Memory', 'Ref', 'Respirat']
Global start time: 0.0, Global end time: 1980.2432749999998
3960487
start_index: 0, end_index: 3960486
LFP1_AON
(3960487,) (3960487,) 2000
3960487
start_index: 0, end_index: 3960487
normalizing data
4000
LFP1_vHp
(3960487,) (3960487,) 2000
3960487
start_index: 0, end_index: 3960487
normalizing data
4000
LFP2_AON
(3960487,) (3960487,) 2000
3960487
start_index: 0, end_index: 3960487
normalizing data
4000
LFP2_vHp
(3960487,) (3960487,) 2000
3960487
start_index: 0, end_index: 3960487
normalizing data
4000
LFP3_AON
(3960487,) (3960487,) 2000
3960487
start_index: 0, end_index: 3960487
normalizing data
4000
LFP4_AON
(3960487,) (3960487,) 200

In [4]:
!pip install openpyxl pillow



In [4]:
import importlib
import pandas as pd
import numpy as np
from scipy import signal
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg')  # Use non-interactive backend
import os
import io
from openpyxl import Workbook
from openpyxl.drawing.image import Image as OpenpyxlImage
from PIL import Image
import tempfile

# Reload the power_functions module
importlib.reload(power_functions)

def compute_cwt_spectrogram(data, fs=2000, freqs=None, n_cycles_factor=3):
    """
    Compute time-frequency representation using continuous wavelet transform with Morlet wavelets
    
    Parameters:
    - data: 1D array of LFP data
    - fs: sampling frequency (default: 1000 Hz)
    - freqs: array of frequencies to analyze (default: 1-100 Hz in log scale)
    - n_cycles_factor: factor for n_cycles = freqs/n_cycles_factor (default: 3)
    
    Returns:
    - dict containing frequencies, times, and time-frequency matrix
    """
    if freqs is None:
        # Create frequency array from 1 to 1000 Hz with more points at lower frequencies
        freqs = np.arange(1, 100)
    
    # Calculate n_cycles for each frequency
    n_cycles = freqs / n_cycles_factor
    
    # Create time vector
    times = np.arange(len(data)) / fs
    # Initialize output matrix
    tfr = np.zeros((len(freqs), len(data)), dtype=complex)
    
    # Compute CWT for each frequency
    for i, freq in enumerate(freqs):
        # Create Morlet wavelet
        # Morlet wavelet: complex exponential modulated by Gaussian
        sigma_t = n_cycles[i] / (2 * np.pi * freq)  # Time domain standard deviation
        
        # Create wavelet in time domain
        t_wavelet = np.arange(-4*sigma_t, 4*sigma_t, 1/fs)
        if len(t_wavelet) % 2 == 0:
            t_wavelet = t_wavelet[:-1]  # Make odd length
            
        # Morlet wavelet formula
        morlet_wavelet = (1 / np.sqrt(np.pi * sigma_t)) * np.exp(1j * 2 * np.pi * freq * t_wavelet) * np.exp(-(t_wavelet**2) / (2 * sigma_t**2))
        
        # Convolve data with wavelet
        convolution = np.convolve(data, morlet_wavelet, mode='same')
        tfr[i, :] = convolution
    
    # Convert to power (magnitude squared)
    power = np.abs(tfr) ** 2
    
    # Convert to dB scale
    power_db = 10 * np.log10(power + 1e-12)  # Add small epsilon to avoid log(0)
    
    return {
        'frequencies': freqs,
        'times': times,
        'spectrogram': power_db,
        'power_spectral_density': power,
        'complex_tfr': tfr,
        'n_cycles': n_cycles
    }

def create_spectrogram_image(spectrogram_data, figsize=(6, 4), dpi=100):
    """
    Create a matplotlib figure of the CWT spectrogram and return as image bytes
    
    Parameters:
    - spectrogram_data: dict from compute_cwt_spectrogram()
    - figsize: tuple for figure size
    - dpi: resolution for the image
    
    Returns:
    - bytes object containing PNG image data
    """
    fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
    
    # Create spectrogram plot
    im = ax.pcolormesh(
        spectrogram_data['times'], 
        spectrogram_data['frequencies'],
        spectrogram_data['spectrogram'],
        shading='gouraud',
        cmap='viridis'
    )
    
    ax.set_ylabel('Frequency [Hz]')
    ax.set_xlabel('Time [sec]')
    ax.set_title('LFP CWT Spectrogram (Morlet Wavelets)')
    # Log scale for better visualization of CWT
    #ax.set_yscale('log')

    # Add colorbar
    cbar = plt.colorbar(im, ax=ax)
    cbar.set_label('Power [dB]')
    
    # Set frequency limits
    ax.set_ylim([1, 100])  # Focus on 1-100 Hz (adjust as needed)
    
    # Add grid for better readability
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    # Save to bytes buffer
    img_buffer = io.BytesIO()
    plt.savefig(img_buffer, format='PNG', bbox_inches='tight', dpi=dpi)
    img_buffer.seek(0)
    img_bytes = img_buffer.getvalue()
    
    plt.close(fig)  # Important: close figure to free memory
    
    return img_bytes
def create_psd_image(psd_data,fmin = 1, fmax=100, figsize=(6, 4), dpi=100):
    """
    Create a matplotlib figure of the PSD and return as image bytes
    
    Parameters:
    - psd_data: 
    - figsize: tuple for figure size
    - dpi: resolution for the image
    
    Returns:
    - bytes object containing PNG image data
    """
    fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
    #Create PSD plot
    welch_data = psd_data
    frequency = np.linspace(0, 1000, len(welch_data))
    ax.plot(frequency, welch_data, label=channel_id)
    # Log scale for better visualization of CWT
    #ax.set_yscale('log')
    ax.set_xlabel('Frequency [Hz]')
    ax.set_ylabel('Power [uV^2/Hz]')
    ax.set_xlim([fmin, fmax])  # Focus on 1-100 Hz (adjust as needed)    
    plt.tight_layout()
    
    # Save to bytes buffer
    img_buffer = io.BytesIO()
    plt.savefig(img_buffer, format='PNG', bbox_inches='tight', dpi=dpi)
    img_buffer.seek(0)
    img_bytes = img_buffer.getvalue()
    
    plt.close(fig)  # Important: close figure to free memory
    
    return img_bytes

# Apply transformations to the dataframe
baseline_lfp_all_df['welch'] = baseline_lfp_all_df['data_before'].apply(
    lambda x: power_functions.apply_welch_transform(x)
)

# baseline_lfp_all_df['welch_complete'] = baseline_lfp_all_df['complete_baseline_data'].apply(
#     lambda x: power_functions.apply_welch_transform(x)
# )

# # Add CWT spectrogram computation
# print("Computing CWT spectrograms with Morlet wavelets... (this may take a while)")
# baseline_lfp_all_df['spectrogram'] = baseline_lfp_all_df['data_before'].apply(
#     lambda x: compute_cwt_spectrogram(x, fs=2000, n_cycles_factor=3)  # Adjust fs to your actual sampling rate
# )

# # Create spectrogram images
# print("Generating spectrogram images...")
# baseline_lfp_all_df['spectrogram_image'] = baseline_lfp_all_df['spectrogram'].apply(
#     lambda x: create_spectrogram_image(x)
# )

#Create PSD images

baseline_lfp_all_df['spectrogram_image'] = baseline_lfp_all_df['welch'].apply(
    lambda x: create_psd_image(x)
)

baseline_lfp_all_df['line_power'] = baseline_lfp_all_df['welch'].apply(
    lambda x: power_functions.get_band_power(x, 58, 62)
)

baseline_lfp_all_df['total_power'] = baseline_lfp_all_df['welch'].apply(
    lambda x: power_functions.get_band_power(x, 1, 100)
)

baseline_lfp_all_df['line_power_ratio'] = (
    baseline_lfp_all_df['line_power'] / baseline_lfp_all_df['total_power']
)

baseline_lfp_all_df = baseline_lfp_all_df.sort_values(by=['mouse_id', 'task'])

# Create final dataframe
baseline_lfp_final = baseline_lfp_all_df[[
    'base_name', 'mouse_id', 'task', 'channel_id', 'line_power_ratio', 'spectrogram_image'
]]

baseline_lfp_final = baseline_lfp_final.reset_index(drop=True)

def save_excel_with_images(df, filepath):
    """
    Save dataframe to Excel with embedded spectrogram images
    
    Parameters:
    - df: dataframe containing spectrogram_image column with PNG bytes
    - filepath: path to save Excel file
    """
    print("Creating Excel file with embedded images...")
    
    # Create workbook and worksheet
    wb = Workbook()
    ws = wb.active
    ws.title = "LFP_Analysis_with_Spectrograms"
    
    # Define column headers (excluding image columns from data table)
    data_columns = ['base_name', 'mouse_id', 'task', 'channel_id', 'line_power_ratio']
    headers = data_columns + ['spectrogram_image']
    
    # Write headers
    for col_idx, header in enumerate(headers, 1):
        ws.cell(row=1, column=col_idx, value=header)
    
    # Set column widths
    ws.column_dimensions['A'].width = 20  # base_name
    ws.column_dimensions['B'].width = 12  # mouse_id
    ws.column_dimensions['C'].width = 15  # task
    ws.column_dimensions['D'].width = 12  # channel_id
    ws.column_dimensions['E'].width = 18  # line_power_ratio
    ws.column_dimensions['F'].width = 50  # spectrogram_image (wide for image)
    
    # Store temporary files to clean up later
    temp_files = []
    
    try:
        # Process each row
        for idx, row in df.iterrows():
            row_num = idx + 2  # Excel rows are 1-indexed, +1 for header
            
            # Write data columns
            for col_idx, col_name in enumerate(data_columns, 1):
                ws.cell(row=row_num, column=col_idx, value=row[col_name])
            
            # Add spectrogram image
            try:
                # Create temporary file for the image
                tmp_file = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
                tmp_file.write(row['spectrogram_image'])
                tmp_file.close()  # Close the file before using it
                temp_files.append(tmp_file.name)
                
                # Add image to Excel
                img = OpenpyxlImage(tmp_file.name)
                img.width = 400  # Adjust size as needed
                img.height = 300
                
                # Position image in the cell
                cell_address = f'F{row_num}'
                ws.add_image(img, cell_address)
                
                # Set row height to accommodate image
                ws.row_dimensions[row_num].height = 225  # Adjust as needed
                
            except Exception as e:
                print(f"Warning: Could not add image for row {idx}: {e}")
                ws.cell(row=row_num, column=6, value="Image generation failed")
        
        # Save workbook
        wb.save(filepath)
        print(f"Excel file with embedded images saved to: {filepath}")
        
    finally:
        # Clean up all temporary files
        for temp_file in temp_files:
            try:
                if os.path.exists(temp_file):
                    os.unlink(temp_file)
            except Exception as e:
                print(f"Warning: Could not delete temporary file {temp_file}: {e}")

# Alternative function to save images in separate sheet
def save_excel_with_separate_image_sheet(df, filepath):
    """
    Save dataframe to Excel with data in one sheet and images in another
    """
    print("Creating Excel file with separate image sheet...")
    
    # Create data sheet without images
    df_for_excel = df[['base_name', 'mouse_id', 'task', 'channel_id', 'line_power_ratio']].copy()
    
    # Store temporary files to clean up later
    temp_files = []
    
    try:
        with pd.ExcelWriter(filepath, engine='openpyxl') as writer:
            # Write main data
            df_for_excel.to_excel(writer, sheet_name='LFP_Data', index=False)
            
            # Create images sheet
            wb = writer.book
            img_ws = wb.create_sheet('Spectrograms')
            
            # Headers for image sheet
            img_ws.cell(row=1, column=1, value='Row_Index')
            img_ws.cell(row=1, column=2, value='Identifier')
            img_ws.cell(row=1, column=3, value='Spectrogram')
            
            # Add images
            for idx, row in df.iterrows():
                row_num = idx + 2
                
                # Add identifier
                identifier = f"{row['mouse_id']}_{row['task']}_ch{row['channel_id']}"
                img_ws.cell(row=row_num, column=1, value=idx)
                img_ws.cell(row=row_num, column=2, value=identifier)
                
                # Add image
                try:
                    tmp_file = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
                    tmp_file.write(row['spectrogram_image'])
                    tmp_file.close()
                    temp_files.append(tmp_file.name)
                    
                    img = OpenpyxlImage(tmp_file.name)
                    img.width = 400
                    img.height = 300
                    
                    cell_address = f'C{row_num}'
                    img_ws.add_image(img, cell_address)
                    img_ws.row_dimensions[row_num].height = 225
                    
                except Exception as e:
                    print(f"Warning: Could not add image for row {idx}: {e}")
                    img_ws.cell(row=row_num, column=3, value="Image generation failed")
            
            # Adjust column widths
            img_ws.column_dimensions['A'].width = 12
            img_ws.column_dimensions['B'].width = 25
            img_ws.column_dimensions['C'].width = 50

        print(f"Excel file with separate image sheet saved to: {filepath}")
        
    finally:
        # Clean up all temporary files
        for temp_file in temp_files:
            try:
                if os.path.exists(temp_file):
                    os.unlink(temp_file)
            except Exception as e:
                print(f"Warning: Could not delete temporary file {temp_file}: {e}")

# Save the Excel file with embedded images
excel_filepath = os.path.join(savepath, 'baseline_power_analysis_with_spectrograms.xlsx')

# Choose one of the following methods:

# Method 1: Images embedded in the same sheet as data
save_excel_with_images(baseline_lfp_final, excel_filepath)

# Method 2: Images in a separate sheet (uncomment to use instead)
# excel_filepath_separate = os.path.join(savepath, 'baseline_power_analysis_separate_images.xlsx')
# save_excel_with_separate_image_sheet(baseline_lfp_final, excel_filepath_separate)

print("Processing complete!")
print(f"Final dataframe shape: {baseline_lfp_final.shape}")
print("Excel file contains actual viewable spectrogram images!")

(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)
(4000,)


In [None]:
importlib.reload(power_functions) #Reloading the power_functions module to ensure we have the latest version
baseline_lfp_all_df['welch'] = baseline_lfp_all_df['data_before'].apply(lambda x: power_functions.apply_welch_transform(x))
baseline_lfp_all_df['welch_complete']= baseline_lfp_all_df['complete_baseline_data'].apply(lambda x: power_functions.apply_welch_transform(x))
baseline_lfp_all_df['line_power'] = baseline_lfp_all_df['welch'].apply(lambda x: power_functions.get_band_power(x, 58, 62))
baseline_lfp_all_df['total_power'] = baseline_lfp_all_df['welch'].apply(lambda x: power_functions.get_band_power(x, 1, 100))
baseline_lfp_all_df['line_power_ratio'] = baseline_lfp_all_df['line_power'] / baseline_lfp_all_df['total_power']
baseline_lfp_all_df = baseline_lfp_all_df.sort_values(by=['mouse_id','task'])
baseline_lfp_final = baseline_lfp_all_df[['base_name', 'mouse_id', 'task', 'channel_id', 'line_power_ratio']]
#baseline_lfp_final=baseline_lfp_final.sort_values(by=['line_power_ratio'], ascending=False)
baseline_lfp_final = baseline_lfp_final.reset_index(drop=True)
#baseline_lfp_final.to_csv(os.path.join(savepath, 'baseline_power_ratio_noref.csv'))


In [None]:
channel_list= ['LFP1_AON', 'LFP2_AON', 'LFP3_AON', 'LFP4_AON', 'LFP1_vHp', 'LFP2_vHp']
base_names = baseline_lfp_all_df['base_name'].unique()
print(base_names)

In [None]:

for base_namei in base_names:
    fig, axs = plt.subplots(6, 1, figsize=(6, 10), sharex=True, sharey=True)
    fig.suptitle(base_namei)
    baseline_lfp_all_df_base = baseline_lfp_all_df[baseline_lfp_all_df['base_name'] == base_namei]
    for axi,channel_id in enumerate(channel_list):
        channel_data = baseline_lfp_all_df_base[baseline_lfp_all_df_base['channel_id'] == channel_id]
        if channel_data.empty:
            continue
        mouse_id = channel_data['mouse_id'].values[0]
        task = channel_data['task'].values[0]
        welch_data = channel_data['welch'].values[0]
        frequency = np.linspace(0, 1000, len(welch_data))
        ax= axs[axi]
        ax.plot(frequency, welch_data, label=channel_id)
        ax.set_title(channel_id)
        ax.set_xlim(0, 6)
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    #fig.savefig(os.path.join(savepath, f'base_psd_{base_namei}.png'), dpi=300)
    plt.show()

In [None]:
baseline_lfp_all_df = baseline_lfp_all_df.sort_values(by=['mouse_id', 'task', 'base_name'])
final_table = baseline_lfp_all_df.pivot(index=['base_name'], columns='channel_id', values='line_power_ratio')
final_table = final_table.reindex(
    baseline_lfp_all_df.drop_duplicates('base_name').sort_values(['mouse_id', 'task', 'base_name'])['base_name']
)
final_table = final_table[['LFP1_AON', 'LFP2_AON', 'LFP3_AON', 'LFP4_AON', 'LFP1_vHp', 'LFP2_vHp']]
final_table_melted = final_table.reset_index().melt(id_vars='base_name', var_name='channel_id', value_name='line_power_ratio')
final_table.to_csv(os.path.join(savepath, 'baseline_channel_selection.csv'))

In [None]:
manual_check=pd.read_excel(os.path.join(savepath, 'filecheck_8325.xlsx'))
manual_check=manual_check[['base_name', 'LFP1AON', 'LFP2AON', 'LFP3AON', 'LFP4AON', 'LFP1vHC', 'LFP2vHC']]
manual_check = manual_check.rename(columns={
    'LFP1AON': 'LFP1_AON', 
    'LFP2AON': 'LFP2_AON', 
    'LFP3AON': 'LFP3_AON', 
    'LFP4AON': 'LFP4_AON', 
    'LFP1vHC': 'LFP1_vHp', 
    'LFP2vHC': 'LFP2_vHp'
})
manual_check = manual_check.set_index('base_name')
manual_check.to_csv(os.path.join(savepath, 'manual_check_channel_selection.csv'))


In [None]:
fs=2000
data_secs=11
spike_start_sec=9
spike_end_sec=10
sine_wave = np.sin(2 * np.pi * 10 * np.arange(spike_start_sec, spike_end_sec, 1/fs)) + 1  # 10 Hz sine wave
time= np.arange(0, data_secs, 1/fs)
test_data = np.random.rand(fs*data_secs)
test_data[spike_start_sec*fs:spike_end_sec*fs] = sine_wave  # Adding a spike in the first 2 seconds
plt.plot(time, test_data)
#plt.xlim(0,0)
plt.ylim(-1,5)
plt.show()


data_before, time, baseline_mean, baseline_std=lfp_pre_processing_functions.baseline_data_normalization(test_data, time, 10, fs)
plt.plot(data_before)