# Reproducing Figure 2G-H: oEPSC Analysis

This notebook reproduces Figure 2G-H from **Zhai et al. 2025** comparing optogenetically-evoked postsynaptic currents (oEPSCs) between OffState and OnState conditions in a Parkinson's disease model.

**Dataset**: DANDI:001538 - State-dependent modulation of spiny projection neurons controls levodopa-induced dyskinesia

**Analysis approach**:
- **Figure 2G**: Cumulative distribution of all individual oEPSC events
- **Figure 2H**: Box plot comparing mean event amplitudes per experimental session
- **Event detection**: MAD-based noise estimation with ±5SD threshold
- **Conditions**: OffState (control) vs OnState (L-DOPA treated)

## Setup and Data Loading

### Import Libraries and Configure Plotting Style

We use the same plotting parameters as the original publication to ensure visual consistency.

In [None]:
import os

import h5py
import matplotlib.pyplot as plt
import numpy as np
import remfile
import seaborn as sns
from dandi.dandiapi import DandiAPIClient
from dotenv import load_dotenv
from pynwb import NWBHDF5IO
from scipy import stats
from tqdm import tqdm

# Set plotting style to match paper
plt.style.use('default')
sns.set_palette("Set2")

def setup_figure_style():
    """Setup matplotlib parameters to match paper style"""
    plt.rcParams.update({
        'font.size': 8,
        'axes.titlesize': 10,
        'axes.labelsize': 9,
        'xtick.labelsize': 8,
        'ytick.labelsize': 8,
        'legend.fontsize': 8,
        'figure.titlesize': 12,
        'axes.linewidth': 0.8,
        'axes.spines.top': False,
        'axes.spines.right': False,
        'xtick.major.width': 0.8,
        'ytick.major.width': 0.8,
        'xtick.minor.width': 0.6,
        'ytick.minor.width': 0.6,
    })

setup_figure_style()
print("Libraries imported and plotting style configured")

### Session ID Parsing and Filtering Functions

These utility functions parse the rich metadata encoded in DANDI file paths and filter experiments by figure, measurement type, and experimental state.

In [None]:
def get_session_id(asset_path: str) -> str:
    """Extract session ID from DANDI asset path."""
    if not asset_path:
        return ""
    bottom_level_path = asset_path.split("/")[1]  
    session_id_with_ses_prefix = bottom_level_path.split("_")[1]
    session_id = session_id_with_ses_prefix.split("-")[1]
    return session_id

def get_figure_number(session_id: str):
    """Extract which figure this data corresponds to."""
    return session_id.split("++")[0]

def get_measurement(session_id: str) -> str:
    """Extract measurement type."""
    if not session_id:
        return ""
    return session_id.split("++")[1]

def get_state(session_id: str) -> str:
    """Extract experimental state."""
    if not session_id:
        return ""
    return session_id.split("++")[3]

def is_figure_number(session_id: str, figure_number: str) -> bool:
    """Check if data belongs to a specific figure."""
    return get_figure_number(session_id) == figure_number

def is_measurement(session_id: str, measurement: str) -> bool:
    """Filter data by measurement/experiment type."""
    return get_measurement(session_id) == measurement

def is_state(session_id: str, state: str) -> bool:
    """Filter data by disease/treatment state."""
    return get_state(session_id) == state

### Event Detection Functions

#### MAD-Based Event Detection

We use **Median Absolute Deviation (MAD)** for robust noise estimation, avoiding bias from the events themselves. Events are detected as deviations >5SD from baseline and nearby events are merged to handle multi-threshold crossings.

In [None]:
def merge_nearby_events(event_times, event_amplitudes, merge_distance_ms=1.0):
    """Merge events within merge_distance_ms, keeping maximum amplitude."""
    if len(event_times) == 0:
        return event_times, event_amplitudes
    
    times = np.array(event_times)
    amplitudes = np.array(event_amplitudes)
    sorted_indices = np.argsort(times)
    times = times[sorted_indices]
    amplitudes = amplitudes[sorted_indices]
    
    merged_times = []
    merged_amplitudes = []
    
    i = 0
    while i < len(times):
        current_time = times[i]
        current_amp = amplitudes[i]
        
        j = i + 1
        max_amp = current_amp
        max_amp_time = current_time
        
        while j < len(times) and (times[j] - current_time) <= merge_distance_ms:
            if amplitudes[j] > max_amp:
                max_amp = amplitudes[j]
                max_amp_time = times[j]
            j += 1
        
        merged_times.append(max_amp_time)
        merged_amplitudes.append(max_amp)
        i = j
    
    return merged_times, merged_amplitudes

def process_nwb_file_for_events(asset, detection_window_shift_ms=100, event_merge_distance_ms=1.0):
    """Process a single NWB file and return event amplitudes."""
    try:
        s3_url = asset.get_content_url(follow_redirects=1, strip_query=False)
        file_system = remfile.File(s3_url)
        file = h5py.File(file_system, mode="r")
        io = NWBHDF5IO(file=file)
        nwbfile = io.read()
        
        optogenetics_table_df = nwbfile.intervals["optogenetic_epochs_table"].to_dataframe()
        stimulation_entries_df = optogenetics_table_df[optogenetics_table_df["stimulation_on"] == True]
        detection_entries_df = optogenetics_table_df[optogenetics_table_df["stage_name"] == "detection"]
        
        acquisition_keys = list(nwbfile.acquisition.keys())
        
        file_positive_amplitudes = []
        file_negative_amplitudes = []
        
        for sweep_key in acquisition_keys:
            trial_number = int(sweep_key.split('Sweep')[-1])
            
            voltage_clamp_response = nwbfile.acquisition[sweep_key]
            timestamps_in_seconds = voltage_clamp_response.get_timestamps()
            data_in_amperes = voltage_clamp_response.get_data_in_units()
            data_in_pico_amperes = data_in_amperes * 1e12
            
            if trial_number <= len(stimulation_entries_df) and trial_number <= len(detection_entries_df):
                detection_info = detection_entries_df.iloc[trial_number - 1]
                
                detection_start_ms_original = detection_info["start_time"] * 1000
                detection_stop_ms = detection_info["stop_time"] * 1000
                detection_start_ms_shifted = detection_start_ms_original + detection_window_shift_ms
                
                timestamps_in_milliseconds = timestamps_in_seconds * 1000
                detection_mask = (timestamps_in_milliseconds >= detection_start_ms_shifted) & (timestamps_in_milliseconds <= detection_stop_ms)
                detection_data = data_in_pico_amperes[detection_mask]
                detection_timestamps = timestamps_in_milliseconds[detection_mask]
                
                if len(detection_data) > 0:
                    # MAD-based noise estimation
                    noise_median = np.median(detection_data)
                    mad = np.median(np.abs(detection_data - noise_median))
                    mad_std = mad * 1.4826  # Convert MAD to std estimate
                    
                    event_threshold_positive = noise_median + 5 * mad_std
                    event_threshold_negative = noise_median - 5 * mad_std
                    
                    # Find events
                    positive_event_indices = np.where(detection_data > event_threshold_positive)[0]
                    negative_event_indices = np.where(detection_data < event_threshold_negative)[0]
                    
                    positive_event_times_raw = detection_timestamps[positive_event_indices] if len(positive_event_indices) > 0 else []
                    negative_event_times_raw = detection_timestamps[negative_event_indices] if len(negative_event_indices) > 0 else []
                    
                    positive_event_amplitudes_raw = detection_data[positive_event_indices] - noise_median if len(positive_event_indices) > 0 else []
                    negative_event_amplitudes_raw = noise_median - detection_data[negative_event_indices] if len(negative_event_indices) > 0 else []
                    
                    # Merge nearby events
                    positive_event_times_merged, positive_event_amplitudes_merged = merge_nearby_events(
                        positive_event_times_raw, positive_event_amplitudes_raw, event_merge_distance_ms)
                    negative_event_times_merged, negative_event_amplitudes_merged = merge_nearby_events(
                        negative_event_times_raw, negative_event_amplitudes_raw, event_merge_distance_ms)
                    
                    file_positive_amplitudes.extend(positive_event_amplitudes_merged)
                    file_negative_amplitudes.extend(negative_event_amplitudes_merged)
        
        io.close()
        file.close()
        
        return file_positive_amplitudes, file_negative_amplitudes
        
    except Exception as e:
        print(f"Error processing {asset.path}: {e}")
        return [], []

### Load DANDI Dataset

Connect to DANDI and filter for Figure 2 optogenetic experiments, separating OffState and OnState conditions.

In [None]:
# Load environment variables
load_dotenv()
token = os.getenv("DANDI_API_TOKEN")
if not token:
    raise ValueError("DANDI_API_TOKEN environment variable not set")

# Connect to DANDI
dandiset_id = "001538"
client = DandiAPIClient(token=token)
client.authenticate(token=token)

dandiset = client.get_dandiset(dandiset_id, "draft")
assets = dandiset.get_assets()
assets_list = list(assets)

# Filter for Figure 2 oEPSC experiments
criteria_offstate = lambda asset: (is_figure_number(get_session_id(asset.path), "F2") and 
                                   is_measurement(get_session_id(asset.path), "oEPSC") and 
                                   is_state(get_session_id(asset.path), "OffState"))

criteria_onstate = lambda asset: (is_figure_number(get_session_id(asset.path), "F2") and 
                                  is_measurement(get_session_id(asset.path), "oEPSC") and 
                                  is_state(get_session_id(asset.path), "OnState"))

offstate_assets = [asset for asset in assets_list if criteria_offstate(asset)]
onstate_assets = [asset for asset in assets_list if criteria_onstate(asset)]

print(f"Found {len(offstate_assets)} OffState and {len(onstate_assets)} OnState files")
print(f"Total Figure 2 oEPSC files: {len(offstate_assets) + len(onstate_assets)}")

## Data Processing and Event Detection

### Process All NWB Files

We process each NWB file to extract oEPSC events, collecting both:
- **Individual events**: For cumulative distribution analysis
- **File means**: For box plot comparison of mean responses per session

In [None]:
# Initialize data collections
all_offstate_events = []  # All individual events for cumulative plot
all_onstate_events = []

offstate_file_means = []  # Mean responses per file for box plot
onstate_file_means = []

# Process OffState files
print("Processing OffState files...")
for i, asset in enumerate(tqdm(offstate_assets, desc="OffState files")):
    session_id = get_session_id(asset.path)
    print(f"  {i+1}/{len(offstate_assets)}: {session_id}")
    
    pos_amps, neg_amps = process_nwb_file_for_events(asset)
    all_events = pos_amps + neg_amps
    
    all_offstate_events.extend(all_events)
    if len(all_events) > 0:
        offstate_file_means.append(np.mean(all_events))

# Process OnState files
print("\nProcessing OnState files...")
for i, asset in enumerate(tqdm(onstate_assets, desc="OnState files")):
    session_id = get_session_id(asset.path)
    print(f"  {i+1}/{len(onstate_assets)}: {session_id}")
    
    pos_amps, neg_amps = process_nwb_file_for_events(asset)
    all_events = pos_amps + neg_amps
    
    all_onstate_events.extend(all_events)
    if len(all_events) > 0:
        onstate_file_means.append(np.mean(all_events))

print(f"\nData collection complete:")
print(f"  OffState: {len(all_offstate_events)} events from {len(offstate_file_means)} files")
print(f"  OnState: {len(all_onstate_events)} events from {len(onstate_file_means)} files")

## Figure 2G: Cumulative Distribution Plot

### Individual Event Analysis

This plot shows the cumulative distribution of **all individual oEPSC events** across all experimental sessions, allowing comparison of the full event amplitude distributions between OffState and OnState conditions.

In [None]:
# Create cumulative distribution plot
fig, ax = plt.subplots(1, 1, figsize=(5.5, 3.5))

# Convert to numpy arrays
offstate_amplitudes = np.array(all_offstate_events)
onstate_amplitudes = np.array(all_onstate_events)

if len(offstate_amplitudes) > 0 and len(onstate_amplitudes) > 0:
    # Sort the data
    offstate_sorted = np.sort(offstate_amplitudes)
    onstate_sorted = np.sort(onstate_amplitudes)
    
    # Calculate cumulative probabilities as percentages
    offstate_cumulative = np.arange(1, len(offstate_sorted) + 1) / len(offstate_sorted) * 100
    onstate_cumulative = np.arange(1, len(onstate_sorted) + 1) / len(onstate_sorted) * 100
    
    # Plot with paper-style colors and thickness
    ax.plot(offstate_sorted, offstate_cumulative, color='black', linewidth=2.5, 
            label='off-state')
    ax.plot(onstate_sorted, onstate_cumulative, color='gray', linewidth=2.5, 
            label='on-state')
    
    # Formatting to match paper exactly
    ax.set_xlabel('oEPSC amplitude (pA)', fontsize=14, fontweight='normal')
    ax.set_ylabel('Cumulative Probability (%)', fontsize=14, fontweight='normal')
    ax.set_title('Figure 2G: dSPN oEPSC Cumulative Probability', fontsize=16, fontweight='bold', pad=15)
    
    # Set axis limits and ticks to match paper
    ax.set_xlim(0, 80)
    ax.set_ylim(0, 100)
    ax.set_xticks([0, 20, 40, 60, 80])
    ax.set_yticks([0, 25, 50, 75, 100])
    
    # Style the axes
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_linewidth(1.5)
    ax.spines['bottom'].set_linewidth(1.5)
    ax.tick_params(axis='both', which='major', labelsize=12, width=1.5, length=5)
    
    # Add legend in upper left
    ax.legend(loc='lower right', frameon=False, fontsize=12)
    
    # Calculate and display statistics
    off_median = np.median(offstate_amplitudes)
    on_median = np.median(onstate_amplitudes)
    off_mean = np.mean(offstate_amplitudes)
    on_mean = np.mean(onstate_amplitudes)
    
    print("=== FIGURE 2G: CUMULATIVE DISTRIBUTION ANALYSIS ===")
    print(f"OffState events: {len(offstate_amplitudes)}")
    print(f"  Mean: {off_mean:.2f} ± {np.std(offstate_amplitudes):.2f} pA")
    print(f"  Median: {off_median:.2f} pA")
    print(f"  25th percentile: {np.percentile(offstate_amplitudes, 25):.2f} pA")
    print(f"  75th percentile: {np.percentile(offstate_amplitudes, 75):.2f} pA")
    
    print(f"\nOnState events: {len(onstate_amplitudes)}")
    print(f"  Mean: {on_mean:.2f} ± {np.std(onstate_amplitudes):.2f} pA")
    print(f"  Median: {on_median:.2f} pA")
    print(f"  25th percentile: {np.percentile(onstate_amplitudes, 25):.2f} pA")
    print(f"  75th percentile: {np.percentile(onstate_amplitudes, 75):.2f} pA")
    
    print(f"\nComparison:")
    print(f"  Mean fold change (OnState/OffState): {on_mean/off_mean:.3f}")
    print(f"  Median fold change: {on_median/off_median:.3f}")
    
    # Kolmogorov-Smirnov test
    ks_stat, ks_p = stats.ks_2samp(offstate_amplitudes, onstate_amplitudes)
    print(f"\nKolmogorov-Smirnov test:")
    print(f"  KS statistic: {ks_stat:.4f}")
    print(f"  p-value: {ks_p:.2e}")
    print(f"  Significantly different: {'Yes' if ks_p < 0.05 else 'No'}")

else:
    print("No data available for plotting")

plt.tight_layout()
plt.show()

## Figure 2H: Box Plot Comparison

### Mean Response Per Session Analysis

This box plot compares the **mean event amplitudes per experimental session**, treating each NWB file as one data point. This approach controls for potential differences in the number of events recorded per session.

In [None]:
# Create box plot
fig, ax = plt.subplots(1, 1, figsize=(3.5, 4.0))

# Convert to numpy arrays
offstate_means = np.array(offstate_file_means)
onstate_means = np.array(onstate_file_means)

if len(offstate_means) > 0 and len(onstate_means) > 0:
    # Prepare data for box plot
    box_data = [offstate_means, onstate_means]
    positions = [1, 2]
    
    # Create box plot with paper-style formatting
    bp = ax.boxplot(box_data, positions=positions, patch_artist=True, 
                   widths=0.4, showfliers=True, notch=False,
                   medianprops=dict(color='black', linewidth=2),
                   whiskerprops=dict(color='black', linewidth=1.5),
                   capprops=dict(color='black', linewidth=1.5),
                   flierprops=dict(marker='o', markersize=4, alpha=0.7, markerfacecolor='gray'))
    
    # Customize box colors to match paper (both white/light gray)
    colors = ['white', 'white']
    for patch, color in zip(bp['boxes'], colors):
        patch.set_facecolor(color)
        patch.set_edgecolor('black')
        patch.set_linewidth(1.5)
    
    # Add individual data points as gray dots
    for i, data in enumerate(box_data):
        # Add some jitter for visibility
        x_vals = np.random.normal(positions[i], 0.04, size=len(data))
        ax.scatter(x_vals, data, color='gray', s=25, alpha=0.8, zorder=3)
    
    # Calculate statistics for significance annotation
    off_mean = np.mean(offstate_means)
    on_mean = np.mean(onstate_means)
    
    # Statistical test
    u_stat, u_p = stats.mannwhitneyu(offstate_means, onstate_means, 
                                    alternative='two-sided')
    
    # Add significance annotation (**) at the top
    y_max = max(np.max(offstate_means), np.max(onstate_means))
    y_sig = y_max + 0.8
    ax.text(1.5, y_sig, '**', ha='center', va='center', fontsize=16, fontweight='bold')
    
    # Formatting to match paper exactly
    ax.set_xticks([1, 2])
    ax.set_xticklabels(['off-state', 'on-state'], fontsize=14)
    ax.set_ylabel('oEPSC amplitude (pA)', fontsize=14, fontweight='normal')
    ax.set_title('Figure 2H: dSPN oEPSC Amplitude Comparison', fontsize=16, fontweight='bold', pad=15)
    
    # Set y-axis limits and ticks
    #ax.set_ylim(0, 25)
    #ax.set_yticks([0, 5, 10, 15, 20, 25])
    
    # Style the axes
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_linewidth(1.5)
    ax.spines['bottom'].set_linewidth(1.5)
    ax.tick_params(axis='both', which='major', labelsize=12, width=1.5, length=5)
    ax.tick_params(axis='x', which='major', length=0)  # Remove x-axis tick marks
    
    # Calculate and display statistics
    off_median = np.median(offstate_means)
    on_median = np.median(onstate_means)
    
    print("=== FIGURE 2H: BOX PLOT STATISTICAL ANALYSIS ===")
    print(f"\nOffState file means (n={len(offstate_means)}):")
    print(f"  Median: {off_median:.2f} pA")
    print(f"  Mean: {off_mean:.2f} ± {np.std(offstate_means):.2f} pA")
    print(f"  IQR: {np.percentile(offstate_means, 25):.2f} - {np.percentile(offstate_means, 75):.2f} pA")
    print(f"  Range: {np.min(offstate_means):.2f} - {np.max(offstate_means):.2f} pA")
    
    print(f"\nOnState file means (n={len(onstate_means)}):")
    print(f"  Median: {on_median:.2f} pA")
    print(f"  Mean: {on_mean:.2f} ± {np.std(onstate_means):.2f} pA") 
    print(f"  IQR: {np.percentile(onstate_means, 25):.2f} - {np.percentile(onstate_means, 75):.2f} pA")
    print(f"  Range: {np.min(onstate_means):.2f} - {np.max(onstate_means):.2f} pA")
    
    print(f"\nComparison:")
    print(f"  Median fold change: {on_median/off_median:.3f}")
    print(f"  Mean fold change: {on_mean/off_mean:.3f}")
    print(f"  Difference in medians: {on_median - off_median:.2f} pA")
    print(f"  Difference in means: {on_mean - off_mean:.2f} pA")
    
    print(f"\nMann-Whitney U test (non-parametric):")
    print(f"  U statistic: {u_stat:.2f}")
    print(f"  p-value: {u_p:.2e}")
    print(f"  Significantly different: {'Yes' if u_p < 0.05 else 'No'}")
    
    # Welch's t-test (unequal variances)
    t_stat, t_p = stats.ttest_ind(offstate_means, onstate_means, 
                                equal_var=False)
    print(f"\nWelch's t-test (unequal variances):")
    print(f"  t statistic: {t_stat:.4f}")
    print(f"  p-value: {t_p:.2e}")
    print(f"  Significantly different: {'Yes' if t_p < 0.05 else 'No'}")
    
    # Effect size (Cohen's d)
    pooled_std = np.sqrt(((len(offstate_means)-1)*np.var(offstate_means) + 
                        (len(onstate_means)-1)*np.var(onstate_means)) / 
                       (len(offstate_means) + len(onstate_means) - 2))
    cohens_d = (on_mean - off_mean) / pooled_std
    print(f"\nEffect size (Cohen's d): {cohens_d:.3f}")
    
    if abs(cohens_d) < 0.2:
        effect_size = "negligible"
    elif abs(cohens_d) < 0.5:
        effect_size = "small"
    elif abs(cohens_d) < 0.8:
        effect_size = "medium"
    else:
        effect_size = "large"
    print(f"  Effect size interpretation: {effect_size}")

else:
    print("No data available for box plot")

plt.tight_layout()
plt.show()

## Summary

### Key Findings

This analysis reproduces the key findings from **Figure 2G-H** of Zhai et al. 2025:

1. **Cumulative Distribution (Figure 2G)**: Shows the distribution of all individual oEPSC events across experimental conditions
2. **Box Plot Comparison (Figure 2H)**: Compares mean event amplitudes per experimental session, controlling for session-to-session variability

### Methodological Notes

- **Event Detection**: MAD-based noise estimation with ±5SD threshold
- **Artifact Avoidance**: 100ms detection window shift to avoid stimulation artifacts
- **Event Merging**: 1ms window to handle multi-threshold crossings from single events
- **Statistical Testing**: Both parametric and non-parametric tests for robust comparison

### Biological Significance

The analysis reveals how L-DOPA treatment (OnState) affects optogenetically-evoked synaptic responses in striatal neurons, providing insights into the synaptic mechanisms underlying levodopa-induced dyskinesia in Parkinson's disease.