***
### Import of required libraries
***

In [None]:
import pickle

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import tensorflow as tf
from tensorflow.keras.models import load_model

from traffic.core import Traffic

***
### Import of data, model and scalers
***

##### Trajectory data

In [None]:
t = Traffic.from_file(
    "/mnt/beegfs/store/krum/MT/encoded_scaled_split/t_test.parquet"
)

##### Model

In [None]:
def rmse_lat():
    return None


def rmse_lon():
    return None


def rmse_alt():
    return None


model = load_model(
    f"/home/krum/git/MT_krum_code/models/snowy-gorge-126.keras",
    custom_objects={
        "rmse_lat": rmse_lat,
        "rmse_lon": rmse_lon,
        "rmse_alt": rmse_alt,
        "weighted_mse": tf.keras.losses.MeanSquaredError(),
    },
)

##### Scalers

In [None]:
# scaler_in
with open(
    "/mnt/beegfs/store/krum/MT/encoded_scaled_split/scaler_in.pkl",
    "rb",
) as file:
    scaler_in = pickle.load(file)

# scaler_out
with open(
    "/mnt/beegfs/store/krum/MT/encoded_scaled_split/scaler_out.pkl",
    "rb",
) as file:
    scaler_out = pickle.load(file)

***
### Application of model to flight
***

##### Generation of subsets

In [None]:
# Flight selection
f = t[555]

# Time variant inputs unscaled
f_in_var_unscaled = f.data[
    [
        "latitude",
        "longitude",
        "altitude",
        "wind_x_2min_avg",
        "wind_y_2min_avg",
        "temperature_gnd",
        "humidity_gnd",
        "pressure_gnd",
    ]
]

# Time variant inputs scaled
f_in_var = f.data[
    [
        "latitude_scaled",
        "longitude_scaled",
        "altitude_scaled",
        "wind_x_2min_avg_scaled",
        "wind_y_2min_avg_scaled",
        "temperature_gnd_scaled",
        "humidity_gnd_scaled",
        "pressure_gnd_scaled",
    ]
]

# Time invariant inputs
f_in_con = f.data[
    [
        "toff_weight_kg_scaled",
        "typecode_A20N",
        "typecode_A21N",
        "typecode_A319",
        "typecode_A320",
        "typecode_A321",
        "typecode_A333",
        "typecode_A343",
        "typecode_B77W",
        "typecode_BCS1",
        "typecode_BCS3",
        "typecode_CRJ9",
        "typecode_DH8D",
        "typecode_E190",
        "typecode_E195",
        "typecode_E290",
        "typecode_E295",
        "typecode_F100",
        "typecode_SB20",
        "SID_DEGES",
        "SID_GERSA",
        "SID_VEBIT",
        "SID_ZUE",
        "hour_sin",
        "hour_cos",
        "weekday_sin",
        "weekday_cos",
        "month_sin",
        "month_cos",
    ]
]

##### Generation of input samples along trajectory

In [None]:
# Variable inputs unscaled
inputs_var_unscaled = []

for i in range(len(f_in_var_unscaled) - 10 - 180):
    inputs_var_unscaled.append(f_in_var_unscaled.iloc[i : i + 10].to_numpy())

flattened_input = [item for sublist in inputs_var_unscaled for item in sublist]
input_var_unscaled = np.stack(flattened_input).reshape(-1, 10, 8)
print(input_var_unscaled.shape)

# Variable inputs scaled
inputs_var = []

for i in range(len(f_in_var) - 10 - 180):
    inputs_var.append(f_in_var.iloc[i : i + 10].to_numpy())

flattened_input = [item for sublist in inputs_var for item in sublist]
input_var = np.stack(flattened_input).reshape(-1, 10, 8)
print(input_var.shape)

# Constant inputs
inputs_con = []

for i in range(len(f_in_con) - 10 - 180):
    inputs_con.append(f_in_con.iloc[i + 10].to_numpy())

flattened_input = [item for sublist in inputs_con for item in sublist]
input_con = np.stack(flattened_input).reshape(-1, 1, 29)
print(input_con.shape)

##### Application of model and unscaling outputs

In [None]:
output = model.predict((input_var, input_con))
output_unscaled = scaler_out.inverse_transform(output.reshape(-1, 3)).reshape(
    -1, 37, 3
)[:, 1:, :]

***
### Animated simulation
***

##### Generation of plotting data

In [None]:
# Flight trajectory
df_full = f.data

# Input output pairs
input_output = []

# For each timestamp generate a dataframe and append to list
for i in range(input_var.shape[0]):
    lat_in = input_var_unscaled[i, :, 0]
    lon_in = input_var_unscaled[i, :, 1]
    alt_in = input_var_unscaled[i, :, 2]

    lat_out = output_unscaled[i, :, 0]
    lon_out = output_unscaled[i, :, 1]
    alt_out = output_unscaled[i, :, 2]

    # Model input
    df1 = pd.DataFrame(
        {
            "latitude": lat_in,
            "longitude": lon_in,
            "altitude": alt_in,
            "color": "#636efa",
        }
    )

    # Model prediction
    df2 = pd.DataFrame(
        {
            "latitude": lat_out,
            "longitude": lon_out,
            "altitude": alt_out,
            "color": "#00cc96",
        }
    )

    # Concatenate input and output and add to list
    df = pd.concat([df1, df2])
    input_output.append(df)

In [None]:
# Simulation length (timestamps)
lengths = len(df_full) - 180 - 10

# Generate subplots------------------------------------------------------------
fig = make_subplots(
    rows=2,
    cols=1,
    specs=[[{"type": "scattermapbox"}], [{}]],
    row_heights=[0.6, 0.4],
    subplot_titles=("Position", "Altitude"),
    vertical_spacing=0.07,
)

# Add initial traces-----------------------------------------------------------
# Full trajectory position
fig.append_trace(
    go.Scattermapbox(
        mode="lines",
        lat=df_full["latitude"],
        lon=df_full["longitude"],
        line=dict(width=2, color="lightgrey"),
        name="Full trajectory",
        showlegend=False,
    ),
    row=1,
    col=1,
)

# Model input-output position
fig.append_trace(
    go.Scattermapbox(
        lat=input_output[0]["latitude"],
        lon=input_output[0]["longitude"],
        marker=dict(size=5, color=input_output[0]["color"]),
        mode="markers",
        showlegend=False,
    ),
    row=1,
    col=1,
)

# Full trajectory altitude
fig.append_trace(
    go.Scatter(
        x=np.arange(0, len(df_full)),
        y=df_full["altitude"],
        line=dict(width=2, color="lightgrey"),
        mode="lines",
        name="full",
        showlegend=False,
    ),
    row=2,
    col=1,
)

# Input altitude
fig.append_trace(
    go.Scatter(
        x=np.arange(0, len(df_full)),
        y=input_output[0].iloc[0:10]["altitude"],
        marker=dict(size=3, color=input_output[0].iloc[0:10]["color"]),
        mode="markers",
        name="full",
        showlegend=False,
    ),
    row=2,
    col=1,
)

# Model input-output altitude
fig.append_trace(
    go.Scatter(
        x=np.arange(15, len(df_full), 5),
        y=input_output[0].iloc[10:]["altitude"],
        marker=dict(size=3, color=input_output[0].iloc[10:]["color"]),
        mode="markers",
        name="full",
        showlegend=False,
    ),
    row=2,
    col=1,
)

# Update Figure layout---------------------------------------------------------
fig.update_xaxes(
    title_text="Elapsed time [s]", range=[0, len(df_full)], row=2, col=1
)
fig.update_xaxes(
    title_text="Elapsed time [s]", range=[0, len(df_full)], row=3, col=1
)
fig.update_yaxes(
    title_text="Baroaltitude [ft]",
    range=[df_full.altitude.min(), df_full.altitude.max()],
    row=2,
    col=1,
)

fig.update_layout(
    mapbox=dict(
        style="carto-positron",
        zoom=10,
        center=dict(
            lat=np.mean(df_full["latitude"].mean()),
            lon=np.mean(df_full["longitude"].mean()),
        ),
    ),
    width=1200,
    height=1200,
    margin=dict(l=50, r=0, t=40, b=40),
)

# Animation--------------------------------------------------------------------
# Creation of animation frames
frames = [
    go.Frame(
        data=[
            # Position
            go.Scattermapbox(
                lat=input_output[k]["latitude"],
                lon=input_output[k]["longitude"],
                marker=dict(size=5, color=input_output[k]["color"]),
                mode="markers",
                showlegend=False,
            ),
            # Altitude input
            go.Scatter(
                x=np.arange(k, k + 180),
                y=input_output[k].iloc[0:10]["altitude"],
                marker=dict(size=3, color=input_output[k].iloc[0:10]["color"]),
                mode="markers",
                showlegend=False,
            ),
            # Altitude output
            go.Scatter(
                x=np.arange(k + 15, k + 15 + 180, 5),
                y=input_output[k].iloc[10:]["altitude"],
                marker=dict(size=3, color=input_output[k].iloc[10:]["color"]),
                mode="markers",
                showlegend=False,
            ),
        ],
        # Frame names
        name=f"fr{k}",
        traces=[
            1,
            3,
            4,
            5,
        ],
    )
    # Loop over all timestamps
    for k in range(lengths)
]

# Add frames to the figure
fig.update(frames=frames)


# Animation frame arguments
def frame_args(duration):
    return {
        "frame": {"duration": duration},
        "mode": "immediate",
        "fromcurrent": True,
        "transition": {"duration": duration, "easing": "linear"},
    }


# Animation frame duration
fr_duration = 750

# Slider configuration
sliders = [
    {
        "pad": {"b": 10, "t": 50},
        "len": 0.9,
        "x": 0.1,
        "y": 0,
        "steps": [
            {
                "args": [[f.name], frame_args(fr_duration)],
                "label": f"{k+1}s",
                "method": "animate",
            }
            for k, f in enumerate(fig.frames)
        ],
    }
]

# Update of slider and button layout
fig.update_layout(
    sliders=sliders,
    updatemenus=[
        {
            "buttons": [
                {
                    "args": [None, frame_args(fr_duration)],
                    "label": "&#9654;",  # play symbol
                    "method": "animate",
                },
                {
                    "args": [[None], frame_args(fr_duration)],
                    "label": "&#9724;",  # pause symbol
                    "method": "animate",
                },
            ],
            "direction": "left",
            "pad": {"r": 10, "t": 70},
            "type": "buttons",
            "x": 0.1,
            "y": 0,
        }
    ],
)

fig.show()