***
### Import of required libraries
***

In [None]:
import pickle

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import tensorflow as tf
from tensorflow.keras.models import load_model

from traffic.core import Traffic, Flight
from traffic.data import navaids

***
### Import of data, model and scalers
***

##### Trajectory data

In [None]:
t = Traffic.from_file(
    "/mnt/beegfs/store/krum/MT/encoded_scaled_split/t_test.parquet"
)

##### Model

In [None]:
def rmse_lat():
    return None


def rmse_lon():
    return None


def rmse_alt():
    return None


model = load_model(
    # f"/home/krum/git/MT_krum_code/models/robust-glitter-237.keras",
    f"/home/krum/git/MT_krum_code/models/revived-snowflake-235.keras",
    custom_objects={
        "rmse_lat": rmse_lat,
        "rmse_lon": rmse_lon,
        "rmse_alt": rmse_alt,
        "weighted_mse": tf.keras.losses.MeanSquaredError(),
    },
)

##### Scalers

In [None]:
# scaler_in
with open(
    "/mnt/beegfs/store/krum/MT/encoded_scaled_split/scaler_in.pkl",
    "rb",
) as file:
    scaler_in = pickle.load(file)

# scaler_out
with open(
    "/mnt/beegfs/store/krum/MT/encoded_scaled_split/scaler_out.pkl",
    "rb",
) as file:
    scaler_out = pickle.load(file)

***
### Erroneous SID
***

##### Functions for plotting

In [None]:
# Function to generate lists of latitudes and longitudes for navaids
def generate_latlon_list(navaid_list):
    lat = []
    lon = []
    for navaid in navaid_list:
        if type(navaid) == tuple:
            lat.append(navaid[0])
            lon.append(navaid[1])
        else:
            dat = navaids.extent("Switzerland").get(navaid)
            lat.append(dat.latitude)
            lon.append(dat.longitude)
    return lat, lon

In [None]:
# SID definitions
# GERSA SID
gersa_lat, gersa_lon = generate_latlon_list(
    [(47.45892, 8.53771), (47.46233, 8.48958), "BREGO", "ARTAG", "GERSA"]
)
# DEGES SID
deges_lat, deges_lon = generate_latlon_list(
    [
        (47.45892, 8.53771),
        (47.46233, 8.48958),
        "ZH552",
        (47.40351, 8.39876),
        (47.4128, 8.4559),
        "KLO",
        "KOLUL",
        "DEGES",
    ]
)
# VEBIT SID
vebit_lat, vebit_lon = generate_latlon_list(
    [(47.45892, 8.53771), (47.46233, 8.48958), "BREGO", "VEBIT"]
)
# ZUE SID
zue_lat, zue_lon = generate_latlon_list(
    [
        (47.45892, 8.53771),
        (47.46233, 8.48958),
        "ZH552",
        (47.40351, 8.39876),
        (47.4128, 8.4559),
        "KLO",
        "ZUE",
    ]
)

In [None]:
# Main function for visualization
def generate_plot(flightdata: Flight, sid: str = "GERSA") -> go.Figure:
    """
    Generate an animated plotly figure for a given flight data object, including
    the model inputs, true outputs and the model prediction.

    Parameters
    ----------
    flightdata : Flight
        Flight data object containing the data of a single flight.
    sid : str (optional)
        Standard Instrument Departure (SID) name. either "GERSA", "DEGES", 
        "VEBIT" or "ZUE". Will be added to the plot. Default is "GERSA".

    Returns
    -------
    go.Figure
        Plotly figure object containing the animated plot.
    """
    f = flightdata

    # Generation of subsets-----------------------------------------------------
    # Time variant inputs unscaled
    f_in_var_unscaled = f.data[
        [
            "latitude",
            "longitude",
            "altitude",
            "wind_x_2min_avg",
            "wind_y_2min_avg",
            "temperature_gnd",
            "humidity_gnd",
            "pressure_gnd",
        ]
    ]

    # Time variant inputs scaled
    f_in_var = f.data[
        [
            "latitude_scaled",
            "longitude_scaled",
            "altitude_scaled",
            "wind_x_2min_avg_scaled",
            "wind_y_2min_avg_scaled",
            "temperature_gnd_scaled",
            "humidity_gnd_scaled",
            "pressure_gnd_scaled",
        ]
    ]

    # Time invariant inputs
    f_in_con = f.data[
        [
            "toff_weight_kg_scaled",
            "typecode_A20N",
            "typecode_A21N",
            "typecode_A319",
            "typecode_A320",
            "typecode_A321",
            "typecode_A333",
            "typecode_A343",
            "typecode_B77W",
            "typecode_BCS1",
            "typecode_BCS3",
            "typecode_CRJ9",
            "typecode_DH8D",
            "typecode_E190",
            "typecode_E195",
            "typecode_E290",
            "typecode_E295",
            "typecode_F100",
            "typecode_SB20",
            "SID_DEGES",
            "SID_GERSA",
            "SID_VEBIT",
            "SID_ZUE",
            "hour_sin",
            "hour_cos",
            "weekday_sin",
            "weekday_cos",
            "month_sin",
            "month_cos",
        ]
    ]

    # True outputs
    f_out_unscaled = f.data[
        [
            "latitude",
            "longitude",
            "altitude",
        ]
    ]

    ##### Generation of input samples along trajectory--------------------------
    # Variable inputs unscaled
    inputs_var_unscaled = []

    for i in range(len(f_in_var_unscaled) - 10 - 180):
        inputs_var_unscaled.append(f_in_var_unscaled.iloc[i : i + 10].to_numpy())

    flattened_input = [item for sublist in inputs_var_unscaled for item in sublist]
    input_var_unscaled = np.stack(flattened_input).reshape(-1, 10, 8)

    # Variable inputs scaled
    inputs_var = []

    for i in range(len(f_in_var) - 10 - 180):
        inputs_var.append(f_in_var.iloc[i : i + 10].to_numpy())

    flattened_input = [item for sublist in inputs_var for item in sublist]
    input_var = np.stack(flattened_input).reshape(-1, 10, 8)

    # Constant inputs
    inputs_con = []

    for i in range(len(f_in_con) - 10 - 180):
        inputs_con.append(f_in_con.iloc[i + 10].to_numpy())

    flattened_input = [item for sublist in inputs_con for item in sublist]
    input_con = np.stack(flattened_input).reshape(-1, 1, 29)

    # True_outputs unscaled
    outputs_true = []

    for i in range(len(f_out_unscaled) - 10 - 180):
        outputs_true.append(f_out_unscaled.iloc[i + 15 : i + 15 + 180 : 5].to_numpy())

    flattened_input = [item for sublist in outputs_true for item in sublist]
    output_true = np.stack(flattened_input).reshape(-1, 36, 3)

    # Application of model------------------------------------------------------
    output = model.predict((input_var, input_con))
    output_unscaled = scaler_out.inverse_transform(output.reshape(-1, 3)).reshape(
        -1, 37, 3
    )[:, 1:, :]

    # Generation of plotting data-----------------------------------------------
    # Flight trajectory
    df_full = f.data

    # Input output pairs
    input_output = []

    # For each timestamp generate a dataframe and append to list
    for i in range(input_var.shape[0]):
        lat_in = input_var_unscaled[i, :, 0]
        lon_in = input_var_unscaled[i, :, 1]
        alt_in = input_var_unscaled[i, :, 2]

        lat_out = output_unscaled[i, :, 0]
        lon_out = output_unscaled[i, :, 1]
        alt_out = output_unscaled[i, :, 2]

        lat_out_true = output_true[i, :, 0]
        lon_out_true = output_true[i, :, 1]
        alt_out_true = output_true[i, :, 2]

        # Model input
        df1 = pd.DataFrame(
            {
                "latitude": lat_in,
                "longitude": lon_in,
                "altitude": alt_in,
                "color": "#636efa",
            }
        )

        # True output
        df2 = pd.DataFrame(
            {
                "latitude": lat_out_true,
                "longitude": lon_out_true,
                "altitude": alt_out_true,
                "color": "grey",
            }
        )

        # Model prediction
        df3 = pd.DataFrame(
            {
                "latitude": lat_out,
                "longitude": lon_out,
                "altitude": alt_out,
                "color": "#00cc96",
            }
        )

        # Concatenate input and output and add to list
        df = pd.concat([df1, df2, df3])
        input_output.append(df)

    # Plotting------------------------------------------------------------------

    # Simulation length (timestamps)
    lengths = len(df_full) - 180 - 10

    # Generate subplots------------------------
    fig = make_subplots(
        rows=2,
        cols=1,
        specs=[[{"type": "scattermapbox"}], [{}]],
        row_heights=[0.6, 0.4],
        subplot_titles=("Position", "Altitude"),
        vertical_spacing=0.07,
    )

    # Add initial traces------------------------
    # SID
    if sid == "GERSA":
        lat, lon, col = gersa_lat, gersa_lon, "#dd6044"#"#6e77f4"
    elif sid == "DEGES":
        lat, lon, col = deges_lat, deges_lon, "#dd6044"#"#dd6044"
    elif sid == "VEBIT":
        lat, lon, col = vebit_lat, vebit_lon, "#dd6044"#"#5dc999"
    elif sid == "ZUE":
        lat, lon, col = zue_lat, zue_lon, "#dd6044"#"#a167f4"

    fig.append_trace(
        go.Scattermapbox(
            mode="lines",
            lat=lat,
            lon=lon,
            line=dict(width=7, color=col),
            opacity=0.3,
            name=f"{sid} SID",
            showlegend=True,
            legendgroup="fix"
        ),
        row=1,
        col=1,
    )

    # Full trajectory position
    fig.append_trace(
        go.Scattermapbox(
            mode="lines",
            lat=df_full["latitude"],
            lon=df_full["longitude"],
            line=dict(width=4, color="lightgrey"),
            name="Full trajectory",
            showlegend=True,
            legendgroup="fix"
        ),
        row=1,
        col=1,
    )

    # Model input-output position
    fig.append_trace(
        go.Scattermapbox(
            lat=input_output[0]["latitude"],
            lon=input_output[0]["longitude"],
            marker=dict(size=7, color=input_output[0]["color"]),
            mode="markers",
            showlegend=False,
        ),
        row=1,
        col=1,
    )

    # Full trajectory altitude
    fig.append_trace(
        go.Scatter(
            x=np.arange(0, len(df_full)),
            y=df_full["altitude"],
            line=dict(width=4, color="lightgrey"),
            mode="lines",
            name="full",
            showlegend=False,
        ),
        row=2,
        col=1,
    )

    # True altitude
    fig.append_trace(
        go.Scatter(
            x=np.arange(15, len(df_full), 5),
            y=input_output[0].iloc[10:46]["altitude"],
            marker=dict(size=6, color=input_output[0].iloc[10:46]["color"]),
            mode="markers",
            name="True Output",
            showlegend=True,
            legendgroup="model"
        ),
        row=2,
        col=1,
    )

    # Input altitude
    fig.append_trace(
        go.Scatter(
            x=np.arange(0, len(df_full)),
            y=input_output[0].iloc[0:10]["altitude"],
            marker=dict(size=6, color=input_output[0].iloc[0:10]["color"]),
            mode="markers",
            name="Model input",
            showlegend=True,
            legendgroup="model"
        ),
        row=2,
        col=1,
    )

    # Predicted altitude
    fig.append_trace(
        go.Scatter(
            x=np.arange(15, len(df_full), 5),
            y=input_output[0].iloc[46:]["altitude"],
            marker=dict(size=6, color=input_output[0].iloc[46:]["color"]),
            mode="markers",
            name="Prediction",
            showlegend=True,
            legendgroup="model"
        ),
        row=2,
        col=1,
    )

    # Update Figure layout----------------------
    fig.update_xaxes(
        title_text="Elapsed time [s]",
        range=[0, len(df_full)],
        row=2,
        col=1,
        title_font={"size": 25},
        tickfont={"size": 20},
    )

    fig.update_yaxes(
        title_text="Baroaltitude [ft]",
        range=[df_full.altitude.min(), df_full.altitude.max()],
        row=2,
        col=1,
        title_font={"size": 25},
        tickfont={"size": 20},
    )

    fig.update_layout(
        mapbox=dict(
            style="carto-positron",
            zoom=10,
            center=dict(
                lat=np.mean(df_full["latitude"].mean()),
                lon=np.mean(df_full["longitude"].mean()),
            ),
        ),
        width=1600,
        height=1000,
        margin=dict(l=50, r=20, t=40, b=40),
    )

    # Animation--------------------------------
    # Creation of animation frames
    frames = [
        go.Frame(
            data=[
                # Position
                go.Scattermapbox(
                    lat=input_output[k]["latitude"],
                    lon=input_output[k]["longitude"],
                    marker=dict(size=7, color=input_output[k]["color"]),
                    mode="markers",
                    showlegend=False,
                    legendgroup="model"
                ),
                # Altitude true output
                go.Scatter(
                    x=np.arange(k + 15, k + 15 + 180, 5),
                    y=input_output[k].iloc[10:46]["altitude"],
                    marker=dict(size=6, color=input_output[k].iloc[10:46]["color"]),
                    mode="markers",
                    name="True Output",
                    showlegend=True,
                    legendgroup="model"
                ),
                # Altitude input
                go.Scatter(
                    x=np.arange(k, k + 180),
                    y=input_output[k].iloc[0:10]["altitude"],
                    marker=dict(size=6, color=input_output[k].iloc[0:10]["color"]),
                    mode="markers",
                    name="Model input",
                    showlegend=True,
                    legendgroup="model"
                ),
                # Altitude true output
                go.Scatter(
                    x=np.arange(k + 15, k + 15 + 180, 5),
                    y=input_output[k].iloc[46:]["altitude"],
                    marker=dict(size=6, color=input_output[k].iloc[46:]["color"]),
                    mode="markers",
                    name="Prediction",
                    showlegend=True,
                    legendgroup="model"
                ),
            ],
            # Frame names
            name=f"fr{k}",
            traces=[
                2,
                4,
                5,
                6,
                7,
            ],
        )
        # Loop over all timestamps
        for k in range(lengths)
    ]

    # Add frames to the figure
    fig.update(frames=frames)

    # Animation frame arguments
    def frame_args(duration):
        return {
            "frame": {"duration": duration},
            "mode": "immediate",
            "fromcurrent": True,
            "transition": {"duration": duration, "easing": "linear"},
        }

    # Animation frame duration
    fr_duration = 750

    # Slider configuration
    sliders = [
        {
            "pad": {"b": 10, "t": 50},
            "len": 0.9,
            "x": 0.1,
            "y": 0,
            "font": {"size": 20},
            "steps": [
                {
                    "args": [[f.name], frame_args(fr_duration)],
                    "label": f"{k+1}s",
                    "method": "animate",
                }
                for k, f in enumerate(fig.frames)
            ],
        }
    ]

    # Update of slider and button layout
    fig.update_layout(
        sliders=sliders,
        updatemenus=[
            {
                "buttons": [
                    {
                        "args": [None, frame_args(fr_duration)],
                        "label": "&#9654;",  # play symbol
                        "method": "animate",
                    },
                    {
                        "args": [[None], frame_args(fr_duration)],
                        "label": "&#9724;",  # pause symbol
                        "method": "animate",
                    },
                ],
                "direction": "left",
                "pad": {"r": 10, "t": 70},
                "type": "buttons",
                "x": 0.1,
                "y": 0,
            }
        ],
        legend=dict(font=dict(size=25)),
    )

    for annotation in fig["layout"]["annotations"]:
        annotation["font"] = dict(size=30)

    return fig

##### Visualisation example VEBIT altered to GERSA

In [None]:
# Selection of flight
flight = t['SWR55T_232979']

# Altering SID info
flight.data["SID_VEBIT"] = 0.0
flight.data["SID_GERSA"] = 1.0

# Visualisation
fig = generate_plot(flightdata=flight, sid="GERSA")
fig.show()

##### Visualisation example DEGES altered to VEBIT

In [None]:
# Selection of flight
flight = t['DLH7C_116899']

# Altering SID info
flight.data["SID_DEGES"] = 0.0
flight.data["SID_VEBIT"] = 1.0

# Visualisation
fig = generate_plot(flightdata=flight, sid="VEBIT")
fig.show()

***
### Erroneous Typecode
***

##### Function for visualisation

In [None]:
def generate_plot(flightdata: Flight, sid: str = "GERSA") -> go.Figure:
    """
    Generate an animated plotly figure for a given flight data object, including
    the model inputs, true outputs and the model prediction.

    Parameters
    ----------
    flightdata : Flight
        Flight data object containing the data of a single flight.
    sid : str (optional)
        Standard Instrument Departure (SID) name. either "GERSA", "DEGES", 
        "VEBIT" or "ZUE". Will be added to the plot. Default is "GERSA".

    Returns
    -------
    go.Figure
        Plotly figure object containing the animated plot.
    """
    f = flightdata

    # Generation of subsets-----------------------------------------------------
    # Time variant inputs unscaled
    f_in_var_unscaled = f.data[
        [
            "latitude",
            "longitude",
            "altitude",
            "wind_x_2min_avg",
            "wind_y_2min_avg",
            "temperature_gnd",
            "humidity_gnd",
            "pressure_gnd",
        ]
    ]

    # Time variant inputs scaled
    f_in_var = f.data[
        [
            "latitude_scaled",
            "longitude_scaled",
            "altitude_scaled",
            "wind_x_2min_avg_scaled",
            "wind_y_2min_avg_scaled",
            "temperature_gnd_scaled",
            "humidity_gnd_scaled",
            "pressure_gnd_scaled",
        ]
    ]

    # Time invariant inputs
    f_in_con = f.data[
        [
            "toff_weight_kg_scaled",
            "typecode_A20N",
            "typecode_A21N",
            "typecode_A319",
            "typecode_A320",
            "typecode_A321",
            "typecode_A333",
            "typecode_A343",
            "typecode_B77W",
            "typecode_BCS1",
            "typecode_BCS3",
            "typecode_CRJ9",
            "typecode_DH8D",
            "typecode_E190",
            "typecode_E195",
            "typecode_E290",
            "typecode_E295",
            "typecode_F100",
            "typecode_SB20",
            "SID_DEGES",
            "SID_GERSA",
            "SID_VEBIT",
            "SID_ZUE",
            "hour_sin",
            "hour_cos",
            "weekday_sin",
            "weekday_cos",
            "month_sin",
            "month_cos",
        ]
    ]

    # True outputs
    f_out_unscaled = f.data[
        [
            "latitude",
            "longitude",
            "altitude",
        ]
    ]

    ##### Generation of input samples along trajectory--------------------------
    # Variable inputs unscaled
    inputs_var_unscaled = []

    for i in range(len(f_in_var_unscaled) - 10 - 180):
        inputs_var_unscaled.append(f_in_var_unscaled.iloc[i : i + 10].to_numpy())

    flattened_input = [item for sublist in inputs_var_unscaled for item in sublist]
    input_var_unscaled = np.stack(flattened_input).reshape(-1, 10, 8)

    # Variable inputs scaled
    inputs_var = []

    for i in range(len(f_in_var) - 10 - 180):
        inputs_var.append(f_in_var.iloc[i : i + 10].to_numpy())

    flattened_input = [item for sublist in inputs_var for item in sublist]
    input_var = np.stack(flattened_input).reshape(-1, 10, 8)

    # Constant inputs
    inputs_con = []

    for i in range(len(f_in_con) - 10 - 180):
        inputs_con.append(f_in_con.iloc[i + 10].to_numpy())

    flattened_input = [item for sublist in inputs_con for item in sublist]
    input_con = np.stack(flattened_input).reshape(-1, 1, 29)

    # True_outputs unscaled
    outputs_true = []

    for i in range(len(f_out_unscaled) - 10 - 180):
        outputs_true.append(f_out_unscaled.iloc[i+15 : i+15 + 180:5].to_numpy())

    flattened_input = [item for sublist in outputs_true for item in sublist]
    output_true = np.stack(flattened_input).reshape(-1, 36, 3)

    # Application of model------------------------------------------------------
    output = model.predict((input_var, input_con))
    output_unscaled = scaler_out.inverse_transform(output.reshape(-1, 3)).reshape(
        -1, 37, 3
    )[:, 1:, :]

    # Generation of plotting data-----------------------------------------------
    # Flight trajectory
    df_full = f.data

    # Input output pairs
    input_output = []

    # For each timestamp generate a dataframe and append to list
    for i in range(input_var.shape[0]):
        lat_in = input_var_unscaled[i, :, 0]
        lon_in = input_var_unscaled[i, :, 1]
        alt_in = input_var_unscaled[i, :, 2]

        lat_out = output_unscaled[i, :, 0]
        lon_out = output_unscaled[i, :, 1]
        alt_out = output_unscaled[i, :, 2]

        lat_out_true = output_true[i, :, 0]
        lon_out_true = output_true[i, :, 1]
        alt_out_true = output_true[i, :, 2]

        # Model input
        df1 = pd.DataFrame(
            {
                "latitude": lat_in,
                "longitude": lon_in,
                "altitude": alt_in,
                "color": "#636efa",
            }
        )

        # True output
        df2 = pd.DataFrame(
            {
                "latitude": lat_out_true,
                "longitude": lon_out_true,
                "altitude": alt_out_true,
                "color": "grey",
            }
        )

        # Model prediction
        df3 = pd.DataFrame(
            {
                "latitude": lat_out,
                "longitude": lon_out,
                "altitude": alt_out,
                "color": "#00cc96",
            }
        )

        # Concatenate input and output and add to list
        df = pd.concat([df1, df2, df3])
        input_output.append(df)

    # Plotting------------------------------------------------------------------

    # Simulation length (timestamps)
    lengths = len(df_full) - 180 - 10

    # Generate subplots------------------------
    fig = make_subplots(
        rows=2,
        cols=1,
        specs=[[{"type": "scattermapbox"}], [{}]],
        row_heights=[0.6, 0.4],
        subplot_titles=("Position", "Altitude"),
        vertical_spacing=0.07,
    )

    # Add traces------------------------
    # Full trajectory position
    fig.append_trace(
        go.Scattermapbox(
            mode="lines",
            lat=df_full["latitude"],
            lon=df_full["longitude"],
            line=dict(width=4, color="lightgrey"),
            name="Full trajectory",
            showlegend=True,
            legendgroup="fix"
        ),
        row=1,
        col=1,
    )

    # Model input-output position
    fig.append_trace(
        go.Scattermapbox(
            lat=input_output[0]["latitude"],
            lon=input_output[0]["longitude"],
            marker=dict(size=7, color=input_output[0]["color"]),
            mode="markers",
            showlegend=False,
        ),
        row=1,
        col=1,
    )

    # Full trajectory altitude
    fig.append_trace(
        go.Scatter(
            x=np.arange(0, len(df_full)),
            y=df_full["altitude"],
            line=dict(width=4, color="lightgrey"),
            mode="lines",
            name="full",
            showlegend=False,
        ),
        row=2,
        col=1,
    )

    # True altitude
    fig.append_trace(
        go.Scatter(
            x=np.arange(15, len(df_full), 5),
            y=input_output[0].iloc[10:46]["altitude"],
            marker=dict(size=6, color=input_output[0].iloc[10:46]["color"]),
            mode="markers",
            name="True Output",
            showlegend=True,
            legendgroup="model"
        ),
        row=2,
        col=1,
    )

    # Input altitude
    fig.append_trace(
        go.Scatter(
            x=np.arange(0, len(df_full)),
            y=input_output[0].iloc[0:10]["altitude"],
            marker=dict(size=6, color=input_output[0].iloc[0:10]["color"]),
            mode="markers",
            name="Model input",
            showlegend=True,
            legendgroup="model"
        ),
        row=2,
        col=1,
    )

    # Predicted altitude
    fig.append_trace(
        go.Scatter(
            x=np.arange(15, len(df_full), 5),
            y=input_output[0].iloc[46:]["altitude"],
            marker=dict(size=6, color=input_output[0].iloc[46:]["color"]),
            mode="markers",
            name="Prediction",
            showlegend=True,
            legendgroup="model"
        ),
        row=2,
        col=1,
    )

    # Update Figure layout----------------------
    fig.update_xaxes(
        title_text="Elapsed time [s]",
        range=[0, len(df_full)],
        row=2,
        col=1,
        title_font={"size": 25},
        tickfont={"size": 20},
    )

    fig.update_yaxes(
        title_text="Baroaltitude [ft]",
        range=[df_full.altitude.min(), df_full.altitude.max()],
        row=2,
        col=1,
        title_font={"size": 25},
        tickfont={"size": 20},
    )

    fig.update_layout(
        mapbox=dict(
            style="carto-positron",
            zoom=10,
            center=dict(
                lat=np.mean(df_full["latitude"].mean()),
                lon=np.mean(df_full["longitude"].mean()),
            ),
        ),
        width=1600,
        height=1000,
        margin=dict(l=50, r=20, t=40, b=40),
    )

    # Animation--------------------------------
    # Creation of animation frames
    frames = [
        go.Frame(
            data=[
                # Position
                go.Scattermapbox(
                    lat=input_output[k]["latitude"],
                    lon=input_output[k]["longitude"],
                    marker=dict(size=7, color=input_output[k]["color"]),
                    mode="markers",
                    showlegend=False,
                    legendgroup="model"
                ),
                # Altitude true output
                go.Scatter(
                    x=np.arange(k + 15, k + 15 + 180, 5),
                    y=input_output[k].iloc[10:46]["altitude"],
                    marker=dict(size=6, color=input_output[k].iloc[10:46]["color"]),
                    mode="markers",
                    name="True Output",
                    showlegend=True,
                    legendgroup="model"
                ),
                # Altitude input
                go.Scatter(
                    x=np.arange(k, k + 180),
                    y=input_output[k].iloc[0:10]["altitude"],
                    marker=dict(size=6, color=input_output[k].iloc[0:10]["color"]),
                    mode="markers",
                    name="Model input",
                    showlegend=True,
                    legendgroup="model"
                ),
                # Altitude true output
                go.Scatter(
                    x=np.arange(k + 15, k + 15 + 180, 5),
                    y=input_output[k].iloc[46:]["altitude"],
                    marker=dict(size=6, color=input_output[k].iloc[46:]["color"]),
                    mode="markers",
                    name="Prediction",
                    showlegend=True,
                    legendgroup="model"
                ),
            ],
            # Frame names
            name=f"fr{k}",
            traces=[
                1,
                3,
                4,
                5,
                6,
            ],
        )
        # Loop over all timestamps
        for k in range(lengths)
    ]

    # Add frames to the figure
    fig.update(frames=frames)


    # Animation frame arguments
    def frame_args(duration):
        return {
            "frame": {"duration": duration},
            "mode": "immediate",
            "fromcurrent": True,
            "transition": {"duration": duration, "easing": "linear"},
        }


    # Animation frame duration
    fr_duration = 750

    # Slider configuration
    sliders = [
        {
            "pad": {"b": 10, "t": 50},
            "len": 0.9,
            "x": 0.1,
            "y": 0,
            "font": {"size": 20},
            "steps": [
                {
                    "args": [[f.name], frame_args(fr_duration)],
                    "label": f"{k+1}s",
                    "method": "animate",
                }
                for k, f in enumerate(fig.frames)
            ],
        }
    ]

    # Update of slider and button layout
    fig.update_layout(
        sliders=sliders,
        updatemenus=[
            {
                "buttons": [
                    {
                        "args": [None, frame_args(fr_duration)],
                        "label": "&#9654;",  # play symbol
                        "method": "animate",
                    },
                    {
                        "args": [[None], frame_args(fr_duration)],
                        "label": "&#9724;",  # pause symbol
                        "method": "animate",
                    },
                ],
                "direction": "left",
                "pad": {"r": 10, "t": 70},
                "type": "buttons",
                "x": 0.1,
                "y": 0,
            }
        ],
        legend=dict(font=dict(size=25)),
    )

    for annotation in fig["layout"]["annotations"]:
        annotation["font"] = dict(size=30)

    return fig

##### Selection of flight

In [None]:
flight = t['SWR73VK_759186']

##### Unaltered plot

In [None]:
fig = generate_plot(flight)
fig.show()

##### Altered to A340

In [None]:
# Altering typecode info
flight.data["typecode_B77W"] = 0
flight.data["typecode_BCS3"] = 0
flight.data["typecode_A343"] = 1

# Visualisation
fig = generate_plot(flight)
fig.show()

##### Altered to A220

In [None]:
# Altering typecode info
flight.data["typecode_B77W"] = 0
flight.data["typecode_BCS3"] = 1
flight.data["typecode_A343"] = 0

# Visualisation
fig = generate_plot(flight)
fig.show()