# 9.6.2 Implementation of the Kalman Filter
Apply Kalman Filtering for smoothing/cleaning vessel trajectories

In [None]:
# imports
import dash
from dash import dcc, html, Input, Output
import plotly.graph_objects as go
import geopandas as gpd
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
from sqlalchemy import create_engine
from stonesoup.models.transition.linear import CombinedLinearGaussianTransitionModel, ConstantVelocity
from stonesoup.models.measurement.linear import LinearGaussian
from stonesoup.predictor.kalman import KalmanPredictor
from stonesoup.updater.kalman import KalmanUpdater
from stonesoup.types.state import GaussianState
from stonesoup.types.detection import Detection
from stonesoup.types.array import CovarianceMatrix
from stonesoup.types.hypothesis import SingleHypothesis
import json

In [None]:
# Load database configuration from config.json
with open("config.json", "r") as file:
    config = json.load(file)

# Construct the database URL for SQLAlchemy
database_url = (
    f"postgresql://{config['DB_USER']}:{config['DB_PASS']}@"
    f"{config['DB_HOST']}:{config['DB_PORT']}/{config['DB_NAME']}"
)

# Create the SQLAlchemy engine
engine = create_engine(database_url)

In [None]:
# Initialize the Dash application and define the layout:
app = dash.Dash(__name__)
app.layout = html.Div([
    dcc.Dropdown(id='mmsi-dropdown', placeholder='Select MMSI'),
    dcc.Graph(id='trajectory-graph')
])


In [None]:
# Callback to dynamically set the options for the MMSI dropdown menu:
@app.callback(
    Output('mmsi-dropdown', 'options'),
    Input('mmsi-dropdown', 'placeholder')
)
def set_mmsi_options(_):
    query = '''
    SELECT DISTINCT mmsi 
    FROM AISInputSample  
    WHERE MMSI IN (246541000, 636018799,311001076,304111000, 211269660, 219014579,219019011,259896000);
    '''
    mmsis = pd.read_sql(query, engine)
    return [{'label': mmsi, 'value': mmsi} for mmsi in mmsis['mmsi']]


In [None]:
# Kalman filtering code
def perform_kalman_filtering(gdf):
    # Define measurement and transition models
    measurement_noise_std = [10.0, 10.0]
    measurement_model = LinearGaussian(
        ndim_state=4,  # position and velocity in 2D
        mapping=(0, 2),
        noise_covar=np.diag([measurement_noise_std[0]**2, measurement_noise_std[1]**2])
    )

    process_noise_std = [1, 1]  # Modify based on application needs
    transition_model = CombinedLinearGaussianTransitionModel([
        ConstantVelocity(process_noise_std[0]**2),
        ConstantVelocity(process_noise_std[1]**2)
    ])

    # Create detections
    detections = [
        Detection(np.array([row.geomproj.x, row.geomproj.y]), timestamp=row.timestamp, measurement_model=measurement_model)
        for _, row in gdf.iterrows()
    ]
    
    # Extract initial state
    initial_state_mean = [gdf.geomproj.iloc[0].x, 0, gdf.geomproj.iloc[0].y, 0]  # [x, x_velocity, y, y_velocity]
    initial_state_covariance = np.diag([measurement_noise_std[0]**2, 
                                        measurement_noise_std[0]**2, 
                                        process_noise_std[1]**2, 
                                        process_noise_std[1]**2])
    initial_state = GaussianState(initial_state_mean, initial_state_covariance, timestamp=detections[0].timestamp)


    # Kalman filter execution
    predictor = KalmanPredictor(transition_model)
    updater = KalmanUpdater(measurement_model)
    

    # List to store filtered states
    filtered_states = []

    # Filtering process
    for i, detection in enumerate(detections):
        if i == 0:
            # For the first measurement, there is no prediction step
            predicted_state = initial_state
        else:

            # Predict the next state using the prior state
            predicted_state = predictor.predict(filtered_states[-1], timestamp=detection.timestamp)

        # Create a hypothesis associating the predicted state with the detection
        hypothesis = SingleHypothesis(predicted_state, detection)

        # Update the state with the hypothesis
        updated_state = updater.update(hypothesis)

        # Store the filtered state
        filtered_states.append(updated_state)

    # Extract the smoothed coordinates
    smoothed_coords = np.array([[state.state_vector[0, 0], state.state_vector[2, 0]] for state in filtered_states])
    return smoothed_coords


In [None]:
# Callback function to update the time-series plot based on the selected MMSI:
@app.callback(
    Output('trajectory-graph', 'figure'),
    [Input('mmsi-dropdown', 'value')]
)
def update_graph(selected_mmsi):
    if selected_mmsi is not None:
        # Fetch trajectory data
        query = f"SELECT geomproj, t AS timestamp FROM AISInputSample WHERE mmsi = {selected_mmsi} ORDER BY t;"
        gdf = gpd.read_postgis(query, engine, geom_col='geomproj')

        # Call the Kalman filtering function
        smoothed_coords = perform_kalman_filtering(gdf)

        # Prepare data for plotting
        original_x = gdf.geometry.x
        original_y = gdf.geometry.y
        smoothed_x = [coord[0] for coord in smoothed_coords]
        smoothed_y = [coord[1] for coord in smoothed_coords]

        # Plotting the trajectories
        fig = go.Figure()
        fig.add_trace(go.Scattergl(x=original_x, y=original_y, mode='lines', name='Original Path'))
        fig.add_trace(go.Scattergl(x=smoothed_x, y=smoothed_y, mode='lines', name='Smoothed Path'))
        fig.update_layout(xaxis_title='x-coordinate', yaxis_title='y-coordinate', 
                        xaxis=dict(
                            tickmode='auto',
                            tickformat=',',  # This will ensure that numbers are separated by commas but not in scientific notation
                        ),
                        yaxis=dict(
                            tickmode='auto',
                            tickformat=','
                        ),
                        margin={'l': 80, 'b': 140, 't': 50, 'r': 10},
                        font=dict(
                            family="Times New Roman",
                            size=18,
                            color= "black"
                        ),
                        autosize=False,
                        width=1000,
                        height=400,
                        )
        return fig

    return go.Figure()  # Return an empty figure if no MMSI is selected


In [None]:
# change port if another dash server is running on this port
if __name__ == '__main__':
    app.run_server(port= 8052)
