***
### Import of required libraries
***

In [None]:
import pickle

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.models import load_model
from geopy.distance import distance

from traffic.core import Traffic
from traffic.data import airports
from tqdm.auto import tqdm

import plotly.graph_objects as go
from plotly.subplots import make_subplots


***
### Import of data, model and scalers
***

##### Trajectory data

In [None]:
t = Traffic.from_file(
    "/mnt/beegfs/store/krum/MT/encoded_scaled_split/t_test.parquet"
)

##### Model

In [None]:
def rmse_lat():
    return None


def rmse_lon():
    return None


def rmse_alt():
    return None


model = load_model(
    f"/home/krum/git/MT_krum_code/models/snowy-gorge-126.keras",
    custom_objects={
        "rmse_lat": rmse_lat,
        "rmse_lon": rmse_lon,
        "rmse_alt": rmse_alt,
        "weighted_mse": tf.keras.losses.MeanSquaredError(),
    },
)

##### Scalers

In [None]:
# scaler_in
with open(
    "/mnt/beegfs/store/krum/MT/encoded_scaled_split/scaler_in.pkl",
    "rb",
) as file:
    scaler_in = pickle.load(file)

# scaler_out
with open(
    "/mnt/beegfs/store/krum/MT/encoded_scaled_split/scaler_out.pkl",
    "rb",
) as file:
    scaler_out = pickle.load(file)

***
### Model application
***

In [None]:
inputs_var_unscaled = []
inputs_var = []
inputs_con = []

for flight in tqdm(t):
    f_in_var_unscaled = flight.data[
        [
            "latitude",
            "longitude",
            "altitude",
        ]
    ].iloc[-11:-1]
    inputs_var_unscaled.append(f_in_var_unscaled)

    input_var = (
        flight.data[
            [
                "latitude_scaled",
                "longitude_scaled",
                "altitude_scaled",
                "wind_x_2min_avg_scaled",
                "wind_y_2min_avg_scaled",
                "temperature_gnd_scaled",
                "humidity_gnd_scaled",
                "pressure_gnd_scaled",
            ]
        ]
        .iloc[-11:-1]
        .to_numpy()
        .reshape(10, 8)
    )
    inputs_var.append(input_var)

    input_con = (
        flight.data[
            [
                "toff_weight_kg_scaled",
                "typecode_A20N",
                "typecode_A21N",
                "typecode_A319",
                "typecode_A320",
                "typecode_A321",
                "typecode_A333",
                "typecode_A343",
                "typecode_B77W",
                "typecode_BCS1",
                "typecode_BCS3",
                "typecode_CRJ9",
                "typecode_DH8D",
                "typecode_E190",
                "typecode_E195",
                "typecode_E290",
                "typecode_E295",
                "typecode_F100",
                "typecode_SB20",
                "SID_DEGES",
                "SID_GERSA",
                "SID_VEBIT",
                "SID_ZUE",
                "hour_sin",
                "hour_cos",
                "weekday_sin",
                "weekday_cos",
                "month_sin",
                "month_cos",
            ]
        ]
        .iloc[-11]
        .to_numpy()
        .reshape(1, 29)
    )
    inputs_con.append(input_con)

# Convert to numpy arrays
inputs_var = np.vstack(inputs_var).reshape(len(t), 10, 8)
inputs_con = np.vstack(inputs_con).reshape(len(t), 1, 29)

# Predict
predictions = model.predict((inputs_var, inputs_con))

# Unscale
predictions_unscaled = scaler_out.inverse_transform(
    predictions.reshape(-1, 3)
).reshape(37061, 37, 3)[1:, :]

***
### Functions for plotting
***

##### Helper function to plot circles

In [None]:
# Function required to draw a circle around a point
def circle_coordinates(
    lat: float, lon: float, radius_km: float, num_points: int = 360
):
    """
    Calculate the coordinates forming a circle of a given diameter around a
    centre coordinate.

    Parameters
    ----------
    lat : float
        Latitude of the center point.
    lon : float
        Longitude of the center point.
    radius_km : float
        Radius of the circle in kilometers.
    num_points : int, optional
        Number of points to generate the circle, by default 360.
    """

    lat_rad = np.radians(lat)
    lon_rad = np.radians(lon)
    R = 6371.0
    angles = np.linspace(0, 2 * np.pi, num_points)
    circle_lat = []
    circle_lon = []

    for angle in angles:
        lat_point = np.arcsin(
            np.sin(lat_rad) * np.cos(radius_km / R)
            + np.cos(lat_rad) * np.sin(radius_km / R) * np.cos(angle)
        )
        lon_point = lon_rad + np.arctan2(
            np.sin(angle) * np.sin(radius_km / R) * np.cos(lat_rad),
            np.cos(radius_km / R) - np.sin(lat_rad) * np.sin(lat_point),
        )
        circle_lat.append(np.degrees(lat_point))
        circle_lon.append(np.degrees(lon_point))

    return circle_lat, circle_lon

##### Main function

In [None]:
def plot_flight(x:int) -> go.Figure:
    """
    Plot the flight trajectory and the prediction of the model for the selected
    flight x.

    Parameters
    ----------
    x : int
        Index of the flight to plot.

    Returns
    -------
    go.Figure
        Plotly figure object.
    """

    # x selection
    flight = t[x]

    # Generate subplots------------------------------------------------------------
    fig = make_subplots(
        rows=2,
        cols=1,
        specs=[[{"type": "scattermapbox"}], [{}]],
        row_heights=[0.6, 0.4],
        subplot_titles=("Position", "Altitude"),
        vertical_spacing=0.07,
    )

    # Add traces-------------------------------------------------------------------
    # Circle
    circle_lat, circle_lon = circle_coordinates(
        airports["LSZH"].latitude, airports["LSZH"].longitude, 30 * 1.852, 360
    )
    fig.add_trace(
        go.Scattermapbox(
            mode="lines",
            lat=circle_lat,
            lon=circle_lon,
            line=dict(color="black"),
            name="Cylinder boundary",
            legendgroup="fixed",
            showlegend=True,
        )
    )

    # Full trajectory position
    fig.append_trace(
        go.Scattermapbox(
            mode="lines",
            lat=flight.data["latitude"],
            lon=flight.data["longitude"],
            line=dict(width=4, color="lightgrey"),
            name="Full trajectory",
            legendgroup="fixed",
            showlegend=True,
        ),
        row=1,
        col=1,
    )

    # Input position
    fig.append_trace(
        go.Scattermapbox(
            lat=inputs_var_unscaled[x]["latitude"],
            lon=inputs_var_unscaled[x]["longitude"],
            marker=dict(size=7, color="#636efa"),
            name="model input",
            mode="markers",
            showlegend=True,
            legendgroup="model",
        ),
        row=1,
        col=1,
    )

    # Prediction position
    fig.append_trace(
        go.Scattermapbox(
            lat=predictions_unscaled[x, :, 0],
            lon=predictions_unscaled[x, :, 1],
            marker=dict(size=7, color="#00cc96"),
            name="Prediction",
            mode="markers",
            showlegend=True,
            legendgroup="model",
        ),
        row=1,
        col=1,
    )

    # Full trajectory altitude
    fig.append_trace(
        go.Scatter(
            x=np.arange(0, len(flight.data)),
            y=flight.data["altitude"],
            line=dict(width=4, color="lightgrey"),
            mode="lines",
            name="full",
            showlegend=False,
        ),
        row=2,
        col=1,
    )

    # Cylinder wall
    fig.add_shape(
        dict(
            type="line",
            x0=len(flight.data),
            y0=0,
            x1=len(flight.data),
            y1=18000,
            line=dict(
                color="black",
                width=2,
            ),
        ),
        row=2,
        col=1
    )

    # Cylinder upper boundary
    fig.add_shape(
        dict(
            type="line",
            x0=0,
            y0=18000,
            x1=1000,
            y1=18000,
            line=dict(
                color="black",
                width=2,
            ),
        ),
        row=2,
        col=1,
        showlegend=False,
    )

    # Input altitude
    fig.append_trace(
        go.Scatter(
            x=np.arange(len(flight.data)-10, len(flight.data)-10+len(inputs_var_unscaled[0])),
            y=flight.data[-10:]["altitude"],
            marker=dict(size=6, color="#636efa"),
            mode="markers",
            name="full",
            showlegend=False,
        ),
        row=2,
        col=1,
    )

    # Prediction altitude
    fig.append_trace(
        go.Scatter(
            x=np.arange(len(flight.data)-10+len(inputs_var_unscaled[0])+4, 5*(10+len(flight.data)-10+len(inputs_var_unscaled[0])), 5),
            y=predictions_unscaled[x, :, 2],
            marker=dict(size=6, color="#00cc96"),
            mode="markers",
            name="full",
            showlegend=False,
        ),
        row=2,
        col=1,
    )

    # Update Figure layout----------------------
    fig.update_xaxes(
        title_text="Elapsed time [s]",
        row=2,
        col=1,
        title_font={"size": 25},
        tickfont={"size": 20},
    )

    fig.update_yaxes(
        title_text="Baroaltitude [ft]",
        row=2,
        col=1,
        title_font={"size": 25},
        tickfont={"size": 20},
    )

    fig.update_layout(
        mapbox=dict(
            style="carto-positron",
            zoom=10,
            center=dict(
                lat=np.mean(flight.data["latitude"].mean()),
                lon=np.mean(flight.data["longitude"].mean()),
            ),
        ),
        width=1600,
        height=1000,
        margin=dict(l=50, r=20, t=40, b=40),
        legend=dict(font=dict(size=25)),
    )

    for annotation in fig["layout"]["annotations"]:
            annotation["font"] = dict(size=30)

    return fig

***
### Plotting
***

##### Flight 1

In [None]:
fig = plot_flight(543)
fig.show()

##### Flight 2

In [None]:
fig = plot_flight(1764)
fig.show()

##### Flight 3

In [None]:
fig = plot_flight(5441)
fig.show()