***
### Import of required libraries
***

In [None]:
import pickle

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.models import load_model
from haversine import haversine_vector, Unit

from traffic.core import Traffic

***
### Import of data, model and scalers
***

##### Data

In [None]:
test_in_con = np.load("/mnt/beegfs/store/krum/MT/samples/test_in32_con.npy")
test_in_var = np.load("/mnt/beegfs/store/krum/MT/samples/test_in32_var.npy")
test_out = np.load("/mnt/beegfs/store/krum/MT/samples/test_out32.npy")

##### Model

In [None]:
def rmse_lat():
    return None


def rmse_lon():
    return None


def rmse_alt():
    return None


model = load_model(
    f"/home/krum/git/MT_krum_code/models/snowy-gorge-126.keras",
    custom_objects={
        "rmse_lat": rmse_lat,
        "rmse_lon": rmse_lon,
        "rmse_alt": rmse_alt,
        "weighted_mse": tf.keras.losses.MeanSquaredError(),
    },
)

##### Scalers

In [None]:
# scaler_in
with open(
    "/mnt/beegfs/store/krum/MT/encoded_scaled_split/scaler_in.pkl",
    "rb",
) as file:
    scaler_in = pickle.load(file)

# scaler_out
with open(
    "/mnt/beegfs/store/krum/MT/encoded_scaled_split/scaler_out.pkl",
    "rb",
) as file:
    scaler_out = pickle.load(file)

***
### Application of model to get prediction on test set
***

##### Model application

In [None]:
# to be uncommented !
# prediction = model.predict((test_in_var, test_in_con))

In [None]:
# Temp, to be removed !
# np.save(
#     f"/home/krum/git/MT_krum_code/03_analysis/predictions_with_mass.npy",
#     prediction,
# )
prediction = np.load(
    f"/home/krum/git/MT_krum_code/03_analysis/predictions_with_mass.npy",
)

##### Unscaling

In [None]:
# Predictions
prediction_unscaled = scaler_out.inverse_transform(
    prediction.reshape(-1, 3)
).reshape(-1, 37, 3)[:, 1:, :]

# True values
test_out_unscaled = scaler_out.inverse_transform(
    test_out.reshape(-1, 3)
).reshape(-1, 37, 3)[:, 1:, :]

***
### Analysis of altitude prediction error
***

##### Error calculation

In [None]:
# Computation of altitude error and conversion from feet to meters
altitude_diff_m = (
    prediction_unscaled[:, :, 2] - test_out_unscaled[:, :, 2]
) / 3.281

##### Visualisation

In [None]:
# Generate boxplots
plt.figure(figsize=(20, 10))

# Create boxplot for each timestep (prediction horizon)
plt.boxplot(
    [altitude_diff_m[:, i] for i in range(36)],
    positions=range(36),
    patch_artist=True,
    boxprops=dict(facecolor="lightblue", color="black"),
    medianprops=dict(color="black"),
    whiskerprops=dict(color="black"),
    capprops=dict(color="black"),
    showfliers=False,
)

# Set plot axis labels and ticks
plt.xlabel("Prediction horizon [s]", fontsize=14)
plt.ylabel("Altitude prediction error [m]", fontsize=14)
plt.xticks(
    ticks=range(0, 36, 1),
    labels=[str(i) for i in range(5, 185, 5)],
    rotation=45,
)
plt.yticks(ticks=np.arange(-500, 500, 100))

# Set grid
plt.grid(True, which="both", linestyle="--", linewidth=0.5)
plt.gca().patch.set_alpha(0.3)

# Show plot
plt.show()

***
### Analysis of position prediction error
***

##### Error calculation

In [None]:
# Extraction of latitudes and longitudes
pred_lat = prediction_unscaled[:, :, 0].flatten()
pred_lon = prediction_unscaled[:, :, 1].flatten()
test_lat = test_out_unscaled[:, :, 0].flatten()
test_lon = test_out_unscaled[:, :, 1].flatten()

# Computation of latitude and longitude errors
lat_error = (pred_lat - test_lat).reshape(prediction_unscaled.shape[0], 36)
lon_error = (pred_lon - test_lon).reshape(prediction_unscaled.shape[0], 36)


# Computation of 2d position errors [m] using haversine formula
pred_lat = prediction_unscaled[:, :, 0].flatten()
pred_lon = prediction_unscaled[:, :, 1].flatten()
test_lat = test_out_unscaled[:, :, 0].flatten()
test_lon = test_out_unscaled[:, :, 1].flatten()

positions_pred = np.column_stack((pred_lat, pred_lon))
positions_actual = np.column_stack((test_lat, test_lon))

position_error_m = haversine_vector(positions_pred, positions_actual, unit=Unit.METERS)
position_error_m = position_error_m.reshape(prediction_unscaled.shape[0], 36)

##### Visualisation latitude error

In [None]:
# Generate boxplots
plt.figure(figsize=(20, 10))

# Create boxplot for each timestep (prediction horizon)
plt.boxplot(
    [lat_error[:, i] for i in range(36)],
    positions=range(36),
    patch_artist=True,
    boxprops=dict(facecolor="lightblue", color="black"),
    medianprops=dict(color="black"),
    whiskerprops=dict(color="black"),
    capprops=dict(color="black"),
    showfliers=False,
)

# Set plot axis labels and ticks
plt.xlabel("Prediction horizon [s]", fontsize=14)
plt.ylabel("Latitude prediction error [°]", fontsize=14)
plt.xticks(
    ticks=range(0, 36, 1),
    labels=[str(i) for i in range(5, 185, 5)],
    rotation=45,
)
plt.yticks(ticks=np.arange(-0.07, 0.07, 0.005))

# Set grid
plt.grid(True, which="both", linestyle="--", linewidth=0.5)
plt.gca().patch.set_alpha(0.3)

# Show plot
plt.show()

##### Visualisation longitude error

In [None]:
# Generate boxplots
plt.figure(figsize=(20, 10))

# Create boxplot for each timestep (prediction horizon)
plt.boxplot(
    [lon_error[:, i] for i in range(36)],
    positions=range(36),
    patch_artist=True,
    boxprops=dict(facecolor="lightblue", color="black"),
    medianprops=dict(color="black"),
    whiskerprops=dict(color="black"),
    capprops=dict(color="black"),
    showfliers=False,
)

# Set plot axis labels and ticks
plt.xlabel("Prediction horizon [s]", fontsize=14)
plt.ylabel("Longitude prediction error [m]", fontsize=14)
plt.xticks(
    ticks=range(0, 36, 1),
    labels=[str(i) for i in range(5, 185, 5)],
    rotation=45,
)
plt.yticks(ticks=np.arange(-0.085, 0.085, 0.005))

# Set grid
plt.grid(True, which="both", linestyle="--", linewidth=0.5)
plt.gca().patch.set_alpha(0.3)

# Show plot
plt.show()

##### Visualisation of 2d position error

In [None]:
# Generate boxplots
plt.figure(figsize=(20, 10))

# Create boxplot for each timestep (prediction horizon)
plt.boxplot(
    [position_error_m[:, i] for i in range(36)],
    positions=range(36),
    patch_artist=True,
    boxprops=dict(facecolor="lightblue", color="black"),
    medianprops=dict(color="black"),
    whiskerprops=dict(color="black"),
    capprops=dict(color="black"),
    showfliers=False,
)

# Set plot axis labels and ticks
plt.xlabel("Prediction horizon [s]", fontsize=14)
plt.ylabel("2D position prediction error [m]", fontsize=14)
plt.xticks(
    ticks=range(0, 36, 1),
    labels=[str(i) for i in range(5, 185, 5)],
    rotation=45,
)
plt.yticks(ticks=np.arange(0, 10000, 1000))

# Set grid
plt.grid(True, which="both", linestyle="--", linewidth=0.5)
plt.gca().patch.set_alpha(0.3)

# Show plot
plt.show()