***
### Import of required libraries
***

In [None]:
import pickle

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.models import load_model
from haversine import haversine_vector, Unit

from traffic.core import Traffic

***
### Import of data, model and scalers
***

##### Data

In [None]:
test_in_con = np.load("/mnt/beegfs/store/krum/MT/samples/test_in32_con.npy")
test_in_var = np.load("/mnt/beegfs/store/krum/MT/samples/test_in32_var.npy")
test_out = np.load("/mnt/beegfs/store/krum/MT/samples/test_out32.npy")

##### Model

In [None]:
def rmse_lat():
    return None


def rmse_lon():
    return None


def rmse_alt():
    return None


model = load_model(
    f"/home/krum/git/MT_krum_code/models/snowy-gorge-126.keras",
    custom_objects={
        "rmse_lat": rmse_lat,
        "rmse_lon": rmse_lon,
        "rmse_alt": rmse_alt,
        "weighted_mse": tf.keras.losses.MeanSquaredError(),
    },
)

##### Scalers

In [None]:
# scaler_in
with open(
    "/mnt/beegfs/store/krum/MT/encoded_scaled_split/scaler_in.pkl",
    "rb",
) as file:
    scaler_in = pickle.load(file)

# scaler_out
with open(
    "/mnt/beegfs/store/krum/MT/encoded_scaled_split/scaler_out.pkl",
    "rb",
) as file:
    scaler_out = pickle.load(file)

***
### Application of model to get prediction on test set
***

##### Model application

In [None]:
# to be uncommented !
# prediction = model.predict((test_in_var, test_in_con))

In [None]:
# Temp, to be removed !
# np.save(
#     f"/home/krum/git/MT_krum_code/03_analysis/predictions_with_mass.npy",
#     prediction,
# )
prediction = np.load(
    f"/home/krum/git/MT_krum_code/03_analysis/predictions_with_mass.npy",
)

##### Unscaling

In [None]:
# Predictions
prediction_unscaled = scaler_out.inverse_transform(
    prediction.reshape(-1, 3)
).reshape(-1, 37, 3)[:, 1:, :]

# True values
test_out_unscaled = scaler_out.inverse_transform(
    test_out.reshape(-1, 3)
).reshape(-1, 37, 3)[:, 1:, :]

***
### Analysis of altitude prediction error
***

##### Error calculation

In [None]:
# Computation of altitude error and conversion from feet to meters
altitude_diff_m = (
    prediction_unscaled[:, :, 2] - test_out_unscaled[:, :, 2]
) / 3.281

##### Visualisation

In [None]:
# Generate boxplots
plt.figure(figsize=(20, 10))

# Create boxplot for each timestep (prediction horizon)
plt.boxplot(
    [altitude_diff_m[:, i] for i in range(36)],
    positions=range(36),
    patch_artist=True,
    boxprops=dict(facecolor="lightblue", color="black"),
    medianprops=dict(color="black"),
    whiskerprops=dict(color="black"),
    capprops=dict(color="black"),
    showfliers=False,
)

# Set plot axis labels and ticks
plt.xlabel("Prediction horizon [s]", fontsize=14)
plt.ylabel("Altitude prediction error [m]", fontsize=14)
plt.xticks(
    ticks=range(0, 36, 1),
    labels=[str(i) for i in range(5, 185, 5)],
    rotation=45,
)
plt.yticks(ticks=np.arange(-500, 500, 100))

# Set grid
plt.grid(True, which="both", linestyle="--", linewidth=0.5)
plt.gca().patch.set_alpha(0.3)

# Show plot
plt.show()

***
### Analysis of position prediction error
***

##### Error calculation

In [None]:
# Extraction of latitudes and longitudes
pred_lat = prediction_unscaled[:, :, 0].flatten()
pred_lon = prediction_unscaled[:, :, 1].flatten()
test_lat = test_out_unscaled[:, :, 0].flatten()
test_lon = test_out_unscaled[:, :, 1].flatten()

# Computation of latitude and longitude errors
lat_error = (pred_lat - test_lat).reshape(prediction_unscaled.shape[0], 36)
lon_error = (pred_lon - test_lon).reshape(prediction_unscaled.shape[0], 36)


# Computation of 2d position errors [m] using haversine formula
pred_lat = prediction_unscaled[:, :, 0].flatten()
pred_lon = prediction_unscaled[:, :, 1].flatten()
test_lat = test_out_unscaled[:, :, 0].flatten()
test_lon = test_out_unscaled[:, :, 1].flatten()

positions_pred = np.column_stack((pred_lat, pred_lon))
positions_actual = np.column_stack((test_lat, test_lon))

position_error_m = haversine_vector(positions_pred, positions_actual, unit=Unit.METERS)
position_error_m = position_error_m.reshape(prediction_unscaled.shape[0], 36)

##### Visualisation latitude error

In [None]:
# Generate boxplots
plt.figure(figsize=(20, 10))

# Create boxplot for each timestep (prediction horizon)
plt.boxplot(
    [lat_error[:, i] for i in range(36)],
    positions=range(36),
    patch_artist=True,
    boxprops=dict(facecolor="lightblue", color="black"),
    medianprops=dict(color="black"),
    whiskerprops=dict(color="black"),
    capprops=dict(color="black"),
    showfliers=False,
)

# Set plot axis labels and ticks
plt.xlabel("Prediction horizon [s]", fontsize=14)
plt.ylabel("Latitude prediction error [°]", fontsize=14)
plt.xticks(
    ticks=range(0, 36, 1),
    labels=[str(i) for i in range(5, 185, 5)],
    rotation=45,
)
plt.yticks(ticks=np.arange(-0.07, 0.07, 0.005))

# Set grid
plt.grid(True, which="both", linestyle="--", linewidth=0.5)
plt.gca().patch.set_alpha(0.3)

# Show plot
plt.show()

##### Visualisation longitude error

In [None]:
# Generate boxplots
plt.figure(figsize=(20, 10))

# Create boxplot for each timestep (prediction horizon)
plt.boxplot(
    [lon_error[:, i] for i in range(36)],
    positions=range(36),
    patch_artist=True,
    boxprops=dict(facecolor="lightblue", color="black"),
    medianprops=dict(color="black"),
    whiskerprops=dict(color="black"),
    capprops=dict(color="black"),
    showfliers=False,
)

# Set plot axis labels and ticks
plt.xlabel("Prediction horizon [s]", fontsize=14)
plt.ylabel("Longitude prediction error [m]", fontsize=14)
plt.xticks(
    ticks=range(0, 36, 1),
    labels=[str(i) for i in range(5, 185, 5)],
    rotation=45,
)
plt.yticks(ticks=np.arange(-0.085, 0.085, 0.005))

# Set grid
plt.grid(True, which="both", linestyle="--", linewidth=0.5)
plt.gca().patch.set_alpha(0.3)

# Show plot
plt.show()

##### Visualisation of 2d position error

In [None]:
# Generate boxplots
plt.figure(figsize=(20, 10))

# Create boxplot for each timestep (prediction horizon)
plt.boxplot(
    [position_error_m[:, i] for i in range(36)],
    positions=range(36),
    patch_artist=True,
    boxprops=dict(facecolor="lightblue", color="black"),
    medianprops=dict(color="black"),
    whiskerprops=dict(color="black"),
    capprops=dict(color="black"),
    showfliers=False,
)

# Set plot axis labels and ticks
plt.xlabel("Prediction horizon [s]", fontsize=14)
plt.ylabel("2D position prediction error [m]", fontsize=14)
plt.xticks(
    ticks=range(0, 36, 1),
    labels=[str(i) for i in range(5, 185, 5)],
    rotation=45,
)
plt.yticks(ticks=np.arange(0, 10000, 1000))

# Set grid
plt.grid(True, which="both", linestyle="--", linewidth=0.5)
plt.gca().patch.set_alpha(0.3)

# Show plot
plt.show()

***
### Analysis separated by SID
***

##### Altitude error

In [None]:
# Generate subsets of altitude errors corresponding to each SID
# DEGES
indices_DEGES = np.where(test_in_con[:, :, 19] == 1)[0].tolist()
alt_diff_DEGES = altitude_diff_m[indices_DEGES,:][:, 1::2]

# GERSA
indices_GERSA = np.where(test_in_con[:, :, 20] == 1)[0].tolist()
alt_diff_GERSA = altitude_diff_m[indices_GERSA,:][:, 1::2]

# VEBIT
indices_VEBIT = np.where(test_in_con[:, :, 21] == 1)[0].tolist()
alt_diff_VEBIT = altitude_diff_m[indices_VEBIT,:][:, 1::2]

# ZUE
indices_ZUE = np.where(test_in_con[:, :, 22] == 1)[0].tolist()
alt_diff_ZUE = altitude_diff_m[indices_ZUE,:][:, 1::2]

In [None]:
# Generate lists of lists for visualisation
group_DEGES = [
    alt_diff_DEGES[:, i]
    for i in range(alt_diff_DEGES.shape[1])
]
group_GERSA = [
    alt_diff_GERSA[:, i]
    for i in range(alt_diff_GERSA.shape[1])
]
group_VEBIT = [
    alt_diff_VEBIT[:, i]
    for i in range(alt_diff_VEBIT.shape[1])
]
group_ZUE = [
    alt_diff_ZUE[:, i]
    for i in range(alt_diff_ZUE.shape[1])
]

# Plotting parameters
labels = ["DEGES", "GERSA", "VEBIT", "ZUE"]
colors = ["#dd6044", "#6e77f4", "#5dc999", "#a167f4"]
data_groups = [group_DEGES, group_GERSA, group_VEBIT, group_ZUE]
width = 0.1  # Width of a boxplot

# Locations for boxplots
xlocations = np.arange(len(group_DEGES))
group_positions = [
    xlocations - 4*(width / 2) - 0.08,
    xlocations - 2*(width / 2) - 0.008,
    xlocations + 2*(width / 2) + 0.008,
    xlocations + 4*(width / 2) + 0.08,
]

# Create plot
fig = plt.subplots(figsize=(20, 10))
for dag, pos, col in zip(data_groups, group_positions, colors):
    boxes = plt.boxplot(
        dag,
        positions=pos,
        widths=width,
        boxprops=dict(facecolor=col, alpha=0.75),
        medianprops=dict(color="grey"),
        patch_artist=True,
        showfliers=False,
    )

# Set plot axis labels and ticks
plt.xlabel("Prediction horizon [s]", fontsize=14)
plt.ylabel("Altitude prediction error [m]", fontsize=14)
plt.xticks(
    ticks=range(0, 18, 1),
    labels=[str(i + 10) for i in range(0, 180, 10)],
    rotation=0,
)
# plt.yticks(ticks=np.arange(0, 12000, 1000))

# Set plot grid
plt.grid(True, which="both", linestyle="--", linewidth=0.5)
plt.gca().patch.set_alpha(0.3)

# Add legend
handles = [plt.Line2D([0], [0], color=c, lw=4) for c in colors]
plt.legend(handles, labels)

plt.show()

##### Position error

In [None]:
# Generate subsets of position errors corresponding to each SID
# DEGES
indices_DEGES = np.where(test_in_con[:, :, 19] == 1)[0].tolist()
position_error_m_DEGES = position_error_m[indices_DEGES,:][:, 1::2]

# GERSA
indices_GERSA = np.where(test_in_con[:, :, 20] == 1)[0].tolist()
position_error_m_GERSA = position_error_m[indices_GERSA,:][:, 1::2]

# VEBIT
indices_VEBIT = np.where(test_in_con[:, :, 21] == 1)[0].tolist()
position_error_m_VEBIT = position_error_m[indices_VEBIT,:][:, 1::2]

# ZUE
indices_ZUE = np.where(test_in_con[:, :, 22] == 1)[0].tolist()
position_error_m_ZUE = position_error_m[indices_ZUE,:][:, 1::2]

In [None]:
# Generate lists of lists for visualisation
group_DEGES = [
    position_error_m_DEGES[:, i]
    for i in range(position_error_m_DEGES.shape[1])
]
group_GERSA = [
    position_error_m_GERSA[:, i]
    for i in range(position_error_m_GERSA.shape[1])
]
group_VEBIT = [
    position_error_m_VEBIT[:, i]
    for i in range(position_error_m_VEBIT.shape[1])
]
group_ZUE = [
    position_error_m_ZUE[:, i]
    for i in range(position_error_m_ZUE.shape[1])
]

# Plotting parameters
labels = ["DEGES", "GERSA", "VEBIT", "ZUE"]
colors = ["#dd6044", "#6e77f4", "#5dc999", "#a167f4"]
data_groups = [group_DEGES, group_GERSA, group_VEBIT, group_ZUE]
width = 0.1  # Width of a boxplot

# Locations for boxplots
xlocations = np.arange(len(group_DEGES))
group_positions = [
    xlocations - 4*(width / 2) - 0.08,
    xlocations - 2*(width / 2) - 0.008,
    xlocations + 2*(width / 2) + 0.008,
    xlocations + 4*(width / 2) + 0.08,
]

# Create plot
fig = plt.subplots(figsize=(20, 10))
for dag, pos, col in zip(data_groups, group_positions, colors):
    boxes = plt.boxplot(
        dag,
        positions=pos,
        widths=width,
        boxprops=dict(facecolor=col, alpha=0.75),
        medianprops=dict(color="grey"),
        patch_artist=True,
        showfliers=False,
    )

# Set plot axis labels and ticks
plt.xlabel("Prediction horizon [s]", fontsize=14)
plt.ylabel("2D position prediction error [m]", fontsize=14)
plt.xticks(
    ticks=range(0, 18, 1),
    labels=[str(i + 10) for i in range(0, 180, 10)],
    rotation=0,
)
# plt.yticks(ticks=np.arange(0, 12000, 1000))

# Set plot grid
plt.grid(True, which="both", linestyle="--", linewidth=0.5)
plt.gca().patch.set_alpha(0.3)

# Add legend
handles = [plt.Line2D([0], [0], color=c, lw=4) for c in colors]
plt.legend(handles, labels)

plt.show()

***
### Analysis separated by Aircraft Type
***

##### Altitude error

In [None]:
# Generate subsets of altitude errors corresponding to AC types
# A320
indices_A320 = np.where(test_in_con[:, :, 4] == 1)[0].tolist()
alt_diff_A320 = altitude_diff_m[indices_A320,:][:, 1::2]

# BCS3
indices_BCS3 = np.where(test_in_con[:, :, 10] == 1)[0].tolist()
alt_diff_BCS3 = altitude_diff_m[indices_BCS3,:][:, 1::2]

# A333
indices_A333 = np.where(test_in_con[:, :, 6] == 1)[0].tolist()
alt_diff_A333 = altitude_diff_m[indices_A333,:][:, 1::2]

# B77W
indices_B77W = np.where(test_in_con[:, :, 8] == 1)[0].tolist()
alt_diff_B77W = altitude_diff_m[indices_B77W,:][:, 1::2]

In [None]:
# Generate lists of lists for visualisation
group_A320 = [
    alt_diff_A320[:, i]
    for i in range(alt_diff_A320.shape[1])
]
group_BCS3 = [
    alt_diff_BCS3[:, i]
    for i in range(alt_diff_BCS3.shape[1])
]
group_A333 = [
    alt_diff_A333[:, i]
    for i in range(alt_diff_A333.shape[1])
]
group_B77W = [
    alt_diff_B77W[:, i]
    for i in range(alt_diff_B77W.shape[1])
]

# Plotting parameters
labels = ["A320", "BCS3", "A333", "B77W"]
colors = ["#dd6044", "#6e77f4", "#5dc999", "#a167f4"]
data_groups = [group_A320, group_BCS3, group_A333, group_B77W]
width = 0.1  # Width of a boxplot

# Locations for boxplots
xlocations = np.arange(len(group_A320))
group_positions = [
    xlocations - 4*(width / 2) - 0.08,
    xlocations - 2*(width / 2) - 0.008,
    xlocations + 2*(width / 2) + 0.008,
    xlocations + 4*(width / 2) + 0.08,
]

# Create plot
fig = plt.subplots(figsize=(20, 10))
for dag, pos, col in zip(data_groups, group_positions, colors):
    boxes = plt.boxplot(
        dag,
        positions=pos,
        widths=width,
        boxprops=dict(facecolor=col, alpha=0.75),
        medianprops=dict(color="grey"),
        patch_artist=True,
        showfliers=False,
    )

# Set plot axis labels and ticks
plt.xlabel("Prediction horizon [s]", fontsize=14)
plt.ylabel("Altitude prediction error [m]", fontsize=14)
plt.xticks(
    ticks=range(0, 18, 1),
    labels=[str(i + 10) for i in range(0, 180, 10)],
    rotation=0,
)
# plt.yticks(ticks=np.arange(0, 12000, 1000))

# Set plot grid
plt.grid(True, which="both", linestyle="--", linewidth=0.5)
plt.gca().patch.set_alpha(0.3)

# Add legend
handles = [plt.Line2D([0], [0], color=c, lw=4) for c in colors]
plt.legend(handles, labels)

plt.show()

##### Position error

In [None]:
# Generate subsets of position errors corresponding to AC types
# A320
indices_A320 = np.where(test_in_con[:, :, 4] == 1)[0].tolist()
position_error_m_A320 = position_error_m[indices_A320,:][:, 1::2]

# BCS3
indices_BCS3 = np.where(test_in_con[:, :, 10] == 1)[0].tolist()
position_error_m_BCS3 = position_error_m[indices_BCS3,:][:, 1::2]

# A333
indices_A333 = np.where(test_in_con[:, :, 6] == 1)[0].tolist()
position_error_m_A333 = position_error_m[indices_A333,:][:, 1::2]

# B77W
indices_B77W = np.where(test_in_con[:, :, 8] == 1)[0].tolist()
position_error_m_B77W = position_error_m[indices_B77W,:][:, 1::2]

In [None]:
# Generate lists of lists for visualisation
group_A320 = [
    position_error_m_A320[:, i]
    for i in range(position_error_m_A320.shape[1])
]
group_BCS3 = [
    position_error_m_BCS3[:, i]
    for i in range(position_error_m_BCS3.shape[1])
]
group_A333 = [
    position_error_m_A333[:, i]
    for i in range(position_error_m_A333.shape[1])
]
group_B77W = [
    position_error_m_B77W[:, i]
    for i in range(position_error_m_B77W.shape[1])
]

# Plotting parameters
labels = ["A320", "BCS3", "A333", "B77W"]
colors = ["#dd6044", "#6e77f4", "#5dc999", "#a167f4"]
data_groups = [group_A320, group_BCS3, group_A333, group_B77W]
width = 0.1  # Width of a boxplot

# Locations for boxplots
xlocations = np.arange(len(group_A320))
group_positions = [
    xlocations - 4*(width / 2) - 0.08,
    xlocations - 2*(width / 2) - 0.008,
    xlocations + 2*(width / 2) + 0.008,
    xlocations + 4*(width / 2) + 0.08,
]

# Create plot
fig = plt.subplots(figsize=(20, 10))
for dag, pos, col in zip(data_groups, group_positions, colors):
    boxes = plt.boxplot(
        dag,
        positions=pos,
        widths=width,
        boxprops=dict(facecolor=col, alpha=0.75),
        medianprops=dict(color="grey"),
        patch_artist=True,
        showfliers=False,
    )

# Set plot axis labels and ticks
plt.xlabel("Prediction horizon [s]", fontsize=14)
plt.ylabel("2D position prediction error [m]", fontsize=14)
plt.xticks(
    ticks=range(0, 18, 1),
    labels=[str(i + 10) for i in range(0, 180, 10)],
    rotation=0,
)
# plt.yticks(ticks=np.arange(0, 12000, 1000))

# Set plot grid
plt.grid(True, which="both", linestyle="--", linewidth=0.5)
plt.gca().patch.set_alpha(0.3)

# Add legend
handles = [plt.Line2D([0], [0], color=c, lw=4) for c in colors]
plt.legend(handles, labels)

plt.show()