# Reproducing Figure 1E: dSPN Somatic Excitability Analysis

This notebook reproduces Figure 1E from **Zhai et al. 2025** analyzing somatic excitability in direct pathway striatal projection neurons (dSPNs) using frequency-intensity (F-I) curves and rheobase measurements.

**Dataset**: DANDI:001538 - State-dependent modulation of spiny projection neurons controls levodopa-induced dyskinesia

**Analysis approach**:
- **F-I Curves**: Frequency-intensity relationships showing action potential firing vs injected current
- **Rheobase Analysis**: Minimum current required to elicit action potential firing
- **Conditions**: LID off-state, LID on-state, and LID on-state with SCH23390 (D1R antagonist)
- **Methodology**: Current clamp recordings with 500ms current steps

## Setup and Data Loading

### Import Libraries and Configure Plotting Style

We use the same plotting parameters as the original publication to ensure visual consistency.

In [None]:
import os
from typing import List, Tuple

import h5py
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import remfile
import seaborn as sns
from dandi.dandiapi import DandiAPIClient
from dotenv import load_dotenv
from pynwb import NWBHDF5IO
from scipy import stats
from tqdm import tqdm

# Set plotting style to match paper
plt.style.use('default')
sns.set_palette("Set2")

def setup_figure_style():
    """Setup matplotlib parameters to match paper style"""
    plt.rcParams.update({
        'font.size': 8,
        'axes.titlesize': 10,
        'axes.labelsize': 9,
        'xtick.labelsize': 8,
        'ytick.labelsize': 8,
        'legend.fontsize': 8,
        'figure.titlesize': 12,
        'axes.linewidth': 0.8,
        'axes.spines.top': False,
        'axes.spines.right': False,
        'xtick.major.width': 0.8,
        'ytick.major.width': 0.8,
        'xtick.minor.width': 0.6,
        'ytick.minor.width': 0.6,
    })

setup_figure_style()
print("Libraries imported and plotting style configured")

### Session ID Parsing and Filtering Functions

These utility functions parse the rich metadata encoded in DANDI file paths and filter experiments by figure, measurement type, and experimental condition.

In [None]:
def get_session_id(asset_path: str) -> str:
    """Extract session ID from DANDI asset path."""
    if not asset_path:
        return ""
    bottom_level_path = asset_path.split("/")[1]  
    session_id_with_ses_prefix = bottom_level_path.split("_")[1]
    session_id = session_id_with_ses_prefix.split("-")[1]
    return session_id

def get_figure_number(session_id: str):
    """Extract which figure this data corresponds to."""
    return session_id.split("++")[0]

def get_measurement(session_id: str) -> str:
    """Extract measurement type."""
    if not session_id:
        return ""
    return session_id.split("++")[1]

def get_state(session_id: str) -> str:
    """Extract experimental state."""
    if not session_id:
        return ""
    return session_id.split("++")[3]

def get_pharmacology(session_id: str) -> str:
    """Extract pharmacological condition."""
    if not session_id:
        return ""
    return session_id.split("++")[4]

def is_f1_somexc(session_id: str) -> bool:
    """Check if data belongs to Figure 1 somatic excitability experiments."""
    return get_figure_number(session_id) == "F1" and get_measurement(session_id) == "SomExc"

def get_condition_label(session_id: str) -> str:
    """Convert session metadata to condition label."""
    parts = session_id.split("++")
    if len(parts) < 5:
        return "unknown"
    
    state = parts[3]
    pharm = parts[4]
    
    if state == "OffState" and pharm == "none":
        return "LID off-state"
    elif state == "OnState" and pharm == "none":
        return "LID on-state"
    elif state == "OnState" and pharm == "D1RaSch":
        return "LID on-state with SCH"
    else:
        return "unknown"

print("Utility functions defined")

### Spike Detection and Analysis Functions

#### Action Potential Counting

We use threshold-crossing spike detection within a 500ms stimulus window (200-700ms from sweep start) to match the methodology from the original analysis.

In [None]:
def count_action_potentials(
    voltage_trace_mV: np.ndarray, timestamps_s: np.ndarray, threshold_mV: float = 0.0
) -> int:
    """Threshold-crossing spike count within a 500 ms stimulus window.

    Matches approach in the original analysis; uses 200–700 ms from sweep start.
    """
    if voltage_trace_mV.size == 0:
        return 0
    if timestamps_s[-1] - timestamps_s[0] < 0.8:
        return 0
    start_t = timestamps_s[0] + 0.2
    end_t = timestamps_s[0] + 0.7
    i0 = np.searchsorted(timestamps_s, start_t)
    i1 = np.searchsorted(timestamps_s, end_t)
    if i0 >= i1 or i1 > voltage_trace_mV.size:
        i0 = voltage_trace_mV.size // 4
        i1 = 3 * voltage_trace_mV.size // 4
    x = voltage_trace_mV[i0:i1]
    if x.size == 0:
        return 0
    spikes = 0
    i = 0
    while i < x.size:
        if x[i] > threshold_mV:
            spikes += 1
            while i < x.size and x[i] > threshold_mV:
                i += 1
        else:
            i += 1
    return spikes

def calculate_rheobase(current_steps: List[float], spike_counts: List[int]) -> float:
    """Calculate rheobase as minimum current that elicits at least one spike."""
    for current, spikes in zip(current_steps, spike_counts):
        if spikes >= 1:
            return current
    return np.nan

def safe_mean_sem(values: np.ndarray) -> Tuple[float, float]:
    """Calculate mean and standard error, handling edge cases."""
    if len(values) == 0:
        return np.nan, 0.0
    if len(values) == 1:
        return float(values[0]), 0.0
    return float(np.mean(values)), float(np.std(values, ddof=1) / np.sqrt(len(values)))

print("Analysis functions defined")

### Load DANDI Dataset

Connect to DANDI and filter for Figure 1 somatic excitability experiments across all three experimental conditions.

In [None]:
# Load environment variables
load_dotenv()
token = os.getenv("DANDI_API_TOKEN")
if not token:
    raise ValueError("DANDI_API_TOKEN environment variable not set")

# Connect to DANDI
dandiset_id = "001538"
client = DandiAPIClient(token=token)
client.authenticate(token=token)

dandiset = client.get_dandiset(dandiset_id, "draft")
assets = dandiset.get_assets()
assets_list = list(assets)

# Filter for Figure 1 somatic excitability experiments
f1_somexc_assets = [asset for asset in assets_list if is_f1_somexc(get_session_id(asset.path))]

print(f"Found {len(f1_somexc_assets)} Figure 1 somatic excitability files")

# Show breakdown by condition
condition_counts = {}
for asset in f1_somexc_assets:
    condition = get_condition_label(get_session_id(asset.path))
    condition_counts[condition] = condition_counts.get(condition, 0) + 1

print("\nBreakdown by condition:")
for condition, count in condition_counts.items():
    print(f"  {condition}: {count} files")

## Data Processing and Analysis

### Process All NWB Files

We process each NWB file to extract current clamp recordings, analyzing:
- **Current steps**: Injected current amplitudes (pA)
- **Spike counts**: Number of action potentials fired during each step
- **F-I relationships**: Frequency-intensity curves for each condition
- **Rheobase**: Minimum current required to elicit spiking

In [None]:
# Initialize data collection
all_recording_data = []  # All individual recordings for F-I curves
cell_statistics = []     # Cell-level statistics for rheobase analysis

print("Processing Figure 1 somatic excitability files...\n")

for i, asset in enumerate(tqdm(f1_somexc_assets, desc="Processing F1 SomExc files")):
    session_id = get_session_id(asset.path)
    condition = get_condition_label(session_id)
    
    # Update progress with current file info
    tqdm.write(f"  {i+1}/{len(f1_somexc_assets)}: {condition} - {session_id}")
    
    try:
        # Open NWB file from DANDI
        s3_url = asset.get_content_url(follow_redirects=1, strip_query=False)
        file_system = remfile.File(s3_url)
        file = h5py.File(file_system, mode="r")
        io = NWBHDF5IO(file=file, load_namespaces=True)
        nwbfile = io.read()
        
        # Get intracellular recordings table
        try:
            if hasattr(nwbfile, "intracellular_recordings"):
                rec_df = nwbfile.intracellular_recordings.to_dataframe()
            else:
                rec_df = nwbfile.get_intracellular_recordings().to_dataframe()
        except Exception as e:
            tqdm.write(f"    No intracellular_recordings table: {e}")
            continue
        
        current_steps = []
        spike_counts = []
        
        # Process each recording sweep
        for _, row in rec_df.iterrows():
            # Get protocol step and current
            protocol_step = row.get(("intracellular_recordings", "protocol_step"), None)
            current_pA = row.get(("intracellular_recordings", "stimulus_current_pA"), None)
            
            # Get voltage response TimeSeries
            ts = None
            
            # Strategy A: acquisition series named by protocol_step
            if protocol_step is not None:
                series_name = f"CurrentClampSeries{int(protocol_step):03d}" if not str(protocol_step).startswith(
                    "CurrentClampSeries"
                ) else str(protocol_step)
                if series_name in nwbfile.acquisition:
                    ts = nwbfile.acquisition[series_name]
            
            # Strategy B: follow reference in table
            if ts is None:
                try:
                    response_ref = row[("responses", "response")]
                    if hasattr(response_ref, "iloc"):
                        response_ref = response_ref.iloc[0]
                    ts = response_ref.timeseries
                except Exception:
                    ts = None
            
            if ts is None:
                continue
            
            # Get timestamps and voltage data
            if ts.timestamps is not None:
                timestamps_s = np.asarray(ts.timestamps, dtype=float)
            else:
                rate = float(ts.rate)
                timestamps_s = float(ts.starting_time) + np.arange(ts.data.shape[0], dtype=float) / rate
            
            voltage_mV = np.asarray(ts.data, dtype=float) * 1000.0
            
            # Get current amplitude if not provided
            if current_pA is None or pd.isna(current_pA):
                try:
                    stim_ref = row[("stimuli", "stimulus")]
                    if hasattr(stim_ref, "iloc"):
                        stim_ref = stim_ref.iloc[0]
                    stim_ts = stim_ref.timeseries
                    if stim_ts is not None:
                        current_A = np.asarray(stim_ts.data, dtype=float)
                        current_pA = float(np.median(current_A * 1e12))
                except Exception:
                    current_pA = np.nan
            
            if pd.isna(current_pA):
                continue
            
            # Count spikes and store data
            spike_count = count_action_potentials(voltage_mV, timestamps_s, threshold_mV=0.0)
            current_steps.append(float(current_pA))
            spike_counts.append(spike_count)
            
            # Store individual recording data
            all_recording_data.append({
                'session_id': session_id,
                'condition': condition,
                'current_pA': float(current_pA),
                'spike_count': spike_count,
                'nwb_file': asset.path.split('/')[-1]
            })
        
        # Calculate cell-level statistics
        if current_steps and spike_counts:
            # Sort by current for rheobase calculation
            sorted_data = sorted(zip(current_steps, spike_counts), key=lambda x: x[0])
            sorted_currents, sorted_spikes = zip(*sorted_data)
            
            rheobase_pA = calculate_rheobase(list(sorted_currents), list(sorted_spikes))
            
            # Get subject ID
            try:
                subject_id = nwbfile.subject.subject_id if nwbfile.subject is not None else session_id
            except Exception:
                subject_id = session_id
            
            cell_statistics.append({
                'session_id': session_id,
                'subject_id': subject_id,
                'condition': condition,
                'rheobase_pA': rheobase_pA,
                'n_recordings': len(current_steps),
                'nwb_file': asset.path.split('/')[-1]
            })
        
        tqdm.write(f"    Processed {len(current_steps)} sweeps")
        
        # Close file
        io.close()
        file.close()
        
    except Exception as e:
        tqdm.write(f"    Error processing file: {e}")
        continue

# Create DataFrames
df_recordings = pd.DataFrame(all_recording_data)
df_cells = pd.DataFrame(cell_statistics)

print(f"\nData processing complete:")
print(f"  Total recordings: {len(df_recordings)}")
print(f"  Total cells: {len(df_cells)}")
print(f"  Conditions: {df_recordings['condition'].nunique()}")

print("\nRecording breakdown by condition:")
for condition in df_recordings['condition'].unique():
    n_recordings = len(df_recordings[df_recordings['condition'] == condition])
    n_cells = len(df_cells[df_cells['condition'] == condition])
    print(f"  {condition}: {n_recordings} recordings from {n_cells} cells")

## Figure 1E: Frequency-Intensity (F-I) Curves

### Action Potential Frequency vs Injected Current

This plot shows the relationship between injected current and action potential firing frequency across the three experimental conditions, revealing how L-DOPA treatment and D1 receptor antagonism affect dSPN excitability.

In [None]:
# Create F-I curves plot
fig, ax = plt.subplots(1, 1, figsize=(6, 4))

# Define condition plotting styles to match paper
condition_styles = {
    "LID off-state": {
        "color": "black", 
        "marker": "o", 
        "linestyle": "-", 
        "label": "off-state",
        "markerfacecolor": "white",
        "markeredgecolor": "black"
    },
    "LID on-state": {
        "color": "black", 
        "marker": "s", 
        "linestyle": "-", 
        "label": "on-state",
        "markerfacecolor": "black",
        "markeredgecolor": "black"
    },
    "LID on-state with SCH": {
        "color": "gray", 
        "marker": "^", 
        "linestyle": "--", 
        "label": "on-state+D1R\\nantagonist",
        "markerfacecolor": "gray",
        "markeredgecolor": "gray"
    }
}

# Process each condition
for condition in ["LID off-state", "LID on-state", "LID on-state with SCH"]:
    if condition not in df_recordings["condition"].unique():
        print(f"Warning: {condition} not found in data")
        continue
    
    condition_data = df_recordings[df_recordings["condition"] == condition]
    
    # Calculate mean and SEM for each current step
    summary_data = []
    for current, group in condition_data.groupby("current_pA"):
        mean_spikes, sem_spikes = safe_mean_sem(group["spike_count"].values)
        summary_data.append({
            "current_pA": current, 
            "mean_spikes": mean_spikes, 
            "sem_spikes": sem_spikes
        })
    
    summary_df = pd.DataFrame(summary_data).sort_values("current_pA")
    
    # Filter to current range used in paper (0-300 pA)
    summary_df = summary_df[
        (summary_df["current_pA"] >= 0) & (summary_df["current_pA"] <= 300)
    ]
    
    if len(summary_df) == 0:
        print(f"Warning: No data in 0-300pA range for {condition}")
        continue
    
    # Plot with error bars
    style = condition_styles[condition]
    ax.errorbar(
        summary_df["current_pA"], 
        summary_df["mean_spikes"], 
        yerr=summary_df["sem_spikes"],
        marker=style["marker"], 
        color=style["color"], 
        linestyle=style["linestyle"],
        linewidth=1.5, 
        markersize=4, 
        capsize=3, 
        capthick=1, 
        label=style["label"],
        markerfacecolor=style["markerfacecolor"],
        markeredgecolor=style["markeredgecolor"], 
        markeredgewidth=1,
    )

# Formatting to match paper style
ax.set_xlabel("current (pA)", fontsize=12)
ax.set_ylabel("number of APs", fontsize=12)
ax.set_xlim(0, 300)
ax.set_ylim(0, 18)
ax.set_xticks([0, 100, 200, 300])
ax.set_yticks([0, 5, 10, 15])

# Style the axes
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_linewidth(1.5)
ax.spines['bottom'].set_linewidth(1.5)
ax.tick_params(axis='both', which='major', labelsize=10, width=1.5, length=5)

# Add legend
ax.legend(loc="upper left", frameon=False, fontsize=10)
ax.set_title("Figure 1E: dSPN Somatic Excitability (F-I Curves)", fontsize=14, fontweight='bold', pad=15)

plt.tight_layout()
plt.show()

# Print summary statistics
print("=== F-I CURVE ANALYSIS SUMMARY ===\n")
for condition in df_recordings["condition"].unique():
    condition_data = df_recordings[df_recordings["condition"] == condition]
    n_recordings = len(condition_data)
    n_cells = len(df_cells[df_cells["condition"] == condition])
    
    # Calculate statistics at different current levels
    current_levels = [50, 100, 150, 200, 250]
    print(f"{condition}:")
    print(f"  Sample: {n_recordings} recordings from {n_cells} cells")
    
    for current in current_levels:
        current_data = condition_data[
            (condition_data["current_pA"] >= current - 10) & 
            (condition_data["current_pA"] <= current + 10)
        ]
        if len(current_data) > 0:
            mean_spikes = current_data["spike_count"].mean()
            sem_spikes = current_data["spike_count"].std() / np.sqrt(len(current_data))
            print(f"  {current}pA: {mean_spikes:.1f} ± {sem_spikes:.1f} APs (n={len(current_data)})")
    print()

## Figure 1E: Rheobase Comparison

### Minimum Current Required for Action Potential Generation

This box plot compares the rheobase (minimum current required to elicit at least one action potential) across experimental conditions, showing how L-DOPA treatment affects neuronal excitability thresholds.

In [None]:
# Create rheobase comparison plot
fig, ax = plt.subplots(1, 1, figsize=(4, 5))

# Filter out cells with invalid rheobase values
valid_cells = df_cells.dropna(subset=["rheobase_pA"])

if len(valid_cells) == 0:
    print("No valid rheobase data available")
else:
    # Prepare data for box plot
    conditions_order = ["LID off-state", "LID on-state", "LID on-state with SCH"]
    condition_labels = ["off-state", "on-state", "on+SCH"]
    
    # Get data for each condition
    plot_data = []
    actual_labels = []
    
    for condition, label in zip(conditions_order, condition_labels):
        condition_data = valid_cells[valid_cells["condition"] == condition]["rheobase_pA"]
        if len(condition_data) > 0:
            plot_data.append(condition_data.values)
            actual_labels.append(label)
        else:
            print(f"Warning: No rheobase data for {condition}")
    
    if len(plot_data) > 0:
        # Create box plot
        bp = ax.boxplot(
            plot_data, 
            labels=actual_labels, 
            patch_artist=True,
            boxprops=dict(facecolor="white", color="black", linewidth=1.5),
            whiskerprops=dict(color="black", linewidth=1.5), 
            capprops=dict(color="black", linewidth=1.5),
            medianprops=dict(color="black", linewidth=2),
            flierprops=dict(marker="o", markerfacecolor="gray", markersize=4, 
                          markeredgecolor="black", alpha=0.7),
        )
        
        # Add individual data points with jitter
        for i, data in enumerate(plot_data):
            x_vals = np.random.normal(i + 1, 0.04, len(data))
            ax.scatter(x_vals, data, color="gray", alpha=0.8, s=20, zorder=3)
        
        # Formatting
        ax.set_ylabel("rheobase (pA)", fontsize=12)
        ax.set_ylim(0, 300)
        ax.set_title("Figure 1E: Rheobase Comparison", fontsize=14, fontweight='bold', pad=15)
        
        # Style the axes
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.spines['left'].set_linewidth(1.5)
        ax.spines['bottom'].set_linewidth(1.5)
        ax.tick_params(axis='both', which='major', labelsize=10, width=1.5, length=5)
        ax.tick_params(axis='x', which='major', length=0)  # Remove x-axis tick marks
        
        plt.tight_layout()
        plt.show()
        
        # Statistical analysis
        print("=== RHEOBASE STATISTICAL ANALYSIS ===\n")
        
        for i, (condition, data) in enumerate(zip(conditions_order, plot_data)):
            if condition in valid_cells["condition"].unique():
                rheobase_data = valid_cells[valid_cells["condition"] == condition]["rheobase_pA"]
                n_cells = len(rheobase_data)
                mean_rheo = rheobase_data.mean()
                sem_rheo = rheobase_data.std() / np.sqrt(n_cells)
                median_rheo = rheobase_data.median()
                q25 = rheobase_data.quantile(0.25)
                q75 = rheobase_data.quantile(0.75)
                
                print(f"{condition} (n={n_cells}):")
                print(f"  Mean: {mean_rheo:.1f} ± {sem_rheo:.1f} pA")
                print(f"  Median: {median_rheo:.1f} pA")
                print(f"  IQR: {q25:.1f} - {q75:.1f} pA")
                print(f"  Range: {rheobase_data.min():.1f} - {rheobase_data.max():.1f} pA\n")
        
        # Statistical comparisons
        if len(plot_data) >= 2:
            print("Statistical Comparisons:")
            
            # Off-state vs On-state
            if "LID off-state" in valid_cells["condition"].unique() and "LID on-state" in valid_cells["condition"].unique():
                off_data = valid_cells[valid_cells["condition"] == "LID off-state"]["rheobase_pA"]
                on_data = valid_cells[valid_cells["condition"] == "LID on-state"]["rheobase_pA"]
                
                # Mann-Whitney U test
                u_stat, u_p = stats.mannwhitneyu(off_data, on_data, alternative='two-sided')
                print(f"\nOff-state vs On-state:")
                print(f"  Mann-Whitney U: {u_stat:.2f}, p = {u_p:.4f}")
                print(f"  Significantly different: {'Yes' if u_p < 0.05 else 'No'}")
                
                # Effect size
                mean_diff = on_data.mean() - off_data.mean()
                print(f"  Mean difference: {mean_diff:.1f} pA")
    else:
        print("No rheobase data available for plotting")

## Summary

### Key Findings

This analysis reproduces the key findings from **Figure 1E** of Zhai et al. 2025:

1. **F-I Curves**: Show the relationship between injected current and action potential frequency across experimental conditions
2. **Rheobase Analysis**: Compares the minimum current required to elicit spiking between conditions
3. **L-DOPA Effects**: Reveals how levodopa treatment affects dSPN somatic excitability
4. **D1 Receptor Role**: Shows the contribution of D1 receptors using SCH23390 antagonist

### Methodological Notes

- **Current Clamp**: Whole-cell patch clamp recordings in current clamp mode
- **Spike Detection**: Threshold-crossing detection at 0mV within 500ms stimulus window (200-700ms)
- **Current Range**: 0-300 pA injected current steps
- **Rheobase Definition**: Minimum current to elicit ≥1 action potential
- **Statistics**: Mann-Whitney U test for non-parametric comparisons

### Biological Significance

The analysis reveals how L-DOPA treatment affects the intrinsic excitability of direct pathway striatal projection neurons, providing insights into the cellular mechanisms underlying levodopa-induced dyskinesia in Parkinson's disease. The D1 receptor antagonist experiments help dissect the specific receptor mechanisms involved in these excitability changes.