***
### Import of required libraries
***

In [None]:
import pickle

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.models import load_model
from geopy.distance import distance

from traffic.core import Traffic
from tqdm.auto import tqdm

import plotly.graph_objects as go
from plotly.subplots import make_subplots


***
### Import of data, model and scalers
***

##### Trajectory data

In [None]:
t = Traffic.from_file(
    "/mnt/beegfs/store/krum/MT/encoded_scaled_split/t_test.parquet"
)

##### Model

In [None]:
def rmse_lat():
    return None


def rmse_lon():
    return None


def rmse_alt():
    return None


model = load_model(
    f"/home/krum/git/MT_krum_code/models/snowy-gorge-126.keras",
    custom_objects={
        "rmse_lat": rmse_lat,
        "rmse_lon": rmse_lon,
        "rmse_alt": rmse_alt,
        "weighted_mse": tf.keras.losses.MeanSquaredError(),
    },
)

##### Scalers

In [None]:
# scaler_in
with open(
    "/mnt/beegfs/store/krum/MT/encoded_scaled_split/scaler_in.pkl",
    "rb",
) as file:
    scaler_in = pickle.load(file)

# scaler_out
with open(
    "/mnt/beegfs/store/krum/MT/encoded_scaled_split/scaler_out.pkl",
    "rb",
) as file:
    scaler_out = pickle.load(file)

***
### Model application
***

In [None]:
inputs_var_unscaled = []
inputs_var = []
inputs_con = []

for flight in tqdm(t):
    f_in_var_unscaled = flight.data[
        [
            "latitude",
            "longitude",
            "altitude",
        ]
    ].iloc[-11:-1]
    inputs_var_unscaled.append(f_in_var_unscaled)

    input_var = (
        flight.data[
            [
                "latitude_scaled",
                "longitude_scaled",
                "altitude_scaled",
                "wind_x_2min_avg_scaled",
                "wind_y_2min_avg_scaled",
                "temperature_gnd_scaled",
                "humidity_gnd_scaled",
                "pressure_gnd_scaled",
            ]
        ]
        .iloc[-11:-1]
        .to_numpy()
        .reshape(10, 8)
    )
    inputs_var.append(input_var)

    input_con = (
        flight.data[
            [
                "toff_weight_kg_scaled",
                "typecode_A20N",
                "typecode_A21N",
                "typecode_A319",
                "typecode_A320",
                "typecode_A321",
                "typecode_A333",
                "typecode_A343",
                "typecode_B77W",
                "typecode_BCS1",
                "typecode_BCS3",
                "typecode_CRJ9",
                "typecode_DH8D",
                "typecode_E190",
                "typecode_E195",
                "typecode_E290",
                "typecode_E295",
                "typecode_F100",
                "typecode_SB20",
                "SID_DEGES",
                "SID_GERSA",
                "SID_VEBIT",
                "SID_ZUE",
                "hour_sin",
                "hour_cos",
                "weekday_sin",
                "weekday_cos",
                "month_sin",
                "month_cos",
            ]
        ]
        .iloc[-11]
        .to_numpy()
        .reshape(1, 29)
    )
    inputs_con.append(input_con)

# Convert to numpy arrays
inputs_var = np.vstack(inputs_var).reshape(len(t), 10, 8)
inputs_con = np.vstack(inputs_con).reshape(len(t), 1, 29)

# Predict
predictions = model.predict((inputs_var, inputs_con))

# Unscale
predictions_unscaled = scaler_out.inverse_transform(
    predictions.reshape(-1, 3)
).reshape(37061, 37, 3)[1:, :]

***
### Plotting
***

In [None]:
# Sample selection
sample = 8

# Generate subplots------------------------------------------------------------
fig = make_subplots(
    rows=2,
    cols=1,
    specs=[[{"type": "scattermapbox"}], [{}]],
    row_heights=[0.6, 0.4],
    subplot_titles=("Position", "Altitude"),
    vertical_spacing=0.07,
)

# Add traces-------------------------------------------------------------------
# Input position
fig.append_trace(
    go.Scattermapbox(
        lat=inputs_var_unscaled[sample]["latitude"],
        lon=inputs_var_unscaled[sample]["longitude"],
        marker=dict(size=5, color="red"),
        mode="markers",
        showlegend=False,
    ),
    row=1,
    col=1,
)

# Prediction position
fig.append_trace(
    go.Scattermapbox(
        lat=predictions_unscaled[sample, :, 0],
        lon=predictions_unscaled[sample, :, 1],
        marker=dict(size=5, color="blue"),
        mode="markers",
        showlegend=False,
    ),
    row=1,
    col=1,
)

# Input altitude
fig.append_trace(
    go.Scatter(
        x=np.arange(0, len(inputs_var_unscaled[0])),
        y=inputs_var_unscaled[0]["altitude"],
        marker=dict(size=3, color="red"),
        mode="markers",
        name="full",
        showlegend=False,
    ),
    row=2,
    col=1,
)

# Prediction altitude
fig.append_trace(
    go.Scatter(
        x=np.arange(10, 5*(10+len(predictions_unscaled[sample, :, 2])), 5),
        y=predictions_unscaled[sample, :, 2],
        marker=dict(size=3, color="blue"),
        mode="markers",
        name="full",
        showlegend=False,
    ),
    row=2,
    col=1,
)

# Update Figure layout---------------------------------------------------------
fig.update_layout(
    mapbox=dict(
        style="carto-positron",
        zoom=10,
        # center=dict(
        #     lat=np.mean(df_full["latitude"].mean()),
        #     lon=np.mean(df_full["longitude"].mean()),
        # ),
    ),
    width=1200,
    height=1200,
    margin=dict(l=50, r=0, t=40, b=40),
)

fig.show()