<a href="https://colab.research.google.com/github/haribharadwaj/notebooks/blob/main/AUDIOLOGY/AEPanalysis.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Auditory Neural Response Analysis

In this notebook, we will:
1. Load neural response data from an Excel file hosted on Dropbox, where the first column is time.
2. Randomly select one participant’s data for each run.
3. Set various parameters (filter type, cutoff frequencies, time window) -- **you can change this emphasize ABRs, MLRs, or later cortical responses!**
4. Filter and plot the data.

We assume the first column in the spreadsheet is called something like "Time" (in ms or s),
and the subsequent columns are participant response waveforms.

For more background on filtering and auditory evoked potentials, see:
- Picton, T. W. (2011). *Human auditory evoked potentials*. Plural Publishing.


In [None]:
#@title Run this to import some code libraries and load the data
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.signal import butter, filtfilt, firwin

# If running in Jupyter, enable inline plotting
%matplotlib inline

# --- DATA LOADING CELL ---
dropbox_url = 'https://www.dropbox.com/scl/fi/pamqgiibn5lo69bp5okhw/ABRMLRAEP_Rawtraces.xlsx?rlkey=hna569bokwyfhkn96pj6tw1l6&st=poj1k8pv&dl=1'

# Read the Excel file directly from the Dropbox link.
# The '?dl=1' at the end ensures a direct download.
df = pd.read_excel(dropbox_url)

print("Data loaded successfully!")
print("DataFrame shape:", df.shape)
print("Columns in the DataFrame:", df.columns.tolist())
df.head()


# Filter Parameters and Time Window

We will define:
1. The filter type: 'fir1' (FIR with Blackman window) or 'iir' (Butterworth).
2. The lower and upper cutoff frequencies (in Hz).
3. The time window (in the same units as your time column).


# How to change parameters

- **Changing Filter Parameters**: Adjust `low_cutoff`, `high_cutoff`, and `filter_type` in the FILTER PARAMETER CELL.
- **Changing Time Window**: Adjust `start_time` and `end_time` in the TIME WINDOW CELL.


In [None]:
# --- FILTER PARAMETER CELL ---

filter_type = 'fir1'  # Options: 'fir1' or 'iir'

# Example bandpass range in Hz (adjust these to target brainstem, middle latency, or cortical)
low_cutoff = 70.0
high_cutoff = 3000.0

print("Filter type:", filter_type)
print("Cutoff frequencies:", low_cutoff, "Hz to", high_cutoff, "Hz")


In [None]:
# --- TIME WINDOW PARAMETER CELL ---

# Set the time window in the same units as the first column of the spreadsheet.
# These are in seconds.
start_time = 0.0
end_time = 0.012

print("Time window for plotting:", start_time, "to", end_time, "(s)")


# Filtering and Plotting

Steps done by the code:
1. Extract the time column from the DataFrame.
2. Randomly pick one participant’s column.
3. Extract only the portion of the data between `start_time` and `end_time`.
4. Compute the sampling rate from the time column (assuming uniform sampling).
5. Apply the selected filter (FIR or IIR).
6. Plot the filtered signal.


In [None]:
#@title This section does the actual filtering and plotting: **Re-run to pick random participant each time**
from scipy.signal import butter, filtfilt, firwin

# 1. Extract the time column
#    We'll assume it's the first column in df.
time_col_name = df.columns[0]        # e.g., "Time"
time_array = df[time_col_name].values

# 2. Randomly pick one participant column
# exclude the first column (Time) and columns having mean and SEM
all_participants = df.columns[1:22]
random_participant = np.random.choice(all_participants)
signal = df[random_participant].values

print(f"Randomly selected participant: {random_participant}")

# 3. Extract only the portion of the data between start_time and end_time
#    First, find indices where (time >= start_time) & (time <= end_time)
mask = (time_array >= start_time) & (time_array <= end_time)
time_segment = time_array[mask]
signal_segment = signal[mask]

# 4. Compute the sampling rate from the time column
#    Assuming the first column is in seconds:
#    If the first column is in seconds:
#      dt = (time_segment[1] - time_segment[0]) in seconds
#      fs = 1 / dt

if len(time_segment) < 2:
    raise ValueError("Selected time window is too small or invalid. Check your start/end times.")

dt = time_segment[1] - time_segment[0]  # difference in consecutive time points
fs = 1.0 / dt
print(f"Approx. sampling rate (fs): {fs:.2f} Hz")

# 5. Apply the chosen filter

def apply_filter(sig, fs, low_cut, high_cut, ftype='fir1'):
    """
    Apply zero-phase bandpass filtering to the input signal.

    Parameters
    ----------
    sig : array
        1D time-series signal to be filtered.
    fs : float
        Sampling rate in Hz.
    low_cut : float
        Lower cutoff frequency in Hz.
    high_cut : float
        Upper cutoff frequency in Hz.
    ftype : str, optional
        'fir1' for FIR filter (using Blackman window) or 'iir' for IIR (Butterworth).

    Returns
    -------
    filtered_signal : array
        The zero-phase filtered signal.
    """

    # Compute normalized cutoff frequencies
    nyquist = fs / 2.0
    low = low_cut / nyquist
    high = high_cut / nyquist

    if ftype == 'fir1':
        # Create FIR coefficients using a Blackman window
        b = firwin(64, [low, high], pass_zero=False, window='blackman')
        # Zero-phase filtering by setting a = [1] and calling filtfilt(b, a, sig)
        filtered_signal = filtfilt(b, [1.0], sig)
        return filtered_signal

    elif ftype == 'iir':
        # Create a 4th order Butterworth bandpass filter
        b, a = butter(4, [low, high], btype='band')
        # filtfilt also provides zero-phase filtering for IIR
        filtered_signal = filtfilt(b, a, sig)
        return filtered_signal

    else:
        raise ValueError("Unknown filter type. Choose 'fir1' or 'iir'.")


filtered_segment = apply_filter(signal_segment, fs, low_cutoff, high_cutoff, filter_type)

# 6. Plot the filtered signal
plt.figure()
plt.plot(time_segment, filtered_segment)
plt.xlabel(f"Time (seconds)")
plt.ylabel("Amplitude (microvolts)")
plt.title(f"Filtered Response ({filter_type.upper()}): {random_participant}")
plt.grid(True)
plt.show()


# Grand Average of Filtered Responses

Previously, we demonstrated filtering a **single participant’s** response. However, if we want
a **grand average** that reflects the same processing steps for every participant, we need to:

1. **Extract and filter** each participant’s data over the same time window.
2. **Store** these filtered waveforms.
3. **Compute** the grand average (mean across participants) and standard error (SEM).

Plotting the grand average of the **filtered** signals (as opposed to filtering the average) is often desirable in evoked potential studies to ensure that each subject’s waveform is processed in the same way prior to averaging.




In [None]:
#@title This section shows the grand average across individuals

import numpy as np
import matplotlib.pyplot as plt

# (1) Identify the time column and compute sampling rate
time_col_name = df.columns[0]   # e.g., "Time"
time_array = df[time_col_name].values


dt = time_array[1] - time_array[0]
fs = 1.0 / dt
# (2) Create a mask for the desired time window
mask = (time_array >= start_time) & (time_array <= end_time)
time_segment = time_array[mask]

# (3) Filter each participant over the same time window, then store the result
filtered_data = []  # will become a list of 1D arrays (one per participant)

# We'll re-use the apply_filter function you defined earlier (with filtfilt for zero-phase):
# def apply_filter(sig, fs, low_cut, high_cut, ftype='fir1'):


# Suppose the first column is time, columns 1..21 are actual participants
valid_participant_cols = df.columns[1:22]  # i.e., up to 21 inclusive

for participant in valid_participant_cols:
    # Extract the participant's signal in the time window
    raw_signal = df[participant].values[mask]
    # Filter the signal
    filtered_signal = apply_filter(raw_signal, fs, low_cutoff, high_cutoff, filter_type)
    filtered_data.append(filtered_signal)

# Convert the list to a 2D NumPy array: shape (num_participants, num_timepoints)
filtered_data = np.array(filtered_data)

# (4) Compute grand average and standard error across participants
grand_average = filtered_data.mean(axis=0)
grand_std = filtered_data.std(axis=0)
num_participants = filtered_data.shape[0]
sem = grand_std / np.sqrt(num_participants)

# (5) Plot the grand average and a shaded region for ±1 SEM
plt.figure()
plt.plot(time_segment, grand_average, label='Grand Average (Filtered)')
plt.fill_between(
    time_segment,
    grand_average - sem,
    grand_average + sem,
    alpha=0.3,
    label='±1 SEM'
)
plt.xlabel(f"Time (seconds)")
plt.ylabel("Amplitude (microvolts)")
plt.title("Grand Average of Filtered Responses")
plt.legend()
plt.grid(True)
plt.show()

