In [14]:
import pymc as pm
import arviz as az
import pytensor
import pytensor.tensor as pt
import numpy as np

import sys

sys.path.append("../..")
from cge_modeling.pytensorf.optimize import root

In [15]:
variables = Y, C, L_d, K_d, P, r, resid = list(
    map(pt.dscalar, ["Y", "C", "L_d", "K_d", "P", "r", "resid"])
)
params = K_s, L_s, A, alpha, w = list(map(pt.dscalar, ["K_s", "L_s", "A", "alpha", "w"]))


def f_model(*args):
    Y, C, L_d, K_d, P, r, resid, *params = args
    K_s, L_s, A, alpha, w = params

    equations = pt.stack(
        [
            Y - A * K_d**alpha * L_d ** (1 - alpha),
            r * K_d - alpha * Y * P,
            w * L_d - (1 - alpha) * Y * P,
            Y - C,
            P * C - w * L_s - r * K_s,
            K_d - K_s,
            L_d - L_s + resid,
        ]
    )

    return equations


equations = f_model(*variables, *params)
jac = pt.stack(pytensor.gradient.jacobian(equations, variables)).T
jac_inv = pt.linalg.solve(jac, pt.identity_like(jac), check_finite=False)

f_jac = pytensor.compile.builders.OpFromGraph(variables + params, outputs=[jac], inline=True)
f_jac_inv = pytensor.compile.builders.OpFromGraph(
    variables + params, outputs=[jac_inv], inline=True
)

In [27]:
obs_data = np.array([11000, 11000, 7000, np.nan, 1, 1, 0])
not_na = ~np.isnan(obs_data)
eye = np.eye(7)
Z = eye[not_na, :]
not_Z = eye[~not_na, :]

In [28]:
def predict_unobserved(Z, not_Z, mu, cov, y_obs):
    mu_o = Z @ mu
    mu_u = not_Z @ mu
    cov_oo = Z @ cov @ Z.T
    cov_uu = not_Z @ cov @ not_Z.T
    cov_uo = not_Z @ cov @ Z.T

    cov_oo_inv = pt.linalg.solve(cov_oo, pt.identity_like(cov_oo), assume_a="pos")

    beta = cov_uo @ cov_oo_inv
    resid = y_obs - mu_o

    mu_hat = mu_u + beta @ resid
    Sigma_hat = cov_uu - beta @ cov_oo @ beta.T

    return mu_hat, Sigma_hat

In [44]:
coords = {
    "equation": np.arange(7, dtype=int),
    "equation_aux": np.arange(7, dtype=int),
    "variable": ["Y", "C", "L_d", "K_d", "P", "r", "resid"],
    "obs_variable": ["Y", "C", "L_d", "P", "r", "resid"],
    "missing_variable": ["K_d"],
    "parameter": ["K_s", "L_s", "A", "alpha", "w"],
}

with pm.Model(coords=coords) as m:
    # Parameters
    pm_A = pm.Gamma("A", 2, 1)
    pm_alpha = pm.Beta("alpha", 3, 3)
    pm_K_s = pm.Normal("K_s", 4000, 100)
    pm_L_s = pm.Normal("L_s", 7000, 100)

    pm_P = pt.ones((), dtype="float64")
    pm_r = pt.ones((), dtype="float64")
    pm_w = pt.ones((), dtype="float64")
    pm_resid = pt.zeros((), dtype="float64")

    pm_params = pm.Deterministic(
        "params", pt.stack([pm_K_s, pm_L_s, pm_A, pm_alpha, pm_w]), dims=["parameter"]
    )

    Y0 = pm.Normal("Y", 11000, 1000)
    C0 = pm.Normal("C", 11000, 1000)
    Kd0 = pm.Normal("K_d", 4000, 100)
    Ld0 = pm.Normal("L_d", 7000, 100)

    #     x0 = pt.as_tensor(np.array([11000, 11000, 4000, 7000, 1, 1, 0], dtype='float64'))
    x0 = pm.Deterministic("x0", pt.stack([Y0, C0, Ld0, Kd0, pm_P, pm_r, 0.0]), dims=["variable"])

    root_history, converged, step_size, n_steps = root(f_model, f_jac_inv, x0=x0, exog=pm_params)
    solution = pm.Deterministic("root", root_history[-1], dims=["variable"])
    error = pm.Deterministic("error", f_model(*solution, *pm_params), dims=["equation"])
    success = pm.Deterministic("success", pt.allclose(error, 0))

    sigma = pm.HalfNormal("sigma", [10, 10, 10, 10, 0.1, 0.1, 0.1], dims=["variable"])

    y_hat = pm.MvNormal(
        "y_hat", mu=Z @ solution, cov=Z @ K @ Z.T, dims=["obs_variable"], observed=obs_data[not_na]
    )

    #     missing_mu, missing_cov = predict_unobserved(Z, not_Z, solution, cov, obs_data[not_na])
    #     y_missing = pm.MvNormal('y_missing', mu=missing_mu, cov=missing_cov, dims=['missing_variable'])

    pm.Potential("optimizer_failure", pt.switch(success, 0, -np.inf))

    #     idata = pm.sample(nuts_sampler="numpyro")
    idata = pm.sample_prior_predictive()
#     idata = pm.sample_smc(kernel=pm.smc.MH,
#                           draws=5_000,
#                           correlation_threshold=1e-2,
#                           threshold=0.5,
#                           chains=4,
#                           progressbar=False)

  idata = pm.sample_prior_predictive()
Sampling: [A, C, K_d, K_s, L_d, L_s, Y, alpha, ls, sigma, y_hat]


ValueError: The input matrix must be symmetric positive semidefinite.
Apply node that caused the error: multivariate_normal_rv{1, (1, 2), floatX, True}(RandomGeneratorSharedVariable(<Generator(PCG64) at 0x2A223C660>), [], 11, CGemv{inplace}.0, Dot22.0)
Toposort index: 64
Inputs types: [RandomGeneratorType, TensorType(int64, shape=(0,)), TensorType(int64, shape=()), TensorType(float64, shape=(6,)), TensorType(float64, shape=(6, 6))]
Inputs shapes: ['No shapes', (0,), (), (6,), (6, 6)]
Inputs strides: ['No strides', (0,), (), (8,), (48, 8)]
Inputs values: [Generator(PCG64) at 0x2A223C660, array([], dtype=int64), array(11), 'not shown', 'not shown']
Outputs clients: [['output'], ['output']]

HINT: Re-running with most PyTensor optimizations disabled could provide a back-trace showing when this node was created. This can be done by setting the PyTensor flag 'optimizer=fast_compile'. If that does not work, PyTensor optimizations can be disabled with 'optimizer=None'.
HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.

In [41]:
var_names = ["~root", "~params", "~x0", "~error", "~success"]
az.summary(idata, var_names=var_names)

Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
K_s,4000.026,0.669,3999.3,4000.942,0.333,0.255,4.0,4.0,31633592.87
L_s,7000.152,0.488,6999.607,7000.909,0.243,0.186,4.0,4.0,31633592.87
Y,11000.118,0.529,10999.477,11000.821,0.263,0.202,4.0,4.0,31633592.87
C,10999.732,0.395,10999.308,11000.317,0.197,0.15,4.0,4.0,31633592.87
K_d,3999.992,0.48,3999.413,4000.657,0.239,0.183,4.0,4.0,31633592.87
L_d,6999.792,0.276,6999.533,7000.258,0.138,0.105,4.0,4.0,31633592.87
A,2.711,1.307,1.056,4.139,0.651,0.498,4.0,4.0,31633592.87
alpha,0.45,0.145,0.272,0.676,0.072,0.055,4.0,4.0,31633592.87
sigma[Y],15.826,5.336,10.419,24.581,2.657,2.035,4.0,4.0,31633592.87
sigma[C],12.117,2.908,9.562,17.051,1.448,1.109,4.0,4.0,31633592.87
