In [1]:
%load_ext autoreload
%autoreload 2

import os
import sys
sys.path.append(os.path.realpath('..'))

In [2]:
import numpy as np
import pandas as pd
import plotly.express as ex
import plotly.graph_objects as go

from copy import deepcopy

from tracking_v2.target import NearConstantVelocityTarget
from tracking_v2.tracker import Track, Tracker
from tracking_v2.kalman import linear_ncv
from tracking_v2.sensor import GeometricSensor
from tracking_v2.evaluation import TrackerRunner, StateFilterRunner, evaluate_nees, evaluate_runner, plot_nscore
import tracking_v2.evaluation.runner as runner

from tracking_v2.util import to_df, display

In [3]:
display.as_png = False

In [None]:
TRACK_ID = 0

class TrackState(object):
    def __init__(self, time: float, kf):
        self.time = time
        self.kf = kf
        self.score = 0
    
    def update(self, m):
        assert m.time >= self.time, "measurement in the past"
        self.kf.predict(m.time - self.time)
        self.kf.prepare_update(m.z, m.R)
        self.kf.update()


class ScoringTracker(Tracker):
    def __init__(self):
        self.reset()
    
    def reset(self):
        self.first_meas = None
        self.track: TrackState = None
    
    def add_measurements(self, ms):
        assert len(ms) == 1, "more than 1 measurement not supported"
        assert len(ms[0].z.squeeze()) == 3, "only 3D spaces are supported"

        m = ms[0]

        if self.first_meas is None:
            self.first_meas = m
        
        elif self.track is None:
            dt = m.time - self.first_meas.time
            dp = m.z - self.first_meas.z
            
            vel = dp / dt
            P_vel = (self.first_meas.R + m.R) / (dt * dt)
            P_pos_vel = m.R / dt
            
            x = np.concatenate((m.z.squeeze(), vel.squeeze()))
            P = np.zeros((6, 6))
            P[:3, :3] = m.R
            P[3:, 3:] = P_vel
            P[:3, 3:] = P_pos_vel
            P[3:, :3] = P_pos_vel
            
            kf = linear_ncv(noise_intensity=1)
            kf.initialize(x, P)
            
            self.track = TrackState(m.time, kf)

        else:
            self.track.update(m)
    
    def estimate_tracks(self, t: float):
        if self.track is None:
            return []
    
        assert t == self.track.time, "cannot estimate tracks for arbitrary t"
        return [Track(TRACK_ID, self.track.kf.x_hat, self.track.kf.P_hat)]