In [1]:
import numpy as np
import pymc as pm
import pymc.sampling.jax
import jax
import jax.numpy as jnp
import jax.random as jrandom
from scipy.stats import differential_entropy as entr

from sim import models, utils

seed = 8675309
np.random.seed(seed)
# Init the Jax rng seeder
jrng = utils.JaxRKey(seed)



In [5]:
# Global variables
NUM_PARAM_BATCHES = 10
T_END = 100
SIM_STEPS = 100

time_points = jnp.linspace(0., T_END, SIM_STEPS)


In [8]:
with pm.Model() as pm_model:
    B0 = pm.Normal("B", mu=90000, sigma=2000)
    w = pm.LogNormal('w', mu=1., sigma=0.00001)
    r = pm.Normal("r", mu=0.2, sigma=0.06)
    k = pm.Normal("k", mu=100000, sigma=10000)
    qE = pm.Normal("qE", mu=0.01, sigma=0.01)

    samples = pm.sample_prior_predictive(samples=NUM_PARAM_BATCHES)

p = utils.Params(**samples.prior)

Sampling: [B, k, qE, r, w]


In [9]:
wt_end = 100
num_points = 20
D = 1000
mc = 10
horizon = 20
P0 = 1
rho = -0.9
C0 = 1
gamma = -0.9

dynamics = models.EulerMaruyamaDynamics(T_END, SIM_STEPS, D)
revenue_model = models.RevenueModel(P0=P0, rho=rho)
cost_model = models.CostModel(C0=C0, gamma=gamma)
policy = models.Policy(revenue_model, cost_model)
loss_model = models.LossModel()
risk_model = models.RiskModel()

model = models.Model(
    p,
    mc,
    dynamics,
    horizon,
    policy,
    revenue_model,
    cost_model,
    loss_model,
    risk_model,
    debug=True,
)

model()

REVENUE -- 1, <xarray.DataArray (chain: 1, draw: 10)> Size: 80B
array([[ 3.50311502,  0.        , 11.95552178, 12.15771629, 11.37556396,
        45.02305235, 27.61077706,  0.        , 42.23722278, 18.62721614]])
Coordinates:
  * chain    (chain) int64 8B 0
  * draw     (draw) int64 80B 0 1 2 3 4 5 6 7 8 9, -0.9
REVENUE -- 1, [[3.47112815e+03 1.04930723e+05 3.70156645e+03 0.00000000e+00
  1.01405897e+03 4.35150472e+04 9.48118536e+00 9.50223492e+04
  3.38973697e+04 4.96613474e+01]
 [2.46778328e+04 1.04930058e+05 1.53748606e+02 0.00000000e+00
  5.19723307e+03 4.35370576e+04 3.11890690e+01 8.28943088e+04
  3.39533824e+04 9.46059555e+01]
 [3.00037538e+04 1.04924896e+05 2.87738565e+03 2.91047762e+01
  1.54562876e+03 4.34868514e+04 2.85198795e-01 6.25621548e+04
  3.39753269e+04 8.14530962e+01]
 [2.85097068e+04 1.04935984e+05 1.67796944e+03 8.82547822e+00
  3.27625586e+03 4.35406057e+04 0.00000000e+00 6.73338066e+03
  3.39745443e+04 1.49220737e+02]
 [2.58317018e+04 1.04921824e+05 4.33414796e+0

  PB = self.P0 * B ** self.rho
  return PB * qE * B
  term1 = 1/(n-m) * np.sum(np.log((n+1)/m * difference), axis=-1)
  Es = 1 - (coef * Bp) ** inv_gamma_power
  PB = self.P0 * B ** self.rho
  return PB * qE * B


REVENUE -- 1, [[0.00000000e+00 0.00000000e+00 0.00000000e+00 8.46899328e+02
  1.67222389e+01 2.59607476e-01 3.13639014e+01 8.80408780e-01
  2.60689599e+01 4.79949492e+01]
 [0.00000000e+00 1.69665072e+00 5.15369934e+01 1.04990005e+03
  2.49558643e-01 3.16233134e+01 8.59035792e+00 9.47420663e+00
  6.15117226e+00 0.00000000e+00]
 [0.00000000e+00 0.00000000e+00 1.36048966e+01 0.00000000e+00
  1.36243102e+01 6.10678500e+00 8.77208459e+04 1.27233705e+01
  0.00000000e+00 8.68163249e+00]
 [7.94064843e+00 0.00000000e+00 2.94852590e+01 8.45409398e+00
  0.00000000e+00 1.03842527e+01 8.71610335e+04 1.79417035e+01
  0.00000000e+00 1.17565104e+01]
 [1.46906083e+01 4.84933119e+00 0.00000000e+00 3.98236616e+00
  1.67707270e+01 3.85448524e+01 8.65144952e+04 7.34079728e+00
  1.71505037e+01 3.32942146e+01]
 [1.39255550e+01 5.71205036e+00 3.09278329e+00 1.05839985e+01
  2.37570817e+00 5.63512942e+00 8.68662345e+04 1.29175529e+01
  1.45203237e+01 2.42701262e+00]
 [0.00000000e+00 1.65184653e+01 3.85637968e+

KeyboardInterrupt: 