# Kalman Filter Tuning Playground

Use this notebook to tune the parameters of the Kalman Filter on individual tracks.  
You can adjust the Process Noise (Q) and Measurement Noise (R) matrices to see how they affect the tracking performance.

In [None]:
import os
import sys
import numpy as np
import glob
import plotly.graph_objects as go
from ipywidgets import interact, FloatSlider

# Add project root to path to import local modules
sys.path.append(os.getcwd())
sys.path.append(os.path.join(os.getcwd(), 'AI_model_method'))

# Import your Kalman Filter
try:
    from AI_model_method.kalman_tracking import TrackKalmanFilter
except ImportError:
    # Try relative if running from different dir
    from kalman_tracking import TrackKalmanFilter

print("Imports successful.")

## 1. Load Data Helper
Define functions to load tracks from the .npz files.

In [None]:
def load_track_raw(file_path):
    """
    Loads a track file. Adapts to different NPZ structures.
    Returns: positions (N,3), truth_direction (3,)
    """
    with np.load(file_path, allow_pickle=True) as data:
        # Try to find coordinates
        if 'hits' in data:
            coords = data['hits']
        elif 'coords' in data: 
            coords = data['coords']
            # If it's a processed file, coords might be (N, 4) with batch info if collated, 
            # or just (N, 3).
            if coords.shape[1] == 4:
                coords = coords[:, 1:]
        elif 'x' in data and 'y' in data and 'z' in data:
            coords = np.stack([data['x'], data['y'], data['z']], axis=1)
        else:
            raise ValueError(f"Could not find coordinates in {file_path}. Keys: {list(data.keys())}")
            
        # Try to find Truth Direction (from companion file or inside)
        truth_dir = np.array([0, 0, 1]) # Default
        
        # Check if there is a _truth file
        truth_path = file_path.replace('.npz', '_truth.npz')
        if os.path.exists(truth_path):
            with np.load(truth_path, allow_pickle=True) as tdata:
                if 'direction' in tdata:
                    truth_dir = tdata['direction']
        
    return coords, truth_dir

def get_sample_files(directory, limit=10):
    search_p = os.path.join(directory, "**", "*.npz")
    files = glob.glob(search_p, recursive=True)
    # Filter out _truth files
    files = [f for f in files if not f.endswith('_truth.npz') and 'dataset_' not in f]
    return files[:limit]

## 2. Visualization Helper (3D)

In [None]:
def plot_kalman_result(coords, true_dir, kf_smooth, kf_dir, title="Kalman Filter Result"):
    """
    Plots the original points, the smoothed KF path, and directions.
    """
    data = []
    
    # 1. Original Points (Scatter)
    # Subsample if too many for speed
    if len(coords) > 10000:
        idx = np.random.choice(len(coords), 5000, replace=False)
        p = coords[idx]
    else:
        p = coords
        
    trace_pts = go.Scatter3d(x=p[:,0], y=p[:,1], z=p[:,2],
                             mode='markers',
                             marker=dict(size=2, color='blue', opacity=0.3),
                             name='Raw Hits')
    data.append(trace_pts)
    
    # 2. KF Smoothed Path (Line)
    if kf_smooth is not None and len(kf_smooth) > 1:
        # kf_smooth is (N, 3)
        trace_kf = go.Scatter3d(x=kf_smooth[:,0], y=kf_smooth[:,1], z=kf_smooth[:,2],
                                mode='lines+markers',
                                marker=dict(size=3, color='magenta'),
                                line=dict(color='magenta', width=4),
                                name='KF Smoothed')
        data.append(trace_kf)
        
        # Visualize Predicted Direction from Start
        start_pt = kf_smooth[0]
        scale = 50.0
        end_pt = start_pt + kf_dir * scale
        
        trace_pred_dir = go.Scatter3d(x=[start_pt[0], end_pt[0]], 
                                      y=[start_pt[1], end_pt[1]], 
                                      z=[start_pt[2], end_pt[2]],
                                      mode='lines',
                                      line=dict(color='cyan', width=6),
                                      name='Pred Direction')
        data.append(trace_pred_dir)
        
    # 3. True Direction (for reference)
    # We don't know the exact true origin in raw files without parsing _truth more carefully,
    # but we can vector-sum or just place it at centroid for visualization of ANGLE.
    centroid = np.mean(coords, axis=0)
    scale = 50.0
    td_end = centroid + true_dir * scale
    
    trace_true_dir = go.Scatter3d(x=[centroid[0], td_end[0]], 
                                  y=[centroid[1], td_end[1]], 
                                  z=[centroid[2], td_end[2]],
                                  mode='lines',
                                  line=dict(color='red', width=6),
                                  name='True Direction (Centered)')
    data.append(trace_true_dir)
    
    layout = go.Layout(title=title, 
                       scene=dict(aspectmode='data'),
                       margin=dict(l=0, r=0, b=0, t=40))
    fig = go.Figure(data=data, layout=layout)
    fig.show()

## 3. Interactive Tuning
Here we define the core logic. To tune, we run the KF on points sorted by distance from the "Start".  
**Note**: In the Evaluation script, we sorted points by distance from the *Predicted Origin*. Here, for raw tracks, we might not have a predicted origin yet. We can assume the "start" is one of the track ends, or use a heuristic (e.g. PCA main axis extreme).

In [None]:
def run_kalman_experiment(file_path, q_scale, r_scale, dt):
    print(f"Processing {os.path.basename(file_path)}...")
    coords, true_dir = load_track_raw(file_path)
    
    if len(coords) < 10:
        print("Track too short!")
        return
    
    # 1. Sort points to simulate time evolution
    # Heuristic: Find Principal Component, project points, sort by projection.
    # This orders them linearly along the track.
    mean = np.mean(coords, axis=0)
    centered = coords - mean
    u, s, vh = np.linalg.svd(centered, full_matrices=False)
    principal_axis = vh[0]
    
    projections = np.dot(coords, principal_axis)
    sorted_idx = np.argsort(projections)
    
    # Ensure we start from the 'beginning' (this is ambiguous in raw tracks without dE/dx)
    # For tuning, it doesn't matter much (time reversible-ish for KF path, though update direction matters).
    # Let's assume one direction.
    sorted_coords = coords[sorted_idx]
    
    # 2. Setup KF
    kf = TrackKalmanFilter(dt=dt)
    
    # INJECT PARAMS
    # Update matrices with user scales
    kf.Q = np.eye(6) * q_scale
    kf.R = np.eye(3) * r_scale
    
    # 3. Fit
    kf_dir, kf_smooth = kf.fit(sorted_coords)
    
    # 4. Calculate Angle Error
    # Check both directions since we might have sorted backwards coverage
    cos_sim = np.dot(true_dir, kf_dir) / (np.linalg.norm(true_dir)*np.linalg.norm(kf_dir) + 1e-9)
    angle_err = np.degrees(np.arccos(np.clip(cos_sim, -1.0, 1.0)))
    if angle_err > 90:
        angle_err = 180 - angle_err
        
    print(f"Estimated Angle Error: {angle_err:.2f} deg")
    
    # 5. Plot
    plot_kalman_result(sorted_coords, true_dir, kf_smooth, kf_dir, 
                       title=f"Q={q_scale}, R={r_scale} | Err: {angle_err:.1f} deg")

# EXAMPLE USAGE
# Set your data directory here:
DATA_DIR = "/sdf/home/b/bahrudin/gammaTPC/MLstudy723/processed_tracks906/" # Update this path!
# If using local PC:
DATA_DIR = r"C:\Users\Korisnik\PycharmProjects\gammaAIModel\data-samples\local_data\compton\E0003000\d03\data\" # Example

# Find a file manually or pick one
# files = get_sample_files(DATA_DIR, limit=5)
# if files:
#     file_to_test = files[0]
#     run_kalman_experiment(file_to_test, q_scale=0.1, r_scale=0.5, dt=1.0)
# else:
#     print("No files found. Please set DATA_DIR correctly.")