In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import xarray as xr
import dask.array as da
import pyproj
import sys

pd.set_option("display.max_rows", 200)
import pdb

from pycontrails import Flight, Fleet, MetDataset
from pycontrails.core import models
from pycontrails.datalib.ecmwf import ERA5
from pycontrails.physics import geo, thermo, units, constants

# from pycontrails.models.ps_model import PSFlight
# from pycontrails.models.emissions import Emissions
from pycontrails.ext.flight_gen import FlightGen
from pycontrails.models.boxmodel.boxm import Boxm

# from pycontrails.models.dry_advection import DryAdvection
from pycontrails.core.met_var import (
    AirTemperature,
    RelativeHumidity,
    SpecificHumidity,
    EastwardWind,
    NorthwardWind,
    VerticalVelocity,
)

In [None]:
# meteorological parameters
met_params = {
    "air_temperature": 235.0,  # K
    "specific_humidity": 0.003,  # 1
    "relative_humidity": 0.5,  # 1
    "eastward_wind": 0.0,  # m/s
    "northward_wind": 0.0,  # m/s
    "lagrangian_tendency_of_air_pressure": 0.0,  # m/s
}

In [None]:
# flight trajectory parameters
fl_params = {
    "t0_fl": pd.to_datetime("2022-01-20 14:00:00"),  # flight start time
    "rt_fl": pd.Timedelta(minutes=30),  # flight run time
    "ts_fl": pd.Timedelta(minutes=2),  # flight time step
    "ac_type": "A320",  # aircraft type
    "fl0_speed": 100.0,  # m/s
    "fl0_heading": 45.0,  # deg
    "fl0_coords0": (47.5, -32.5, 12500),  # lat, lon, alt [deg, deg, m]
    "sep_dist": (5000, 2000, 0),  # dx, dy, dz [m]
    "n_ac": 5,  # number of aircraft
}

In [None]:
# plume dispersion parameters
plume_params = {
    "dt_integration": pd.Timedelta(minutes=5),  # integration time step
    "max_age": pd.Timedelta(hours=2),  # maximum age of the plume
    "depth": 50.0,  # initial plume depth, [m]
    "width": 50.0,  # initial plume width, [m]
    "shear": 0.005,  # wind shear [1/s]
}

In [None]:
# chemistry sim parameters
chem_params = {
    "t0_chem": pd.to_datetime("2022-01-20 12:00:00"),  # chemistry start time
    "rt_chem": pd.Timedelta(days=30),  # chemistry runtime
    "ts_chem": pd.Timedelta(seconds=20),  # chemistry time step
    "lat_bounds": (47.0, 48.0),  # lat bounds [deg]
    "lon_bounds": (-33.0, -32.0),  # lon bounds [deg]
    "alt_bounds": (12000, 13000),  # alt bounds [m]
    "hres_chem": 0.5,  # horizontal resolution [deg]
    "vres_chem": 500,  # vertical resolution [m]
}

In [None]:
lats = np.arange(
    chem_params["lat_bounds"][0], chem_params["lat_bounds"][1] + chem_params["hres_chem"], chem_params["hres_chem"]
)

lons = np.arange(
    chem_params["lon_bounds"][0], chem_params["lon_bounds"][1] + chem_params["hres_chem"], chem_params["hres_chem"]
)

alts = np.arange(
    chem_params["alt_bounds"][0], chem_params["alt_bounds"][1] + chem_params["vres_chem"], chem_params["vres_chem"]
)

times = pd.date_range(
    start=chem_params["t0_chem"],
    end=chem_params["t0_chem"] + chem_params["rt_chem"],
    freq=chem_params["ts_chem"],
)

In [None]:
# generate artifical met dataset (boxm currently only supports zero-wind scenarios)
data_vars = {
    param: (
        ["longitude", "latitude", "level", "time"],
        da.full(
            (len(lons), len(lats), len(alts), len(times)),
            value,
            chunks=(len(lons), len(lats), len(alts), 100),
        ),
    )
    for param, value in met_params.items()
}

met = xr.Dataset(
    data_vars,
    coords={"longitude": lons, "latitude": lats, "level": units.m_to_pl(alts), "time": times},
)

met = MetDataset(met)

met

In [None]:
# instantiate FlightGen object
fl_gen = FlightGen(met, fl_params, plume_params, chem_params)

In [None]:
fl = fl_gen.traj_gen()

fl

In [None]:
# estimate fuel burn and emissions using ps_model and emissions model
fl = fl_gen.calc_fb_emissions()

In [None]:
# visualise the fleet
ax = plt.axes()
ax.set_xlim([lons[0], lons[-1]])
ax.set_ylim([lats[0], lats[-1]])
for i in fl:
    i.plot(ax=ax)

In [None]:
# simulate plume dispersion/advection using dry advection model
fl_df, pl_df = fl_gen.sim_plumes()

pd.set_option("display.max_rows", 500)
pd.set_option("display.max_columns", 30)

fl_df

In [None]:
# fl_gen.anim_fl(fl_df, pl_df)

In [None]:
# convert plume dataframe to EMI geospatial xarray dataset
emi = fl_gen.plume_to_grid(lats, lons, alts, times)

In [None]:
# init boxm simulation and generate chemistry dataset
boxm = Boxm(met=met, params=chem_params)

In [None]:
# run boxm simulation
chem = boxm.eval(emi)

In [None]:
# plt.figure(figsize=(10, 6))  # Optional: Adjust figure size

# print(boxm.boxm_ds["sza"].sel(cell=0).values[500])
# print(boxm.boxm_ds["ZEN_orig"].values[500])

# boxm.boxm_ds["sza"].sel(cell=0).plot()
# boxm.boxm_ds["ZEN_orig"].plot()
# plt.legend()
# plt.show()

In [None]:
plt.figure(figsize=(10, 6))  # Optional: Adjust figure size

#boxm.boxm_ds["Y"].sel(species="NO", cell=13).plot()
# boxm.boxm_ds["Y_orig"].sel(species="O3").plot()
# plt.legend()
# plt.show()

boxm.boxm_ds["species"].values


In [None]:
# fig, axs = plt.subplots(3, 3, sharex=True, figsize=(15, 12), constrained_layout=True)
# fig.suptitle('Box model validation - North Atlantic Airspace')

# axs[0, 0].plot(boxm.boxm_ds["Y"].sel(species="NO", cell=13).values, 'r.-', lw=0.2)
# axs[0, 0].plot(boxm.boxm_ds["Y_orig"].sel(species="NO").values, 'b.-', lw=0.2)
# axs[0, 0].set(title='NO', xlabel='time [s]', ylabel='concentration [ppb]')

# axs[0, 1].plot(boxm.boxm_ds["Y"].sel(species="NO2", cell=13).values, 'r.-', lw=0.2)
# axs[0, 1].plot(boxm.boxm_ds["Y_orig"].sel(species="NO2").values, 'b.-', lw=0.2)
# axs[0, 1].set(title='NO2', xlabel='time [s]', ylabel='concentration [ppb]')

# axs[0, 2].plot(boxm.boxm_ds["Y"].sel(species="OH", cell=13).values, 'r.-', lw=0.2)
# axs[0, 2].plot(boxm.boxm_ds["Y_orig"].sel(species="OH").values, 'b.-', lw=0.2)
# axs[0, 2].set(title='OH', xlabel='time [s]', ylabel='concentration [ppb]')

# axs[1, 0].plot(boxm.boxm_ds["Y"].sel(species="HO2", cell=13).values, 'r.-', lw=0.2)
# axs[1, 0].plot(boxm.boxm_ds["Y_orig"].sel(species="HO2").values, 'b.-', lw=0.2)
# axs[1, 0].set(title='HO2', xlabel='time [s]', ylabel='concentration [ppb]')

# axs[1, 1].plot(boxm.boxm_ds["Y"].sel(species="O3", cell=13).values, 'r.-', lw=0.2)
# axs[1, 1].plot(boxm.boxm_ds["Y_orig"].sel(species="O3").values, 'b.-', lw=0.2)
# axs[1, 1].set(title='O3', xlabel='time [s]', ylabel='concentration [ppb]')

# axs[1, 2].plot(boxm.boxm_ds["Y"].sel(species="CO", cell=13).values, 'r.-', lw=0.2)
# axs[1, 2].plot(boxm.boxm_ds["Y_orig"].sel(species="CO").values, 'b.-', lw=0.2)
# axs[1, 2].set(title='CO', xlabel='time [s]', ylabel='concentration [ppb]')

# axs[2, 0].plot(boxm.boxm_ds["Y"].sel(species="CH4", cell=13).values, 'r.-', lw=0.2)
# axs[2, 0].plot(boxm.boxm_ds["Y_orig"].sel(species="CH4").values, 'b.-', lw=0.2)
# axs[2, 0].set(title='CH4', xlabel='time [s]', ylabel='concentration [ppb]')

# axs[2, 1].plot(boxm.boxm_ds["Y"].sel(species="HCHO", cell=13).values, 'r.-', lw=0.2)
# axs[2, 1].plot(boxm.boxm_ds["Y_orig"].sel(species="HCHO").values, 'b.-', lw=0.2)
# axs[2, 1].set(title='HCHO', xlabel='time [s]', ylabel='concentration [ppb]')

# axs[2, 2].plot(boxm.boxm_ds["Y"].sel(species="HONO", cell=13).values, 'r.-', lw=0.2)
# axs[2, 2].plot(boxm.boxm_ds["Y_orig"].sel(species="HONO").values, 'b.-', lw=0.2)
# axs[2, 2].set(title='HONO', xlabel='time [s]', ylabel='concentration [ppb]')

# plt.savefig('/home/ktait98/pycontrails_kt/pycontrails/models/files/validation/NA_Airspace.png', dpi=1000)

In [None]:
r_squared = np.zeros(len(boxm.boxm_ds["species"].values))

for s, species in enumerate(boxm.boxm_ds["species"].values):
    ds_hourly = boxm.boxm_ds["Y"].sel(species=species, cell=13).resample(time="1h").mean()
    ds_hourly_orig = boxm.boxm_ds["Y_orig"].sel(species=species).resample(time="1h").mean()
    correlation_coefficient = (np.sum((ds_hourly - ds_hourly.mean()) * (ds_hourly_orig - ds_hourly_orig.mean()))) / \
                          np.sqrt(np.sum((ds_hourly - ds_hourly.mean()) ** 2) * np.sum((ds_hourly_orig - ds_hourly_orig.mean()) ** 2))
    r_squared_s = correlation_coefficient ** 2

    r_squared[s] = r_squared_s
    print(f"R^2 for {species} is {r_squared_s}")

# Convert the r_squared array to a pandas DataFrame
r_squared_df = pd.DataFrame(r_squared, columns=["r_squared"])

# send to csv
r_squared_df.to_csv(f"/home/ktait98/pycontrails_kt/pycontrails/models/files/validation/r_squared_NA_Airspace.csv")


# rmse = np.sqrt(((ds_hourly - ds_hourly_orig) ** 2).mean())
# nmrse = rmse / ds_hourly_orig.mean()

# nmrse


In [None]:
correlation_coefficient = (np.sum((ds_hourly - ds_hourly.mean()) * (ds_hourly_orig - ds_hourly_orig.mean()))) / \
                          np.sqrt(np.sum((ds_hourly - ds_hourly.mean()) ** 2) * np.sum((ds_hourly_orig - ds_hourly_orig.mean()) ** 2))
r_squared = correlation_coefficient ** 2

r_squared

In [None]:
boxm.boxm_ds["Y"].sel(species="NO3").isel(time=1000).plot()

In [None]:
#boxm.anim_chem("Y_perc_diff", "O3", 410.6)

In [None]:
boxm.params