In [None]:
%load_ext autoreload
%autoreload 2

from pathlib import Path
import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pymc as pm
from patsy import dmatrix
from py12box_invert.paths import Paths
from py12box_invert.invert import Invert
from py12box_invert.plot import plot_mf, plot_emissions
from py12box import core

Use a set of knots and splines to estimate emissions over time.
Just do a standard Normal inversion with MCMC

In [None]:
species = "CFC-11"
project_path = Paths.data / f"example/{species}"

In [None]:
# Just run to get matrices
inv_true = Invert(project_path, species, method="analytical_gaussian", n_threads=4, sensitivity_freq="yearly", start_year=2000, end_year=2020)
inv_true.run_inversion([70., 20., 10., 10.])

In [None]:
# Get all years where we have data
# Note, there's no data in 2020 so the times in the H matrix are only until 2019
time = np.arange(2000,2021)

In [None]:
# Start by just having 11 evenly spaced knots over the 20 years
knot_list = np.linspace(2000,2020, 10)

B_i = dmatrix(
    "bs(time, knots=knots, degree=3, include_intercept=False) - 1",
    {"time": time, "knots": knot_list[1:-1]},
)

# As we have 4 boxes, we're going to have to repeat it
B = np.repeat(np.asarray(B_i, order="F"), 4, axis=0)

In [None]:
# Now do the MCMC
# For now just put a prior on each knot of N(0, 100^2). 
# Actually set the prior on a standard normal, then scale
# Better would be to extract each knot's prior at the right time
COORDS = {"splines": np.arange(B.shape[1])}
with pm.Model(coords=COORDS) as spline_model:
    w_ = pm.Normal("w_", mu=0, sigma=1, size=B.shape[1], dims="splines")
    w = pm.Deterministic("w",  0 + 100*w_)
    x = pm.Deterministic("mu", inv_true.mat.H @ B @ w.T)
    Y = pm.Normal("Y", mu=x, sigma=np.diag(inv_true.mat.R), observed=inv_true.mat.y, dims="obs")

    idata = pm.sample_prior_predictive()
    idata.extend(pm.sample(draws=50000, tune=10000, chains=2, step=pm.Metropolis()))
    pm.sample_posterior_predictive(idata, extend_inferencedata=True)

In [None]:
# Get posterior mean of knots
wp = idata.posterior["w"].mean(("chain", "draw")).values
x_hat = np.dot(B,wp.T)

In [None]:
# Lazy way to plot results 
apriori = inv_true.mod.emissions[::12].sum(1)
plt.plot(time, apriori + np.sum(inv_true.mat.x_hat.reshape(-1,4),1))
plt.plot(time, apriori + np.sum(x_hat.reshape(-1,4),1))
