# Speed Filtering Comparison

This notebook compares two approaches to handling speed filtering:

1. **Filtered model (OLD)**: Removes low-speed data points, creating temporal discontinuities
2. **Continuous model (NEW)**: Keeps all data points, uses speed-dependent observation noise

**Expected outcomes:**
- Continuous model should have much better fit (higher R², lower RMSE)
- α (gain parameter) should be closer to 1.0 in continuous model
- β_speed > 0 confirms that speed-dependent noise is needed

## Setup

In [None]:
# Standard imports
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Add parent directory to path
import sys
sys.path.append('../..')

# Import heading model modules
from heading_model import data_preprocessing as dp
from heading_model import bayesian_model as bm
from heading_model import visualization as viz

# Plotting settings
plt.style.use('default')
plt.rcParams['axes.grid'] = True
plt.rcParams['grid.alpha'] = 0.3
plt.rcParams['figure.dpi'] = 100

%matplotlib inline

## Load Data with Different Filter Modes

We'll use a small subset (3 animals) for fast testing.

In [None]:
# Configuration
TEST_MODE = True  # Set to False to run on full dataset
MAX_ANIMALS = 3 if TEST_MODE else None
CONDITION = 'all_light'
MIN_SPEED = 2.0  # cm/s

print(f"Running in {'TEST' if TEST_MODE else 'FULL'} mode")
print(f"Condition: {CONDITION}")
print(f"Speed threshold: {MIN_SPEED} cm/s")

### Load data with OLD approach (remove low-speed points)

In [None]:
print("\n" + "="*80)
print("LOADING DATA: FILTERED (OLD APPROACH)")
print("="*80)

df_filtered = dp.load_and_filter_data(
    conditions=[CONDITION],
    sessions=dp.DEFAULT_USEABLE_SESSIONS,
    min_speed=MIN_SPEED,
    filter_mode='remove',  # OLD: Remove low-speed points
    verbose=True
)

data_dict_filtered = dp.structure_hierarchical_data(
    df_filtered,
    smooth_sigma=1.0,
    min_trial_length=50,
    max_animals=MAX_ANIMALS,
    verbose=True
)

summary_filtered = dp.get_data_summary(data_dict_filtered)
print("\nData Summary (Filtered):")
print(summary_filtered.to_string(index=False))

### Load data with NEW approach (keep all points)

In [None]:
print("\n" + "="*80)
print("LOADING DATA: CONTINUOUS (NEW APPROACH)")
print("="*80)

df_continuous = dp.load_and_filter_data(
    conditions=[CONDITION],
    sessions=dp.DEFAULT_USEABLE_SESSIONS,
    min_speed=MIN_SPEED,
    filter_mode='none',  # NEW: Keep all points
    verbose=True
)

data_dict_continuous = dp.structure_hierarchical_data(
    df_continuous,
    smooth_sigma=1.0,
    min_trial_length=50,
    max_animals=MAX_ANIMALS,
    verbose=True
)

summary_continuous = dp.get_data_summary(data_dict_continuous)
print("\nData Summary (Continuous):")
print(summary_continuous.to_string(index=False))

## Compare Data Sizes

In [None]:
print("\n" + "="*80)
print("DATA SIZE COMPARISON")
print("="*80)

n_obs_filtered = sum(len(data_dict_filtered['theta_obs'][a][t]) 
                     for a in data_dict_filtered['animals'] 
                     for t in data_dict_filtered['trials_per_animal'][a])

n_obs_continuous = sum(len(data_dict_continuous['theta_obs'][a][t]) 
                       for a in data_dict_continuous['animals'] 
                       for t in data_dict_continuous['trials_per_animal'][a])

print(f"Filtered observations:   {n_obs_filtered:,}")
print(f"Continuous observations: {n_obs_continuous:,}")
print(f"Difference:              {n_obs_continuous - n_obs_filtered:,} ({(n_obs_continuous/n_obs_filtered - 1)*100:.1f}% more data)")

## Visualize Heading Continuity

Check for discontinuities in true heading plots.

In [None]:
# Pick a random trial to visualize
animal = data_dict_filtered['animals'][0]
trial_id = data_dict_filtered['trials_per_animal'][animal][0]

fig, axes = plt.subplots(2, 1, figsize=(12, 8), sharex=True)

# Filtered
t_filt = data_dict_filtered['time'][animal][trial_id]
theta_filt = data_dict_filtered['theta_obs'][animal][trial_id]
axes[0].plot(t_filt, theta_filt, 'o-', markersize=3, alpha=0.7)
axes[0].set_ylabel('Decoded Heading (rad)')
axes[0].set_title(f'FILTERED: {trial_id}\n(May show discontinuities from speed filtering)')
axes[0].axhline(np.pi, color='red', linestyle='--', alpha=0.3)
axes[0].axhline(-np.pi, color='red', linestyle='--', alpha=0.3)
axes[0].grid(True, alpha=0.3)

# Continuous
t_cont = data_dict_continuous['time'][animal][trial_id]
theta_cont = data_dict_continuous['theta_obs'][animal][trial_id]
speed_cont = data_dict_continuous['speed'][animal][trial_id]

# Color by speed
sc = axes[1].scatter(t_cont, theta_cont, c=speed_cont, s=10, alpha=0.7, cmap='viridis')
axes[1].set_ylabel('Decoded Heading (rad)')
axes[1].set_xlabel('Time (s)')
axes[1].set_title('CONTINUOUS: Same trial with all data points (colored by speed)')
axes[1].axhline(np.pi, color='red', linestyle='--', alpha=0.3)
axes[1].axhline(-np.pi, color='red', linestyle='--', alpha=0.3)
axes[1].grid(True, alpha=0.3)

# Colorbar
cbar = plt.colorbar(sc, ax=axes[1])
cbar.set_label('Speed (cm/s)')

plt.tight_layout()
plt.show()

print(f"\nFiltered data points: {len(t_filt)}")
print(f"Continuous data points: {len(t_cont)}")
print(f"Low-speed points removed: {len(t_cont) - len(t_filt)} ({(1 - len(t_filt)/len(t_cont))*100:.1f}%)")

## Fit Model 1: Filtered (OLD)

⚠️ This will take 10-20 minutes even with 3 animals.

In [None]:
print("\n" + "="*80)
print("FITTING MODEL: FILTERED (OLD APPROACH)")
print("="*80)

model_filtered = bm.HeadingHierarchicalModel(data_dict_filtered, condition_name='filtered')

model_filtered.build_model(
    alpha_prior_mean=1.0,
    alpha_prior_sd=0.5,
    gamma_prior_sd=0.1,
    theta0_prior_sd=1.0,
    obs_noise_beta=0.1,
    use_noncentered=False
)

trace_filtered = model_filtered.fit(
    draws=1000,
    tune=500,
    chains=4,
    target_accept=0.9,
    cores=4
)

In [None]:
# Check convergence
summary_filt = model_filtered.check_convergence()

In [None]:
# Compute fit statistics
fit_stats_filtered = model_filtered.compute_model_fit_stats()

print("\n" + "="*80)
print("FILTERED MODEL RESULTS")
print("="*80)
results_filt = model_filtered.extract_parameters()
print(f"\nPopulation parameters:")
print(f"  mu_alpha:    {results_filt['population']['mu_alpha']['mean']:.4f}")
print(f"  sigma_base:  {results_filt['population']['sigma_base']['mean']:.4f}")
print(f"  beta_speed:  {results_filt['population']['beta_speed']['mean']:.4f}")
print(f"\nModel fit:")
print(f"  R²:          {fit_stats_filtered['overall']['r_squared']:.4f}")
print(f"  RMSE:        {fit_stats_filtered['overall']['rmse']:.4f} rad")
print(f"  MAE:         {fit_stats_filtered['overall']['mae']:.4f} rad")

## Fit Model 2: Continuous (NEW)

⚠️ This will take 10-20 minutes even with 3 animals.

In [None]:
print("\n" + "="*80)
print("FITTING MODEL: CONTINUOUS (NEW APPROACH)")
print("="*80)

model_continuous = bm.HeadingHierarchicalModel(data_dict_continuous, condition_name='continuous')

model_continuous.build_model(
    alpha_prior_mean=1.0,
    alpha_prior_sd=0.5,
    gamma_prior_sd=0.1,
    theta0_prior_sd=1.0,
    obs_noise_beta=0.1,
    use_noncentered=False
)

trace_continuous = model_continuous.fit(
    draws=1000,
    tune=500,
    chains=4,
    target_accept=0.9,
    cores=4
)

In [None]:
# Check convergence
summary_cont = model_continuous.check_convergence()

In [None]:
# Compute fit statistics
fit_stats_continuous = model_continuous.compute_model_fit_stats()

print("\n" + "="*80)
print("CONTINUOUS MODEL RESULTS")
print("="*80)
results_cont = model_continuous.extract_parameters()
print(f"\nPopulation parameters:")
print(f"  mu_alpha:    {results_cont['population']['mu_alpha']['mean']:.4f}")
print(f"  sigma_base:  {results_cont['population']['sigma_base']['mean']:.4f}")
print(f"  beta_speed:  {results_cont['population']['beta_speed']['mean']:.4f}")
print(f"\nModel fit:")
print(f"  R²:          {fit_stats_continuous['overall']['r_squared']:.4f}")
print(f"  RMSE:        {fit_stats_continuous['overall']['rmse']:.4f} rad")
print(f"  MAE:         {fit_stats_continuous['overall']['mae']:.4f} rad")

## Direct Comparison

In [None]:
print("\n" + "="*80)
print("SIDE-BY-SIDE COMPARISON")
print("="*80)

comparison_df = pd.DataFrame({
    'Metric': ['R²', 'RMSE (rad)', 'MAE (rad)', 'mu_alpha', 'sigma_base', 'beta_speed'],
    'Filtered (OLD)': [
        fit_stats_filtered['overall']['r_squared'],
        fit_stats_filtered['overall']['rmse'],
        fit_stats_filtered['overall']['mae'],
        results_filt['population']['mu_alpha']['mean'],
        results_filt['population']['sigma_base']['mean'],
        results_filt['population']['beta_speed']['mean']
    ],
    'Continuous (NEW)': [
        fit_stats_continuous['overall']['r_squared'],
        fit_stats_continuous['overall']['rmse'],
        fit_stats_continuous['overall']['mae'],
        results_cont['population']['mu_alpha']['mean'],
        results_cont['population']['sigma_base']['mean'],
        results_cont['population']['beta_speed']['mean']
    ]
})

comparison_df['Improvement'] = ((comparison_df['Continuous (NEW)'] - comparison_df['Filtered (OLD)']) / 
                                 comparison_df['Filtered (OLD)'].abs() * 100)

print("\n", comparison_df.to_string(index=False))

print("\n" + "="*80)
print("INTERPRETATION")
print("="*80)
print(f"✓ R² improvement: {comparison_df.loc[0, 'Improvement']:.1f}%")
print(f"✓ RMSE reduction: {-comparison_df.loc[1, 'Improvement']:.1f}%")
print(f"✓ alpha closer to 1.0: {results_cont['population']['mu_alpha']['mean']:.3f} vs {results_filt['population']['mu_alpha']['mean']:.3f}")
print(f"✓ beta_speed > 0: {results_cont['population']['beta_speed']['mean']:.3f} (confirms speed-dependent noise is needed)")

## Visualize Parameter Posteriors

In [None]:
import arviz as az

fig, axes = plt.subplots(2, 3, figsize=(15, 10))

params = ['mu_alpha', 'mu_gamma', 'sigma_base', 'beta_speed', 'sigma_alpha', 'sigma_gamma']
titles = ['Gain (α)', 'Drift (γ)', 'Base Noise (σ_base)', 'Speed Effect (β_speed)', 'σ_α', 'σ_γ']

for idx, (param, title) in enumerate(zip(params, titles)):
    ax = axes[idx // 3, idx % 3]
    
    # Plot posteriors
    az.plot_posterior(trace_filtered, var_names=[param], ax=ax, color='C0', 
                      label='Filtered', kind='hist')
    az.plot_posterior(trace_continuous, var_names=[param], ax=ax, color='C1', 
                      label='Continuous', kind='hist')
    
    ax.set_title(title)
    ax.legend(['Filtered (OLD)', 'Continuous (NEW)'])

plt.tight_layout()
plt.show()

## Visualize Example Fits

In [None]:
# Compare fits for a random trial
fig = viz.plot_model_fits(model_filtered, n_trials=4, random_seed=42, figsize=(14, 10))
fig.suptitle('FILTERED MODEL: Example Fits', fontsize=16, y=1.00)
plt.show()

fig = viz.plot_model_fits(model_continuous, n_trials=4, random_seed=42, figsize=(14, 10))
fig.suptitle('CONTINUOUS MODEL: Example Fits', fontsize=16, y=1.00)
plt.show()

## Conclusion

**Key Findings:**

1. **Data continuity**: The continuous model uses more data and avoids temporal gaps
2. **Model fit**: The continuous model should show substantially better fit (higher R², lower RMSE)
3. **Parameter interpretability**: α should be closer to 1.0, indicating animals integrate angular velocity
4. **Speed-dependent noise**: β_speed > 0 confirms that observation noise depends on speed

**Next Steps:**
- Run full analysis with all animals if test results are promising
- Compare across conditions (light vs dark, search vs homing)
- Investigate any remaining discrepancies