# Interactive Theta Analysis Notebook

This notebook provides an interactive environment for running theta oscillation analysis around Near Mistake (NM) events. 

**Key Benefits:**
- Load `all_eeg_data.pkl` once and keep in memory
- Run single session or multi-session analyses quickly
- Easy parameter modification
- Interactive results visualization

**Workflow:**
1. Run the setup cell to load data and import functions
2. Use the analysis cells to run analyses with different parameters
3. Modify parameters as needed and re-run

## 1. Setup and Data Loading

**Run this cell once to load all data into memory.**

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pickle
import os
import sys
import pandas as pd
from typing import Dict, List, Tuple, Optional, Union
from collections import defaultdict
import warnings

# Add src/core to path for imports
sys.path.append(os.path.join('src', 'core'))

# Import analysis functions
from nm_theta_analysis import analyze_session_nm_theta_roi, load_session_data
from nm_theta_multi_session import analyze_rat_multi_session, load_all_sessions_for_rat
from electrode_utils import get_channels, load_electrode_mappings, ROI_MAP

print("📦 Loading all EEG data into memory...")
print("This may take a moment but will speed up all subsequent analyses.")

# Load all data once
pkl_path = 'data/processed/all_eeg_data.pkl'
with open(pkl_path, 'rb') as f:
    all_eeg_data = pickle.load(f)

print(f"✅ Data loaded successfully!")
print(f"   Total sessions: {len(all_eeg_data)}")

# Load electrode mappings
electrode_mappings = load_electrode_mappings()
print(f"   Electrode mappings loaded for {len(electrode_mappings)} rats")

# Display available rats
rat_ids = set()
for session in all_eeg_data:
    rat_ids.add(session.get('rat_id', 'unknown'))

print(f"\n🐭 Available rats: {sorted(rat_ids)}")
print(f"🧠 Available ROIs: {list(ROI_MAP.keys())}")

# Display session summary
session_summary = {}
for i, session in enumerate(all_eeg_data):
    rat_id = session.get('rat_id', 'unknown')
    if rat_id not in session_summary:
        session_summary[rat_id] = []
    session_summary[rat_id].append({
        'index': i,
        'date': session.get('session_date', 'unknown'),
        'nm_events': len(session.get('nm_peak_times', [])),
        'eeg_shape': session['eeg'].shape if 'eeg' in session else 'N/A'
    })

print(f"\n📊 Session Summary:")
for rat_id, sessions in session_summary.items():
    print(f"   Rat {rat_id}: {len(sessions)} sessions, total NM events: {sum(s['nm_events'] for s in sessions)}")

📦 Loading all EEG data into memory...
This may take a moment but will speed up all subsequent analyses.
✅ Data loaded successfully!
   Total sessions: 306
   Electrode mappings loaded for 14 rats

🐭 Available rats: ['10501', '1055', '10592', '10593', '422', '441', '442', '531', '532', '9151', '9441', '9442', '9591', '9592']
🧠 Available ROIs: ['frontal', 'motor', 'ss', 'hippocampus', 'visual']

📊 Session Summary:
   Rat 10501: 18 sessions, total NM events: 2538
   Rat 1055: 23 sessions, total NM events: 3725
   Rat 10592: 6 sessions, total NM events: 1087
   Rat 10593: 12 sessions, total NM events: 2099
   Rat 422: 35 sessions, total NM events: 6585
   Rat 441: 30 sessions, total NM events: 5551
   Rat 442: 37 sessions, total NM events: 7960
   Rat 531: 36 sessions, total NM events: 5241
   Rat 532: 28 sessions, total NM events: 5787
   Rat 9151: 19 sessions, total NM events: 3200
   Rat 9441: 8 sessions, total NM events: 1611
   Rat 9442: 24 sessions, total NM events: 4618
   Rat 9591:

## 2. Quick Data Explorer

Use this cell to quickly explore specific sessions.

In [2]:
# === PARAMETERS ===
session_index = 180  # Change this to explore different sessions

# === EXPLORATION ===
session = all_eeg_data[session_index]
print(f"📋 Session {session_index} Details:")
print(f"   Rat ID: {session.get('rat_id', 'unknown')}")
print(f"   Date: {session.get('session_date', 'unknown')}")
print(f"   EEG shape: {session['eeg'].shape}")
print(f"   Sampling rate: ~{1/np.median(np.diff(session['eeg_time'].flatten())):.0f} Hz")
print(f"   Duration: {session['eeg_time'].flatten()[-1]:.1f} seconds")
print(f"   NM events: {len(session['nm_peak_times'])}")
print(f"   NM sizes: {np.unique(session['nm_sizes'])}")

# Show available channels for this rat
rat_id = session.get('rat_id', 'unknown')
if rat_id in electrode_mappings.index:
    print(f"\n🧠 Available ROI channels for rat {rat_id}:")
    for roi_name, roi_channels in ROI_MAP.items():
        try:
            channels = get_channels(rat_id, roi_name, electrode_mappings)
            print(f"   {roi_name}: {channels}")
        except:
            print(f"   {roi_name}: Not available")
else:
    print(f"\n⚠️  No electrode mapping found for rat {rat_id}")

📋 Session 180 Details:
   Rat ID: 531
   Date: 170820
   EEG shape: (32, 476189)
   Sampling rate: ~200 Hz
   Duration: 2340.6 seconds
   NM events: 31
   NM sizes: [1. 2. 3.]

⚠️  No electrode mapping found for rat 531


## 3. Single Session Analysis

Analyze theta oscillations for a single session. Modify parameters as needed.

In [4]:
%matplotlib qt

In [5]:
# === ANALYSIS PARAMETERS ===
session_index = 70                    # Which session to analyze
roi_specification = [2,3,5]      # ROI name or list of channels [1, 2, 3]
freq_range = (1, 45)                 # Frequency range in Hz
n_freqs = 30                         # Number of log-spaced frequencies
window_duration = 2.0                # Event window duration (±0.5s)
n_cycles_factor = 3.0                # Cycles per frequency factor
save_results = True                  # Whether to save results
show_plots = True                    # Whether to show plots

print(f"🔬 Running Single Session Analysis...")
print(f"   Session: {session_index}")
print(f"   ROI: {roi_specification}")
print(f"   Frequency range: {freq_range[0]}-{freq_range[1]} Hz ({n_freqs} frequencies)")
print(f"   Window: ±{window_duration/2:.1f}s around events")

# Get session data (no need to load from file!)
session_data = all_eeg_data[session_index]
rat_id = session_data.get('rat_id', 'unknown')
session_date = session_data.get('session_date', 'unknown')

print(f"   Rat ID: {rat_id}")
print(f"   Date: {session_date}")
print(f"   NM events: {len(session_data['nm_peak_times'])}")

# Run analysis
try:
    results = analyze_session_nm_theta_roi(
        session_data=session_data,
        roi_or_channels=roi_specification,
        freq_range=freq_range,
        n_freqs=n_freqs,
        window_duration=window_duration,
        n_cycles_factor=n_cycles_factor,
        save_path=f'results/single_session_{rat_id}_{session_date}' if save_results else None,
        mapping_df=electrode_mappings,
        show_plots=show_plots
    )
    
    print(f"\n✅ Analysis completed successfully!")
    print(f"   ROI channels used: {results['roi_channels']}")
    print(f"   Frequencies: {results['freqs'][0]:.2f} - {results['freqs'][-1]:.2f} Hz")
    print(f"   Event windows analyzed: {list(results['normalized_windows'].keys())}")
    
    # Store results for later use
    last_single_session_results = results
    
except Exception as e:
    print(f"❌ Analysis failed: {e}")
    import traceback
    traceback.print_exc()

🔬 Running Single Session Analysis...
   Session: 70
   ROI: [2, 3, 5]
   Frequency range: 1-45 Hz (30 frequencies)
   Window: ±1.0s around events
   Rat ID: 422
   Date: 100720
   NM events: 163
NM THETA ROI ANALYSIS
Step 1: Determining ROI channels
🔍 ELECTRODE MAPPING VERIFICATION
✓ Custom channel specification: [2, 3, 5]
✓ Resulting channel indices: [21, 25, 27]
Step 2: Computing ROI theta spectrogram (1-45 Hz)
Computing ROI theta spectrogram for 3 channels...
ROI channels: [21, 25, 27]
📊 PROCESSING VERIFICATION:
   Using 3 channels: [21, 25, 27]
   EEG data shape: (32, 407082)
   Each channel will be z-score normalized individually, then averaged
📊 Using 30 logarithmically spaced frequencies:
   Range: 1.00 - 45.00 Hz
   Frequencies: ['1.00', '1.14', '1.30', '1.48', '1.69', '1.93', '2.20', '2.51', '2.86', '3.26', '3.72', '4.24', '4.83', '5.51', '6.28', '7.16', '8.17', '9.31', '10.62', '12.11', '13.81', '15.75', '17.95', '20.47', '23.34', '26.62', '30.35', '34.61', '39.46', '45.00']


In [None]:
for channel_index in range(1, 33): 
    results = analyze_session_nm_theta_roi(
        session_data=session_data,
        roi_or_channels=[channel_index],
        freq_range=freq_range,
        n_freqs=n_freqs,
        window_duration=window_duration,
        n_cycles_factor=n_cycles_factor,
        save_path=f'results/single_session_{rat_id}_{session_date}_ch{channel_index}' if save_results else None,
        mapping_df=electrode_mappings,
        show_plots=show_plots
    )

## 4. Multi-Session Analysis

Analyze theta oscillations across all sessions for a specific rat.

In [None]:
# === MULTI-SESSION ANALYSIS PARAMETERS ===
rat_id = "10593"                       # Which rat to analyze
roi_specification = [1,3,5,26,28,30]        # ROI name or list of channels
freq_range = (1, 45)                  # Frequency range in Hz
n_freqs = 30                         # Number of log-spaced frequencies
window_duration = 2.0                # Event window duration
n_cycles_factor = 3.0                # Cycles per frequency factor
save_results = True                  # Whether to save results
show_plots = True                    # Whether to show plots

print(f"🔬 Running Multi-Session Analysis...")
print(f"   Rat ID: {rat_id}")
print(f"   ROI: {roi_specification}")
print(f"   Frequency range: {freq_range[0]}-{freq_range[1]} Hz ({n_freqs} frequencies)")
print(f"   Window: ±{window_duration/2:.1f}s around events")

# Find sessions for this rat
rat_sessions = []
for i, session in enumerate(all_eeg_data):
    if session.get('rat_id') == rat_id:
        rat_sessions.append((i, session))

print(f"   Found {len(rat_sessions)} sessions for rat {rat_id}")

if len(rat_sessions) == 0:
    print(f"❌ No sessions found for rat {rat_id}")
    print(f"Available rats: {sorted(rat_ids)}")
else:
    # Show session details
    total_events = 0
    for session_idx, session in rat_sessions:
        nm_events = len(session.get('nm_peak_times', []))
        total_events += nm_events
        print(f"     Session {session_idx}: {session.get('session_date', 'unknown')} - {nm_events} events")
    print(f"   Total NM events: {total_events}")
    
    # Run multi-session analysis using our pre-loaded data
    try:
        # We need to create a modified version that uses our pre-loaded data
        # Instead of loading from pickle file
        
        print(f"\n🚀 Starting multi-session analysis...")
        
        # Create a simple wrapper that uses our pre-loaded data
        def analyze_rat_multi_session_preloaded(rat_id, roi_or_channels, all_data, **kwargs):
            """Modified version that uses pre-loaded data."""
            # Filter sessions for this rat
            rat_sessions = []
            for i, session_data in enumerate(all_data):
                if session_data.get('rat_id') == rat_id:
                    rat_sessions.append((i, session_data))
            
            if not rat_sessions:
                raise ValueError(f"No sessions found for rat {rat_id}")
            
            # Analyze each session
            session_results = []
            for orig_session_idx, session_data in rat_sessions:
                print(f"Analyzing session {orig_session_idx}...")
                try:
                    result = analyze_session_nm_theta_roi(
                        session_data=session_data,
                        roi_or_channels=roi_or_channels,
                        mapping_df=electrode_mappings,
                        show_plots=True,  # Don't show individual session plots
                        **kwargs
                    )
                    session_results.append(result)
                except Exception as e:
                    print(f"⚠️  Session {orig_session_idx} failed: {e}")
                    continue
            
            if not session_results:
                raise ValueError("No sessions analyzed successfully")
            
            # Combine results (simplified version)
            combined_results = {
                'rat_id': rat_id,
                'roi_or_channels': roi_or_channels,
                'session_results': session_results,
                'n_sessions': len(session_results),
                'total_events': sum(len(r['normalized_windows']) for r in session_results)
            }
            
            return combined_results
        
        # Run the analysis
        multi_results = analyze_rat_multi_session_preloaded(
            rat_id=rat_id,
            roi_or_channels=roi_specification,
            all_data=all_eeg_data,
            freq_range=freq_range,
            n_freqs=n_freqs,
            window_duration=window_duration,
            n_cycles_factor=n_cycles_factor
        )
        
        print(f"\n✅ Multi-session analysis completed successfully!")
        print(f"   Rat: {multi_results['rat_id']}")
        print(f"   Sessions analyzed: {multi_results['n_sessions']}")
        print(f"   Total events: {multi_results['total_events']}")
        
        # Store results for later use
        last_multi_session_results = multi_results
        
        # Show summary statistics
        if multi_results['session_results']:
            roi_channels = multi_results['session_results'][0]['roi_channels']
            freqs = multi_results['session_results'][0]['freqs']
            print(f"   ROI channels: {roi_channels}")
            print(f"   Frequency range: {freqs[0]:.2f} - {freqs[-1]:.2f} Hz")
        
    except Exception as e:
        print(f"❌ Multi-session analysis failed: {e}")
        import traceback
        traceback.print_exc()

## 5. Parameter Testing

Different ROI or frequencies for one session

In [None]:
# === PARAMETER TESTING ===
test_session = 0
test_roi = 'hippocampus'  # Change this to test different ROIs

# Test different frequency ranges
freq_ranges_to_test = [
    (3, 8),    # Classic theta
    (2, 10),   # Extended theta
    (4, 12),   # Theta + low alpha
]

print(f"🧪 Testing different frequency ranges on session {test_session} with ROI '{test_roi}'")

session_data = all_eeg_data[test_session]
print(f"Session: Rat {session_data.get('rat_id')}, {len(session_data['nm_peak_times'])} NM events\n")

for i, freq_range in enumerate(freq_ranges_to_test):
    print(f"📊 Test {i+1}: {freq_range[0]}-{freq_range[1]} Hz")
    
    try:
        # Quick analysis without saving or plotting
        results = analyze_session_nm_theta_roi(
            session_data=session_data,
            roi_or_channels=test_roi,
            freq_range=freq_range,
            n_freqs=20,
            window_duration=1.0,
            mapping_df=electrode_mappings,
            show_plots=False,
            save_path=None
        )
        
        # Print summary
        roi_channels = results['roi_channels']
        freqs = results['freqs']
        windows = results['normalized_windows']
        
        print(f"   ✅ Success: {len(roi_channels)} channels, {len(freqs)} frequencies")
        print(f"      ROI channels: {roi_channels}")
        print(f"      Windows: {list(windows.keys())}")
        
        # Show some power statistics
        all_power = np.concatenate([w.flatten() for w in windows.values()])
        print(f"      Z-score range: {all_power.min():.2f} to {all_power.max():.2f}")
        print(f"      Z-score mean±std: {all_power.mean():.2f} ± {all_power.std():.2f}")
        
    except Exception as e:
        print(f"   ❌ Failed: {e}")
    
    print()

## 7. Utility Functions

Helper functions for common tasks.

In [None]:
def quick_session_info(session_idx):
    """Get quick info about a session."""
    if session_idx >= len(all_eeg_data):
        print(f"❌ Session {session_idx} out of range (0-{len(all_eeg_data)-1})")
        return
    
    session = all_eeg_data[session_idx]
    print(f"Session {session_idx}:")
    print(f"  Rat: {session.get('rat_id', 'unknown')}")
    print(f"  Date: {session.get('session_date', 'unknown')}")
    print(f"  Duration: {session['eeg_time'].flatten()[-1]:.1f}s")
    print(f"  NM events: {len(session['nm_peak_times'])}")
    print(f"  NM sizes: {np.unique(session['nm_sizes'])}")

def list_sessions_for_rat(rat_id):
    """List all sessions for a specific rat."""
    sessions = []
    for i, session in enumerate(all_eeg_data):
        if session.get('rat_id') == rat_id:
            sessions.append({
                'index': i,
                'date': session.get('session_date', 'unknown'),
                'nm_events': len(session.get('nm_peak_times', []))
            })
    
    if sessions:
        print(f"Sessions for rat {rat_id}:")
        for s in sessions:
            print(f"  [{s['index']}] {s['date']} - {s['nm_events']} events")
    else:
        print(f"No sessions found for rat {rat_id}")
    
    return sessions

def test_roi_channels(rat_id, roi_name):
    """Test which channels are available for a ROI."""
    try:
        channels = get_channels(rat_id, roi_name, electrode_mappings)
        print(f"ROI '{roi_name}' for rat {rat_id}: {channels}")
        return channels
    except Exception as e:
        print(f"ROI '{roi_name}' for rat {rat_id}: Not available ({e})")
        return None

# Example usage:
print("🛠️  Utility functions loaded. Examples:")
print("   quick_session_info(0)")
print("   list_sessions_for_rat(10501)")
print("   test_roi_channels(10501, 'frontal')")

# Quick demo:
if len(all_eeg_data) > 0:
    quick_session_info(0)

# RAT ID

In [11]:
from src.core.nm_theta_multi_session import (
    analyze_rat_multi_session,
    plot_multi_session_results,
    save_multi_session_results
)

In [None]:
# Process multiple rats
rat_ids = ["531"]  # Add all your rat IDs here
roi_specification = 'frontal'

for rat_id in rat_ids:
    print(f"\n=== Processing Rat {rat_id} ===")
    
    # Run multi-session analysis using pre-loaded data
    multi_results = analyze_rat_multi_session_preloaded(
        rat_id=rat_id,
        roi_or_channels=roi_specification,
        all_data=all_eeg_data,
        freq_range=freq_range,
        n_freqs=n_freqs,
        window_duration=window_duration,
        n_cycles_factor=n_cycles_factor
    )
    
    # Save results using organized structure
    save_path = f'results/multi_session/rat_{rat_id}'
    save_multi_session_results(multi_results, save_path)
    
    # Plot averaged results (but not individual sessions)
    plot_multi_session_results(multi_results, save_path)
    
    print(f"✅ Completed analysis for rat {rat_id}")

In [None]:
# Process multiple rats with better memory management
rat_ids = ["531"]  # Add other rat IDs as needed
roi_specification = 'frontal'
save_base_path = 'results/multi_session'

for rat_id in rat_ids:
    print(f"\n=== Processing Rat {rat_id} ===")
    save_path = f'{save_base_path}/rat_{rat_id}'
    
    # Find sessions for this rat
    rat_sessions = []
    for i, session in enumerate(all_eeg_data):
        if session.get('rat_id') == rat_id:
            rat_sessions.append((i, session))
    
    # Process each session individually but don't keep all results in memory
    print(f"Found {len(rat_sessions)} sessions for rat {rat_id}")
    session_results = []
    session_metadata = []
    
    for session_idx, (orig_idx, session_data) in enumerate(rat_sessions):
        print(f"Analyzing session {orig_idx}...")
        try:
            # Create a session-specific save path
            session_save_path = f'{save_path}/session_{orig_idx}'
            
            # Analyze this session
            result = analyze_session_nm_theta_roi(
                session_data=session_data,
                roi_or_channels=roi_specification,
                freq_range=freq_range,
                n_freqs=n_freqs,
                window_duration=window_duration,
                n_cycles_factor=n_cycles_factor,
                save_path=session_save_path,
                show_plots=False  # Don't plot individual sessions
            )
            
            # Save individual session results immediately
            with open(f'{session_save_path}/session_result.pkl', 'wb') as f:
                pickle.dump(result, f)
                
            # Store only metadata and minimal results for averaging
            light_result = {
                'roi_channels': result['roi_channels'],
                'freqs': result['freqs'],
                'roi_specification': result['roi_specification'],
                'normalized_windows': result['normalized_windows']
            }
            session_results.append(light_result)
            
            # Store metadata
            metadata = {
                'original_session_index': orig_idx,
                'session_date': session_data.get('session_date', 'unknown'),
                'rat_id': session_data.get('rat_id'),
                'roi_channels': result['roi_channels'],
                'total_nm_events': sum(data['n_events'] for data in result['normalized_windows'].values()),
                'nm_sizes': list(result['normalized_windows'].keys())
            }
            session_metadata.append(metadata)
            
            # Explicitly free memory after each session
            del result
            import gc
            gc.collect()
            
            print(f"✓ Session {session_idx + 1} completed successfully")
        
        except Exception as e:
            print(f"⚠️ Error processing session {orig_idx}: {e}")
            continue
    
    # Now average results across sessions
    if session_results:
        # Import the necessary functions
        from src.core.nm_theta_multi_session import average_session_results, save_multi_session_results, plot_multi_session_results
        
        # Average the results
        print("\nAveraging results across sessions...")
        averaged_results = average_session_results(session_results, session_metadata, rat_id)
        
        # Save multi-session results
        print("Saving multi-session results...")
        save_multi_session_results(averaged_results, save_path)
        
        # Plot multi-session results
        print("Plotting multi-session results...")
        plot_multi_session_results(averaged_results, save_path)
        
        print(f"✅ Completed multi-session analysis for rat {rat_id}")
    else:
        print(f"❌ No successful sessions to average for rat {rat_id}")