In [None]:
import os
import yaml
import datetime
import numpy as np
import pandas as pd
import xarray as xr
import colorcet as cc
import cartopy.crs as ccrs
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.dates as mdates

from period1.utils import (
    plot_field,
    sum_total_emissions,
    plot_time_series,
    get_posterior_emissions,
    get_period_mean_emissions,
)

In [None]:
def to_datetime_list(date_str_list):
    """
    Convert a list of date strings to a list of datetime objects.
    """
    return [
        datetime.datetime.strptime(str(date_str), "%Y%m%d")
        for date_str in date_str_list
    ]

In [None]:
# Read the configuration file *update if not on aws*
config = yaml.load(
    open("/home/ubuntu/integrated_methane_inversion/config.yml"), Loader=yaml.FullLoader
)

In [None]:
# Open the state vector file
state_vector_filepath = "./../StateVector.nc"
state_vector = xr.load_dataset(state_vector_filepath)
state_vector_labels = state_vector["StateVector"]

# Identify the last element of the region of interest
last_ROI_element = int(
    np.nanmax(state_vector_labels.values) - config["nBufferClusters"]
)

# Define mask for region of interest
mask = state_vector_labels <= last_ROI_element

In [None]:
# Paths to prior emissions, inversion results, GEOS/satellite data, posterior simulation
# Get the current directory
cwd = os.getcwd()


def extract_number(s):
    # This function extracts the last number from a string
    # by splitting on non-digit characters and taking the last part
    return int("".join(filter(str.isdigit, s.split()[-1])))


# Extract the start dates of each period
periods_df = pd.read_csv("./../periods.csv")
start_dates = periods_df.iloc[:, 0].tolist()
end_dates = periods_df.iloc[:, 1].tolist()
num_periods = len(start_dates)
inv_result_filename = (
    "inversion_result_ln.nc" if config["LognormalErrors"] else "inversion_result.nc"
)
gridded_posterior_filename = (
    "gridded_posterior_ln.nc" if config["LognormalErrors"] else "gridded_posterior.nc"
)

prior_cache_path = f"./../hemco_prior_emis/OutputDir/"
results_prefixes = sorted(
    [
        f"./{name}/"
        for name in os.listdir(cwd)
        if os.path.isdir(os.path.join(cwd, name))
    ],
    key=extract_number,
)
results_paths = [prefix + gridded_posterior_filename for prefix in results_prefixes]
satdat_dirs = [prefix + "data_converted" for prefix in results_prefixes]
inversion_result_paths = [prefix + inv_result_filename for prefix in results_prefixes]
posterior_dirs = [prefix + "data_converted_posterior" for prefix in results_prefixes]
visualization_dirs = [prefix + "data_visualization" for prefix in results_prefixes]
posterior_viz_dirs = [
    prefix + "data_visualization_posterior" for prefix in results_prefixes
]
sf_paths = [
    f"./../archive_sf/posterior_sf_period{i}.nc" for i in range(1, num_periods + 1)
]

In [None]:
# Set latitude/longitude bounds for plots

# Trim 1-2.5 degrees to remove GEOS-Chem buffer zone
if config["Res"] == "0.25x0.3125":
    degx = 4 * 0.3125
    degy = 4 * 0.25
elif config["Res"] == "0.5x0.625":
    degx = 4 * 0.625
    degy = 4 * 0.5
elif config["Res"] == "2.0x2.5":
    degx = 4 * 2.5
    degy = 4 * 2.0

lon_bounds = [
    np.min(state_vector.lon.values) + degx,
    np.max(state_vector.lon.values) - degx,
]
lat_bounds = [
    np.min(state_vector.lat.values) + degy,
    np.max(state_vector.lat.values) - degy,
]

# State Vector

In [None]:
fig = plt.figure(figsize=(8, 8))
plt.rcParams.update({"font.size": 16})
ax = fig.subplots(1, 1, subplot_kw={"projection": ccrs.PlateCarree()})
num_colors = state_vector_labels.where(mask).max().item()
sv_cmap = matplotlib.colors.ListedColormap(np.random.rand(int(num_colors), 3))
plot_field(
    ax,
    state_vector_labels,
    cmap=sv_cmap,
    title="State vector elements",
    cbar_label="Element Id",
)

# Calculate emissions for each inversion interval

In [None]:
# Prior emissions
priors_ds = [
    get_period_mean_emissions(prior_cache_path, period + 1, "./../periods.csv")
    for period in range(periods_df.shape[0])
]
priors = [prior["EmisCH4_Total"] for prior in priors_ds]

# Optimized scale factors
scales = [xr.load_dataset(sf_path) for sf_path in sf_paths]

# Posterior emissions
posteriors_ds = [
    get_posterior_emissions(priors_ds[i], scales[i]) for i in range(num_periods)
]
posteriors = [posterior["EmisCH4_Total"] for posterior in posteriors_ds]

In [None]:
# Calculate total emissions per interval in the region of interest
areas = [ds["AREA"] for ds in priors_ds]

total_prior_emissions_per_period = [
    sum_total_emissions(priors[i], areas[i], mask) for i in range(num_periods)
]
total_posterior_emissions_per_period = [
    sum_total_emissions(posteriors[i], areas[i], mask) for i in range(num_periods)
]
posterior_df = pd.DataFrame(
    {
        "Date": to_datetime_list(end_dates),
        "Emissions": total_posterior_emissions_per_period,
    }
)
# Calculate the moving average for 4 intervals at a time
smoothing_window = 4
smoothing_num_days = config["UpdateFreqDays"] * smoothing_window
posterior_df["MovingAverage"] = (
    posterior_df["Emissions"]
    .rolling(window=smoothing_window, min_periods=1, center=True)
    .mean()
)

In [None]:
# Averaging kernel and DOFS
A_ROIs = [
    xr.load_dataset(inv_results_pth)["A"].values[:last_ROI_element, :last_ROI_element]
    for inv_results_pth in inversion_result_paths
]

# Calculate DOFS from averaging kernel
DOFS = [np.trace(A_ROI) for A_ROI in A_ROIs]

# Plot emission variability over inversion period

In [None]:
# plot time series with emissions, moving average, and DOFS
y_data = [posterior_df["Emissions"], posterior_df["MovingAverage"]]
line_labels = ["Weekly Emission", f"{smoothing_num_days}-day Moving Average"]
plot_time_series(
    posterior_df["Date"],
    y_data,
    line_labels,
    "Posterior Emissions Time Series",
    "Methane Emissions Tg/yr",
    DOFS=DOFS,
)