In [None]:
# ============================================================================
# COMPARISON: FFT CONVOLUTION vs SIMPLE EXPONENTIAL FILTER
# ============================================================================

print("\n[BONUS] Comparing filtering methods...\n")

# Import both methods
from experiments.experiment_utils import apply_exponential_filter
from analysis.encoding_analysis import filter_spikes_exp_kernel

# Create exponential kernel for FFT method
tau = exp.tau_syn
dt = exp.dt
kernel_fft = np.exp(-np.arange(0, 5 * tau, dt) / tau)

print(f"Kernel length: {len(kernel_fft)} timesteps ({len(kernel_fft)*dt:.1f} ms)")
print(f"Tau: {tau} ms")

# Select one trial for comparison
trial_idx = 0
spike_times_trial = all_trial_results[trial_idx]['spike_times']

# Convert to spike matrix
n_timesteps = int(exp.stimulus_duration / dt)
spike_matrix = spikes_to_matrix(spike_times_trial, n_timesteps, exp.n_neurons, dt)

print(f"\nSpike matrix shape: {spike_matrix.shape}")
print(f"Total spikes: {spike_matrix.sum()}")

# Method 1: Simple exponential filter (task experiments)
print("\nMethod 1: Simple exponential filter...")
start_time = time.time()
filtered_simple = apply_exponential_filter(spike_matrix, tau, dt)
time_simple = time.time() - start_time
print(f"  Time: {time_simple*1000:.2f} ms")

# Method 2: FFT convolution (encoding experiments)
print("\nMethod 2: FFT convolution...")
start_time = time.time()
spike_matrix_3d = spike_matrix[np.newaxis, :, :]  # Add trial dimension
filtered_fft = filter_spikes_exp_kernel(spike_matrix_3d, kernel_fft)[0]  # Remove trial dim
time_fft = time.time() - start_time
print(f"  Time: {time_fft*1000:.2f} ms")

print(f"\nSpeedup: {time_simple/time_fft:.2f}x {'(FFT faster)' if time_fft < time_simple else '(Simple faster)'}")

# ============================================================================
# Visualize the difference
# ============================================================================

fig, axes = plt.subplots(3, 2, figsize=(16, 12))
fig.suptitle('Comparison: Simple Exponential Filter vs FFT Convolution', fontsize=16, fontweight='bold')

# Select a few neurons to show
neurons_to_show = [0, 10, 20, 50]

for i, neuron_idx in enumerate(neurons_to_show):
    if i >= 4:
        break
    
    row = i // 2
    col = i % 2
    ax = axes[row, col] if row < 2 else None
    
    if ax is not None:
        # Plot both filtered traces
        time_axis = np.arange(n_timesteps) * dt
        
        ax.plot(time_axis, filtered_simple[:, neuron_idx], 
                label='Simple Exponential', alpha=0.7, linewidth=2)
        ax.plot(time_axis, filtered_fft[:, neuron_idx], 
                label='FFT Convolution', alpha=0.7, linewidth=2, linestyle='--')
        
        # Mark spike times
        spike_mask = spike_matrix[:, neuron_idx] > 0
        spike_times_neuron = time_axis[spike_mask]
        ax.scatter(spike_times_neuron, np.zeros_like(spike_times_neuron), 
                  color='red', marker='|', s=100, label='Spikes', zorder=3)
        
        ax.set_xlabel('Time (ms)')
        ax.set_ylabel('Filtered Activity')
        ax.set_title(f'Neuron {neuron_idx}')
        ax.legend()
        ax.grid(alpha=0.3)

# Show difference heatmap
ax_diff = axes[2, 0]
difference = filtered_simple - filtered_fft
im = ax_diff.imshow(difference.T[:50], aspect='auto', cmap='RdBu_r', 
                    vmin=-np.abs(difference).max(), vmax=np.abs(difference).max(),
                    interpolation='nearest')
ax_diff.set_xlabel('Time (ms)')
ax_diff.set_ylabel('Neuron (first 50)')
ax_diff.set_title('Difference (Simple - FFT)')
plt.colorbar(im, ax=ax_diff, label='Difference')

# Show statistics
ax_stats = axes[2, 1]
ax_stats.axis('off')

# Compute statistics
mae = np.mean(np.abs(difference))
rmse = np.sqrt(np.mean(difference**2))
max_diff = np.max(np.abs(difference))
correlation = np.corrcoef(filtered_simple.flatten(), filtered_fft.flatten())[0, 1]

stats_text = f"""
Comparison Statistics:

Mean Absolute Error: {mae:.6f}
Root Mean Square Error: {rmse:.6f}
Max Absolute Difference: {max_diff:.6f}
Correlation: {correlation:.6f}

Simple Filter Range: [{filtered_simple.min():.3f}, {filtered_simple.max():.3f}]
FFT Filter Range: [{filtered_fft.min():.3f}, {filtered_fft.max():.3f}]

Computation Time:
  Simple: {time_simple*1000:.2f} ms
  FFT: {time_fft*1000:.2f} ms
  Speedup: {time_simple/time_fft:.2f}x
"""

ax_stats.text(0.1, 0.5, stats_text, fontsize=12, family='monospace',
             verticalalignment='center')

plt.tight_layout()
plt.show()

# ============================================================================
# Compare readout performance with both methods
# ============================================================================

print("\n[BONUS] Training readouts with both filtering methods...\n")

# Use FFT-filtered traces for training
traces_all_fft = np.array([
    filter_spikes_exp_kernel(
        spikes_to_matrix(r['spike_times'], n_timesteps, exp.n_neurons, dt)[np.newaxis, :, :],
        kernel_fft
    )[0]
    for r in all_trial_results
])

print(f"FFT traces shape: {traces_all_fft.shape}")
print(f"Simple traces shape: {traces_all.shape}")

# Train with FFT filtering
from experiments.experiment_utils import train_task_readout, predict_task_readout, evaluate_temporal_task

# Use fold 0 for comparison
train_idx = fold_results[0]['train_idx']
test_idx = fold_results[0]['test_idx']

X_train_simple = traces_all[train_idx]
X_test_simple = traces_all[test_idx]

X_train_fft = traces_all_fft[train_idx]
X_test_fft = traces_all_fft[test_idx]

y_train = ground_truth_all[train_idx]
y_test = ground_truth_all[test_idx]

# Train both readouts
print("Training with simple filter...")
W_simple = train_task_readout(X_train_simple, y_train, exp.lambda_reg)
y_pred_simple = predict_task_readout(X_test_simple, W_simple)
metrics_simple = evaluate_temporal_task(y_pred_simple, y_test)

print("Training with FFT filter...")
W_fft = train_task_readout(X_train_fft, y_train, exp.lambda_reg)
y_pred_fft = predict_task_readout(X_test_fft, W_fft)
metrics_fft = evaluate_temporal_task(y_pred_fft, y_test)

# Compare performance
print("\nReadout Performance Comparison:")
print(f"  Simple Filter:")
print(f"    RMSE: {metrics_simple['rmse_mean']:.4f}")
print(f"    R²:   {metrics_simple['r2_mean']:.4f}")
print(f"    Corr: {metrics_simple['correlation_mean']:.4f}")
print(f"  FFT Convolution:")
print(f"    RMSE: {metrics_fft['rmse_mean']:.4f}")
print(f"    R²:   {metrics_fft['r2_mean']:.4f}")
print(f"    Corr: {metrics_fft['correlation_mean']:.4f}")

# Plot comparison
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

metrics_comparison = {
    'RMSE': [metrics_simple['rmse_mean'], metrics_fft['rmse_mean']],
    'R² Score': [metrics_simple['r2_mean'], metrics_fft['r2_mean']],
    'Correlation': [metrics_simple['correlation_mean'], metrics_fft['correlation_mean']]
}

for ax, (name, values) in zip(axes, metrics_comparison.items()):
    ax.bar(['Simple', 'FFT'], values, alpha=0.8, color=['C0', 'C1'])
    ax.set_ylabel(name)
    ax.set_title(f'{name} Comparison')
    ax.grid(axis='y', alpha=0.3)
    
    # Add values on bars
    for i, v in enumerate(values):
        ax.text(i, v, f'{v:.4f}', ha='center', va='bottom')

plt.tight_layout()
plt.show()

print("\n✓ Filtering method comparison complete!")