In [None]:
# Notebook 01: Preprocessing In Vivo Electrophysiology Data
#This notebook demonstrates the preprocessing steps for in vivo electrophysiology data, including loading raw data, filtering, and removing artifacts. The output is cleaned data that will be used in subsequent analyses, such as spike sorting and LFP analysis.

### Objectives:
###1. Load raw electrophysiology data.
###2. Apply bandpass filtering to isolate the spike band.
###3. Perform basic artifact removal.
###4. Save the filtered data for further analysis.

# Notebook 01: Loading Raw Data
import numpy as np
import matplotlib.pyplot as plt

# Load raw data (assuming it's stored as a NumPy array)
# Example path: data/raw/electrophysiology_data.npy
data = np.load('data/raw/electrophysiology_data.npy')

# Preview the first few seconds of data
plt.figure(figsize=(12, 4))
plt.plot(data[:20000])  # Assuming a 20 kHz sampling rate, this plots the first second
plt.title('Raw Electrophysiology Data')
plt.xlabel('Time (samples)')
plt.ylabel('Amplitude')
plt.show()

# Notebook 01: Bandpass Filtering
from scipy.signal import butter, filtfilt

# Define a bandpass filter
def butter_bandpass(lowcut, highcut, fs, order=4):
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    b, a = butter(order, [low, high], btype='band')
    return b, a

def bandpass_filter(data, lowcut=300.0, highcut=3000.0, fs=20000.0, order=4):
    b, a = butter_bandpass(lowcut, highcut, fs, order=order)
    y = filtfilt(b, a, data)
    return y

# Apply the bandpass filter
filtered_data = bandpass_filter(data, lowcut=300.0, highcut=3000.0, fs=20000.0)

# Plot the filtered data
plt.figure(figsize=(12, 4))
plt.plot(filtered_data[:20000])
plt.title('Filtered Electrophysiology Data (300-3000 Hz)')
plt.xlabel('Time (samples)')
plt.ylabel('Amplitude')
plt.show()

# Notebook 01: Artifact Removal
# Set an amplitude threshold for artifact detection (e.g., 5 times the standard deviation)
artifact_threshold = 5 * np.std(filtered_data)

# Detect artifacts
artifact_indices = np.where(np.abs(filtered_data) > artifact_threshold)[0]

# Display the number of detected artifacts
print(f"Detected {len(artifact_indices)} potential artifacts.")

# Plot a segment of the data showing artifacts
plt.figure(figsize=(12, 4))
plt.plot(filtered_data[:20000], label='Filtered Data')
plt.plot(artifact_indices[artifact_indices < 20000], filtered_data[artifact_indices[artifact_indices < 20000]], 'ro', label='Detected Artifacts')
plt.title('Artifact Detection')
plt.xlabel('Time (samples)')
plt.ylabel('Amplitude')
plt.legend()
plt.show()

# Replace artifacts with interpolated values
def replace_artifacts(data, artifact_indices):
    cleaned_data = data.copy()
    for idx in artifact_indices:
        if idx > 0 and idx < len(data) - 1:
            # Simple linear interpolation
            cleaned_data[idx] = (data[idx - 1] + data[idx + 1]) / 2
    return cleaned_data

# Apply artifact removal
cleaned_data = replace_artifacts(filtered_data, artifact_indices)

# Plot the cleaned data
plt.figure(figsize=(12, 4))
plt.plot(cleaned_data[:20000], label='Cleaned Data')
plt.title('Filtered and Cleaned Electrophysiology Data (After Artifact Removal)')
plt.xlabel('Time (samples)')
plt.ylabel('Amplitude')
plt.legend()
plt.show()

# Notebook 01: Saving the Processed Data
# Save the cleaned data for further analysis
np.save('data/processed/cleaned_data.npy', cleaned_data)