In [2]:
from spike_sorting_gui import launch_spike_sorting_gui
import neo
import numpy as np
from get_path import get_path

# PLXファイル読み込み
plx_file = get_path(file_type='plx', initial_dir=r'\\Nagoya_AMCBNAS2\wakelab7\okita\whisker_analysis\20260114-1_Wildtype_S1BF')
plx = neo.io.PlexonIO(filename=plx_file)
data = plx.read()
seg = data[0].segments[0]

# 40kHzデータを取得
for sig in seg.analogsignals:
    if float(sig.sampling_rate) >= 20000:
        wideband = np.array(sig)
        fs = float(sig.sampling_rate)
        break



 選択: 260114-4_Wildtype_S1BF_920_wisker.plx


Parsing data blocks: 100%|█████████▉| 157886004/157976764 [00:04<00:00, 33362145.05it/s]
Finalizing data blocks for type 1: 100%|██████████| 32/32 [00:00<00:00, 31994.69it/s]
Finalizing data blocks for type 4: 100%|██████████| 43/43 [00:00<00:00, 3341.15it/s]
Finalizing data blocks for type 5: 100%|██████████| 128/128 [00:05<00:00, 23.60it/s]
Finalizing data blocks: 100%|██████████| 3/3 [00:05<00:00,  1.82s/it]
Parsing signal channels: 100%|██████████| 128/128 [00:00<?, ?it/s]
Parsing spike channels: 16it [00:00, ?it/s]
Parsing event channels: 100%|██████████| 43/43 [00:00<?, ?it/s]


In [3]:
# GUI起動（自動ソーティング実行）
launch_spike_sorting_gui(wideband, fs)

自動ソーティング実行中...
=== スパイクソーティング ===
チャンネル: 16, fs: 40000.0 Hz
フィルタ: 300.0-3000.0 Hz

フィルタリング...

--- Channel 0 ---
  Ch0: 検出 849 スパイク
  Ch0: PCA寄与率 [0.30, 0.28, 0.18]
  Ch0: 4 クラスター
    Unit1: n= 490, amp=-0.0511, SNR=4.0, ISI=1.4% ✓
    Unit2: n= 216, amp=-0.0556, SNR=3.2, ISI=3.7% ⚠
    Unit3: n=  45, amp=-0.0637, SNR=1.8, ISI=4.5% ⚠
    Unit4: n=  73, amp=-0.0647, SNR=3.9, ISI=0.0% ✓

--- Channel 1 ---
  Ch1: 検出 948 スパイク
  Ch1: PCA寄与率 [0.31, 0.25, 0.18]
  Ch1: 4 クラスター
    Unit1: n= 249, amp=-0.0519, SNR=3.3, ISI=6.5% ⚠
    Unit2: n= 519, amp=-0.0475, SNR=3.8, ISI=1.9% ✓
    Unit3: n=  25, amp=-0.0634, SNR=1.6, ISI=8.3% ⚠
    Unit4: n= 141, amp=-0.0519, SNR=2.7, ISI=0.0% ✓

--- Channel 2 ---
  Ch2: 検出 869 スパイク
  Ch2: PCA寄与率 [0.30, 0.28, 0.17]
  Ch2: 3 クラスター
    Unit1: n= 182, amp=-0.0581, SNR=2.4, ISI=6.1% ⚠
    Unit2: n=  76, amp=-0.0625, SNR=3.8, ISI=0.0% ✓
    Unit3: n= 582, amp=-0.0490, SNR=3.7, ISI=3.3% ⚠

--- Channel 3 ---
  Ch3: 検出 1050 スパイク
  Ch3: PCA寄与率 [0.32, 0.25, 0.15]
  Ch

In [5]:
import numpy as np
import matplotlib.pyplot as plt
from spike_sorting import sort_channel, SortingConfig, compute_isi_histogram, compute_autocorrelogram
from scipy import signal
import neo


for sig in seg.analogsignals:
    if float(sig.sampling_rate) >= 20000:
        wideband = np.array(sig)
        fs = float(sig.sampling_rate)
        break

# フィルタリング
def bandpass_filter(data, fs, low=300, high=3000):
    nyq = 0.5 * fs
    sos = signal.butter(4, [low/nyq, high/nyq], btype='band', output='sos')
    return signal.sosfiltfilt(sos, data, axis=0)

filtered = bandpass_filter(wideband, fs)

# ソーティング（Ch0）
config = SortingConfig()
result = sort_channel(filtered[:, 0], fs, 0, config, verbose=True)

# Unit 1を取得（一番スパイク数が多いもの）
unit = result.units[0]
spike_times = unit.spike_times  # 秒

print(f"\n=== 発火パターン分析 ===")
print(f"スパイク数: {len(spike_times)}")
print(f"記録時間: {spike_times[-1] - spike_times[0]:.1f} 秒")
print(f"平均発火率: {len(spike_times) / (spike_times[-1] - spike_times[0]):.1f} Hz")

# ISI詳細
isi_ms = np.diff(spike_times) * 1000
print(f"\n=== ISI統計 ===")
print(f"平均ISI: {np.mean(isi_ms):.1f} ms")
print(f"中央値ISI: {np.median(isi_ms):.1f} ms")
print(f"最小ISI: {np.min(isi_ms):.2f} ms")
print(f"最大ISI: {np.max(isi_ms):.1f} ms")

# 特定の周期に対応するISIの数を確認
print(f"\n=== 特定周期のISI ===")
print(f"10Hz (100ms) 付近 (90-110ms): {np.sum((isi_ms > 90) & (isi_ms < 110))} 個")
print(f"17Hz (59ms) 付近 (50-70ms): {np.sum((isi_ms > 50) & (isi_ms < 70))} 個")
print(f"34Hz (29ms) 付近 (25-35ms): {np.sum((isi_ms > 25) & (isi_ms < 35))} 個")
print(f"短いISI (<10ms): {np.sum(isi_ms < 10)} 個 → バースト?")

# プロット
fig, axes = plt.subplots(2, 3, figsize=(15, 10))

# 1. ISIヒストグラム（細かいビン）
ax1 = axes[0, 0]
ax1.hist(isi_ms[isi_ms < 200], bins=200, color='steelblue', edgecolor='black', alpha=0.7)
ax1.axvline(x=100, color='red', linestyle='--', linewidth=2, label='10Hz (100ms)')
ax1.axvline(x=59, color='orange', linestyle='--', linewidth=2, label='17Hz (59ms)')
ax1.axvline(x=29, color='green', linestyle='--', linewidth=2, label='34Hz (29ms)')
ax1.axvline(x=2, color='purple', linestyle='--', linewidth=1, label='2ms (refractory)')
ax1.set_xlabel('ISI (ms)')
ax1.set_ylabel('Count')
ax1.set_title('ISI Histogram (detailed)')
ax1.legend()

# 2. Autocorrelogram（細かいビン）
ax2 = axes[0, 1]
bins_ac, autocorr = compute_autocorrelogram(spike_times, bin_size_ms=1.0, window_ms=150)
ax2.bar(bins_ac, autocorr, width=1.0, color='steelblue', alpha=0.7)
ax2.axvline(x=100, color='red', linestyle='--', linewidth=2, label='10Hz')
ax2.axvline(x=-100, color='red', linestyle='--', linewidth=2)
ax2.axvline(x=59, color='orange', linestyle='--', linewidth=2, label='17Hz')
ax2.axvline(x=-59, color='orange', linestyle='--', linewidth=2)
ax2.axvline(x=29, color='green', linestyle='--', linewidth=2, label='34Hz')
ax2.axvline(x=-29, color='green', linestyle='--', linewidth=2)
ax2.set_xlabel('Lag (ms)')
ax2.set_ylabel('Count')
ax2.set_title('Autocorrelogram (detailed)')
ax2.legend()

# 3. 短いISI（バースト検出）
ax3 = axes[0, 2]
short_isi = isi_ms[isi_ms < 20]
if len(short_isi) > 0:
    ax3.hist(short_isi, bins=20, color='red', edgecolor='black', alpha=0.7)
    ax3.axvline(x=2, color='purple', linestyle='--', label='2ms (refractory)')
ax3.set_xlabel('ISI (ms)')
ax3.set_ylabel('Count')
ax3.set_title('Short ISI (<20ms) - Burst detection')
ax3.legend()

# 4. スパイク発生時刻のラスタープロット（最初の5秒）
ax4 = axes[1, 0]
t_show = 5.0
spike_times_show = spike_times[spike_times < t_show]
ax4.eventplot([spike_times_show], lineoffsets=0, linelengths=0.8, colors='black')

# 10Hz刺激のタイミング（仮）
stim_times = np.arange(0, t_show, 0.1)  # 10Hz = 100ms間隔
for t in stim_times:
    ax4.axvline(x=t, color='red', alpha=0.3, linewidth=0.5)
ax4.set_xlabel('Time (s)')
ax4.set_title(f'Raster plot (first {t_show}s) - Red: 10Hz stim timing')
ax4.set_xlim(0, t_show)

# 5. 連続するISIの関係（バーストパターン確認）
ax5 = axes[1, 1]
if len(isi_ms) > 1:
    ax5.scatter(isi_ms[:-1], isi_ms[1:], alpha=0.3, s=5)
    ax5.axhline(y=100, color='red', linestyle='--', alpha=0.5)
    ax5.axvline(x=100, color='red', linestyle='--', alpha=0.5)
    ax5.set_xlabel('ISI_n (ms)')
    ax5.set_ylabel('ISI_n+1 (ms)')
    ax5.set_title('Return map (ISI_n vs ISI_n+1)')
    ax5.set_xlim(0, 200)
    ax5.set_ylim(0, 200)

# 6. 発火率の時間変化
ax6 = axes[1, 2]
bin_size = 0.1  # 100ms bins
bins = np.arange(0, spike_times[-1], bin_size)
rate, _ = np.histogram(spike_times, bins=bins)
rate = rate / bin_size  # Hz
ax6.plot(bins[:-1], rate, 'b-', linewidth=0.5)
ax6.set_xlabel('Time (s)')
ax6.set_ylabel('Firing rate (Hz)')
ax6.set_title('Instantaneous firing rate')

plt.tight_layout()
plt.savefig('firing_pattern_analysis.png', dpi=150)
plt.show()

# バースト検出
print(f"\n=== バースト分析 ===")
burst_threshold = 10  # ms
burst_isi = isi_ms < burst_threshold
n_burst_events = np.sum(burst_isi)
print(f"バースト内ISI (<{burst_threshold}ms): {n_burst_events} 回")
print(f"全ISIに占める割合: {n_burst_events / len(isi_ms) * 100:.1f}%")

# バーストのパターンを確認
if n_burst_events > 0:
    # 連続するバーストを検出
    burst_starts = []
    burst_lengths = []
    in_burst = False
    current_length = 1
    
    for i, is_burst in enumerate(burst_isi):
        if is_burst:
            if not in_burst:
                burst_starts.append(i)
                in_burst = True
                current_length = 2
            else:
                current_length += 1
        else:
            if in_burst:
                burst_lengths.append(current_length)
                in_burst = False
    
    if in_burst:
        burst_lengths.append(current_length)
    
    print(f"バースト回数: {len(burst_lengths)}")
    if burst_lengths:
        print(f"バースト内スパイク数: 平均 {np.mean(burst_lengths):.1f}, 最大 {max(burst_lengths)}")

  Ch0: 検出 849 スパイク
  Ch0: PCA寄与率 [0.30, 0.28, 0.18]
  Ch0: 4 クラスター
    Unit1: n= 490, amp=-0.0511, SNR=4.0, ISI=1.4% ✓
    Unit2: n= 216, amp=-0.0556, SNR=3.2, ISI=3.7% ⚠
    Unit3: n=  45, amp=-0.0637, SNR=1.8, ISI=4.5% ⚠
    Unit4: n=  73, amp=-0.0647, SNR=3.9, ISI=0.0% ✓

=== 発火パターン分析 ===
スパイク数: 490
記録時間: 109.5 秒
平均発火率: 4.5 Hz

=== ISI統計 ===
平均ISI: 223.9 ms
中央値ISI: 16.6 ms
最小ISI: 1.00 ms
最大ISI: 6170.9 ms

=== 特定周期のISI ===
10Hz (100ms) 付近 (90-110ms): 13 個
17Hz (59ms) 付近 (50-70ms): 27 個
34Hz (29ms) 付近 (25-35ms): 23 個
短いISI (<10ms): 192 個 → バースト?

=== バースト分析 ===
バースト内ISI (<10ms): 192 回
全ISIに占める割合: 39.3%
バースト回数: 61
バースト内スパイク数: 平均 4.1, 最大 16


In [6]:
import numpy as np
import matplotlib.pyplot as plt
from scipy import signal

# === 既にフィルタ済みデータとスパイク検出結果があるとする ===
# filtered: フィルタ済みデータ（Ch0）
# result: sort_channelの結果
# unit: result.units[0]

unit = result.units[0]
spike_times = unit.spike_times
spike_indices = unit.spike_indices
waveforms = unit.waveforms

# 波形時間軸
pre_ms = 0.5
post_ms = 1.0
pre_samples = int(pre_ms / 1000 * fs)
post_samples = int(post_ms / 1000 * fs)
time_ms = np.linspace(-pre_ms, post_ms, pre_samples + post_samples)

fig, axes = plt.subplots(2, 3, figsize=(16, 10))

# ========================================
# 1. バーストの最初のスパイク vs 2番目のスパイク
# ========================================
ax1 = axes[0, 0]

isi_ms = np.diff(spike_indices) / fs * 1000
burst_threshold = 10  # ms

first_in_burst = []  # バーストの1発目
later_in_burst = []  # バーストの2発目以降
isolated = []        # 孤立スパイク

for i in range(len(isi_ms)):
    prev_long = (i == 0) or (isi_ms[i-1] > burst_threshold)
    next_short = isi_ms[i] < burst_threshold
    
    if prev_long and next_short:
        first_in_burst.append(i)
    elif not prev_long and (i < len(isi_ms)):
        later_in_burst.append(i)
    elif prev_long and not next_short:
        isolated.append(i)

print(f"バースト1発目: {len(first_in_burst)}")
print(f"バースト2発目以降: {len(later_in_burst)}")
print(f"孤立スパイク: {len(isolated)}")

# 各カテゴリの平均波形
if first_in_burst:
    wf_first = waveforms[first_in_burst]
    ax1.plot(time_ms, np.mean(wf_first, axis=0), 'b-', linewidth=2, 
             label=f'Burst 1st (n={len(first_in_burst)})')

if later_in_burst:
    wf_later = waveforms[later_in_burst]
    ax1.plot(time_ms, np.mean(wf_later, axis=0), 'r-', linewidth=2, 
             label=f'Burst 2nd+ (n={len(later_in_burst)})')

if isolated:
    wf_iso = waveforms[isolated]
    ax1.plot(time_ms, np.mean(wf_iso, axis=0), 'g-', linewidth=2, 
             label=f'Isolated (n={len(isolated)})')

ax1.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
ax1.set_xlabel('Time (ms)')
ax1.set_ylabel('Amplitude')
ax1.set_title('Mean waveform by burst position')
ax1.legend()

# ========================================
# 2. 高発火率 vs 低発火率 区間の波形比較
# ========================================
ax2 = axes[0, 1]

# 前半（疎な区間）vs 後半（密な区間）
mid_time = np.median(spike_times)
early_mask = spike_times < mid_time
late_mask = spike_times >= mid_time

if np.sum(early_mask) > 5:
    ax2.plot(time_ms, np.mean(waveforms[early_mask], axis=0), 'b-', linewidth=2,
             label=f'Early (n={np.sum(early_mask)})')

if np.sum(late_mask) > 5:
    ax2.plot(time_ms, np.mean(waveforms[late_mask], axis=0), 'r-', linewidth=2,
             label=f'Late (n={np.sum(late_mask)})')

ax2.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
ax2.set_xlabel('Time (ms)')
ax2.set_ylabel('Amplitude')
ax2.set_title('Early vs Late waveforms')
ax2.legend()

# ========================================
# 3. 振幅の時間変化
# ========================================
ax3 = axes[0, 2]

peak_amps = np.min(waveforms, axis=1)
ax3.scatter(spike_times, peak_amps, s=5, alpha=0.5)
ax3.axhline(y=np.mean(peak_amps), color='red', linestyle='--', label='Mean')
ax3.set_xlabel('Time (s)')
ax3.set_ylabel('Peak Amplitude')
ax3.set_title('Amplitude stability')
ax3.legend()

# ========================================
# 4. 生波形でバースト区間を拡大表示
# ========================================
ax4 = axes[1, 0]

# 密な発火区間を見つける
dense_start_idx = np.searchsorted(spike_times, mid_time)
if dense_start_idx < len(spike_indices):
    center_sample = spike_indices[dense_start_idx]
    window = int(0.05 * fs)  # 50ms
    start = max(0, center_sample - window)
    end = min(len(filtered), center_sample + window)
    
    t_plot = np.arange(start, end) / fs * 1000  # ms
    ax4.plot(t_plot - t_plot[0], filtered[start:end, 0] if filtered.ndim > 1 else filtered[start:end], 
             'b-', linewidth=0.5)
    
    # スパイクマーカー
    for idx in spike_indices:
        if start <= idx < end:
            ax4.axvline(x=(idx - start) / fs * 1000, color='red', alpha=0.5, linewidth=0.5)
    
    ax4.set_xlabel('Time (ms)')
    ax4.set_ylabel('Amplitude')
    ax4.set_title(f'Raw signal (50ms window) at t={center_sample/fs:.1f}s')

# ========================================
# 5. バーストサイズの分布
# ========================================
ax5 = axes[1, 1]

burst_sizes = []
current_burst = 1

for i, isi in enumerate(isi_ms):
    if isi < burst_threshold:
        current_burst += 1
    else:
        burst_sizes.append(current_burst)
        current_burst = 1
burst_sizes.append(current_burst)

ax5.hist(burst_sizes, bins=range(1, max(burst_sizes) + 2), 
         color='steelblue', edgecolor='black', alpha=0.7, align='left')
ax5.set_xlabel('Spikes per burst')
ax5.set_ylabel('Count')
ax5.set_title('Burst size distribution')
ax5.set_xticks(range(1, min(max(burst_sizes) + 1, 15)))

print(f"\nバーストサイズ: 平均 {np.mean(burst_sizes):.1f}, 最大 {max(burst_sizes)}")
print(f"孤立スパイク(=1): {burst_sizes.count(1)}")
print(f"ダブレット(=2): {burst_sizes.count(2)}")
print(f"トリプレット(=3): {burst_sizes.count(3)}")

# ========================================
# 6. パワースペクトル（周期性の確認）
# ========================================
ax6 = axes[1, 2]

# スパイクトレインからパワースペクトル
spike_train = np.zeros(int(spike_times[-1] * 1000) + 1)  # 1ms bins
for t in spike_times:
    idx = int(t * 1000)
    if idx < len(spike_train):
        spike_train[idx] = 1

freqs = np.fft.rfftfreq(len(spike_train), d=0.001)  # 1ms bin
power = np.abs(np.fft.rfft(spike_train))**2

# 0-50Hzを表示
mask = freqs < 50
ax6.plot(freqs[mask], power[mask], 'b-', linewidth=0.5)
ax6.axvline(x=10, color='red', linestyle='--', label='10Hz')
ax6.axvline(x=17, color='orange', linestyle='--', label='17Hz')
ax6.axvline(x=34, color='green', linestyle='--', label='34Hz')
ax6.set_xlabel('Frequency (Hz)')
ax6.set_ylabel('Power')
ax6.set_title('Spike train power spectrum')
ax6.legend()

plt.tight_layout()
plt.savefig('burst_analysis.png', dpi=150)
plt.show()

バースト1発目: 61
バースト2発目以降: 192
孤立スパイク: 236

バーストサイズ: 平均 1.6, 最大 16
孤立スパイク(=1): 237
ダブレット(=2): 30
トリプレット(=3): 9


In [7]:
import numpy as np
import matplotlib.pyplot as plt
from scipy import signal
import neo



# ========================================
# 1. 刺激タイミングを探す
# ========================================
print("=== Event channels ===")
for event in seg.events:
    print(f"  {event.name}: {len(event.times)} events")
    if len(event.times) > 0:
        times = np.array(event.times)
        print(f"    First 10: {times[:10]}")
        print(f"    Time range: {times[0]:.3f} - {times[-1]:.3f}")

print("\n=== Epoch channels ===")
for epoch in seg.epochs:
    print(f"  {epoch.name}: {len(epoch.times)} epochs")

# もしイベントが見つからない場合、PLXのraw eventを確認
print("\n=== All analog signals ===")
for i, sig in enumerate(seg.analogsignals):
    sr = float(sig.sampling_rate)
    dur = len(sig) / sr
    print(f"  Signal {i}: shape={sig.shape}, fs={sr:.0f} Hz, duration={dur:.1f}s")
    # annotations確認
    print(f"    annotations: {sig.annotations}")


# ========================================
# 仮に刺激タイミングが見つかった場合（stim_timesに秒単位で入る）
# 見つからない場合は手動で推定する
# ========================================

# パターン1: Eventから取得
stim_times = None
for event in seg.events:
    if len(event.times) > 5:  # 刺激イベントっぽいもの
        stim_times = np.array(event.times).astype(float)
        print(f"\nFound stimulus events: {event.name}, n={len(stim_times)}")
        break

# パターン2: 刺激イベントがない場合、スパイク発火から推定
if stim_times is None:
    print("\n刺激イベントが見つかりません。発火パターンから推定します。")


# ========================================
# 2. 刺激構造の分析
# ========================================
if stim_times is not None and len(stim_times) > 0:
    
    stim_isi = np.diff(stim_times) * 1000  # ms
    
    print(f"\n=== 刺激構造 ===")
    print(f"総刺激数: {len(stim_times)}")
    print(f"刺激ISI 平均: {np.mean(stim_isi):.1f} ms")
    print(f"刺激ISI 中央値: {np.median(stim_isi):.1f} ms")
    print(f"刺激ISI 範囲: {np.min(stim_isi):.1f} - {np.max(stim_isi):.1f} ms")
    
    # Trial構造を検出
    long_gaps = np.where(stim_isi > 500)[0]  # 500ms以上のギャップ = trial間
    print(f"\nTrial間ギャップ数: {len(long_gaps)}")
    print(f"→ Trial数: {len(long_gaps) + 1}")
    
    if len(long_gaps) > 0:
        gap_durations = stim_isi[long_gaps]
        print(f"ITI: {np.mean(gap_durations):.0f} ± {np.std(gap_durations):.0f} ms")
    
    fig, axes = plt.subplots(2, 3, figsize=(16, 10))
    
    # 刺激ISIヒストグラム
    ax = axes[0, 0]
    ax.hist(stim_isi, bins=100, color='steelblue', edgecolor='black')
    ax.set_xlabel('Stim ISI (ms)')
    ax.set_ylabel('Count')
    ax.set_title('Stimulus ISI distribution')
    ax.axvline(x=100, color='red', linestyle='--', label='100ms (10Hz)')
    ax.legend()
    
    # PSTH (Peri-Stimulus Time Histogram)
    ax = axes[0, 1]
    
    # unit = result.units[0]  # ソーティング結果から
    # spike_times_sec = unit.spike_times
    # ↑ 実際にはresultのspike_timesを使う
    
    psth_bins = np.arange(-50, 150, 1)  # -50ms to +150ms, 1ms bins
    psth_counts = np.zeros(len(psth_bins) - 1)
    
    for stim_t in stim_times:
        # 各刺激に対するスパイクの相対時刻
        relative_times = (spike_times - stim_t) * 1000  # ms
        hist, _ = np.histogram(relative_times, bins=psth_bins)
        psth_counts += hist
    
    bin_centers = (psth_bins[:-1] + psth_bins[1:]) / 2
    ax.bar(bin_centers, psth_counts, width=1, color='steelblue', alpha=0.7)
    ax.axvline(x=0, color='red', linewidth=2, label='Stimulus')
    ax.set_xlabel('Time from stimulus (ms)')
    ax.set_ylabel('Spike count')
    ax.set_title('PSTH (all stimuli)')
    ax.legend()
    
    # Trial別のラスタープロット
    ax = axes[0, 2]
    
    # Trial開始時刻を検出
    trial_starts = [stim_times[0]]
    for gap_idx in long_gaps:
        trial_starts.append(stim_times[gap_idx + 1])
    trial_starts = np.array(trial_starts)
    
    for trial_i, trial_start in enumerate(trial_starts):
        # このtrialの刺激範囲
        trial_end = trial_start + 1.5  # 10stim × 100ms + margin
        
        # このtrial内のスパイク
        trial_spikes = spike_times[(spike_times >= trial_start) & 
                                    (spike_times < trial_end)]
        
        # 相対時刻
        relative = (trial_spikes - trial_start) * 1000  # ms
        ax.scatter(relative, np.ones_like(relative) * trial_i, 
                  s=2, c='black', marker='|')
        
        # 刺激タイミング
        trial_stims = stim_times[(stim_times >= trial_start) & 
                                  (stim_times < trial_end)]
        for st in trial_stims:
            ax.axvline(x=(st - trial_start) * 1000, color='red', alpha=0.2, linewidth=0.5)
    
    ax.set_xlabel('Time from trial start (ms)')
    ax.set_ylabel('Trial #')
    ax.set_title('Raster plot by trial')
    ax.set_xlim(-50, 1200)
    
    # 刺激番号別の応答率
    ax = axes[1, 0]
    
    response_per_stim = []
    for trial_i, trial_start in enumerate(trial_starts):
        trial_stims = stim_times[(stim_times >= trial_start) & 
                                  (stim_times < trial_start + 1.5)]
        
        for stim_i, stim_t in enumerate(trial_stims):
            # 刺激後 5-50ms にスパイクがあるか
            response_window = spike_times[(spike_times > stim_t + 0.005) & 
                                           (spike_times < stim_t + 0.050)]
            response_per_stim.append({
                'trial': trial_i,
                'stim_num': stim_i,
                'n_spikes': len(response_window)
            })
    
    import pandas as pd
    df = pd.DataFrame(response_per_stim)
    if len(df) > 0:
        response_by_stim = df.groupby('stim_num')['n_spikes'].mean()
        ax.bar(response_by_stim.index + 1, response_by_stim.values, 
               color='steelblue', edgecolor='black')
        ax.set_xlabel('Stimulus # in trial')
        ax.set_ylabel('Mean spikes per stimulus')
        ax.set_title('Response by stimulus position')
    
    # Autocorrelogram（刺激構造を除いた補正版）
    ax = axes[1, 1]
    
    # 各trial内でのautocorrelogramを計算
    from spike_sorting import compute_autocorrelogram
    
    all_ac_bins = None
    all_ac_counts = None
    
    for trial_start in trial_starts:
        trial_end = trial_start + 1.5
        trial_spikes = spike_times[(spike_times >= trial_start) & 
                                    (spike_times < trial_end)]
        
        if len(trial_spikes) > 2:
            bins_ac, ac = compute_autocorrelogram(trial_spikes, bin_size_ms=1.0, window_ms=120)
            if all_ac_bins is None:
                all_ac_bins = bins_ac
                all_ac_counts = ac
            else:
                all_ac_counts += ac
    
    if all_ac_bins is not None:
        ax.bar(all_ac_bins, all_ac_counts, width=1.0, color='steelblue', alpha=0.7)
        ax.axvline(x=0, color='gray', linestyle='--')
        ax.axvline(x=100, color='red', linestyle='--', alpha=0.7, label='100ms')
        ax.axvline(x=59, color='orange', linestyle='--', alpha=0.7, label='59ms')
        ax.set_xlabel('Lag (ms)')
        ax.set_ylabel('Count')
        ax.set_title('Autocorrelogram (within-trial only)')
        ax.legend()
    
    # パワースペクトル（trial内のみ）
    ax = axes[1, 2]
    
    # Trial内スパイクのみでスペクトル
    trial_spike_train = []
    for trial_start in trial_starts:
        trial_end = trial_start + 1.5
        trial_spikes = spike_times[(spike_times >= trial_start) & 
                                    (spike_times < trial_end)]
        # Trial内の相対時刻
        relative = trial_spikes - trial_start
        trial_spike_train.extend(relative.tolist())
    
    # 1.5秒のスパイクトレイン（全trial重ね合わせ）
    bin_ms = 1  # 1ms
    bins = np.arange(0, 1500, bin_ms)
    train, _ = np.histogram(np.array(trial_spike_train) * 1000, bins=bins)
    
    freqs = np.fft.rfftfreq(len(train), d=bin_ms / 1000)
    power = np.abs(np.fft.rfft(train))**2
    
    mask = (freqs > 1) & (freqs < 80)
    ax.plot(freqs[mask], power[mask], 'b-', linewidth=0.8)
    ax.axvline(x=10, color='red', linestyle='--', linewidth=2, label='10Hz')
    ax.axvline(x=20, color='red', linestyle='--', linewidth=1, alpha=0.5, label='20Hz')
    ax.axvline(x=17, color='orange', linestyle='--', linewidth=2, label='17Hz')
    ax.set_xlabel('Frequency (Hz)')
    ax.set_ylabel('Power')
    ax.set_title('Power spectrum (within-trial spikes)')
    ax.legend()
    
    plt.tight_layout()
    plt.savefig('stimulus_analysis.png', dpi=150)
    plt.show()

else:
    print("\n刺激タイミングが見つかりません。")
    print("PLXファイルのイベントチャンネルを確認してください。")
    print("または、以下のように手動で設定してください：")
    print("  stim_times = np.array([...])  # 秒単位")


=== Event channels ===
  EVT01: 90 events
    First 10: [17.0859   17.185875 17.28585  17.385825 17.4858   17.5858   17.685775
 17.78575  17.885725 17.985725]
    Time range: 17.086 - 98.374
  EVT02: 9 events
    First 10: [17.04585  27.124475 37.20415  47.2793   57.279725 67.355875 77.359525
 87.36135  97.434275]
    Time range: 17.046 - 97.434
  EVT03: 1242 events
    First 10: [6.329625 7.7129   7.786275 7.86595  7.946075 8.02575  8.10565  8.185775
 8.265675 8.34625 ]
    Time range: 6.330 - 106.894
  EVT04: 0 events
  EVT05: 0 events
  EVT06: 0 events
  EVT07: 0 events
  EVT08: 0 events
  EVT09: 0 events
  EVT10: 0 events
  EVT11: 0 events
  EVT12: 0 events
  EVT13: 0 events
  EVT14: 0 events
  EVT15: 0 events
  EVT16: 0 events
  EVT17: 0 events
  EVT18: 0 events
  EVT19: 0 events
  EVT20: 0 events
  EVT21: 0 events
  EVT22: 0 events
  EVT23: 0 events
  EVT24: 0 events
  EVT25: 0 events
  EVT26: 0 events
  EVT27: 0 events
  EVT28: 0 events
  EVT29: 0 events
  EVT30: 0 events
  EVT3