# PyMC code for *BUGS in Bayesian stock assessments* by Meyer and Millar (1999)

Original BUGS code available here: https://www.stat.auckland.ac.nz/~millar/Bayesian/Surtuna.bugs

Corresponding paper is available here: https://cdnsciencepub.com/doi/pdf/10.1139/f99-043

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pymc as pm
import pytensor.tensor as pt
import pytensor
from pymc.pytensorf import collect_default_updates
import arviz as az

# Data

In [2]:
# Catch time series (N=23)
C_data = np.array([
    15.9, 25.7, 28.5, 23.7, 25.0, 33.3, 28.2, 19.7, 17.5, 19.3, 
    21.6, 23.1, 22.5, 22.5, 23.6, 29.1, 14.4, 13.2, 28.4, 34.6,
    37.5, 25.9, 25.3
], dtype=float)

# Index time series (N=23)
I_data = np.array([
    61.89, 78.98, 55.59, 44.61, 56.89, 38.27, 33.84, 36.13, 41.95, 36.63,
    36.33, 38.82, 34.32, 37.64, 34.01, 32.16, 26.88, 36.61, 30.07, 30.75,
    23.36, 22.36, 21.91
], dtype=float)

# Number of observed years:
N = len(I_data)

# Number of future projection years:
nyrs = 10

# For projection years, assume constant catch of 19 (as in the JAGS code):
C_proj = np.repeat(19.0, nyrs)

# Concatenate for total catch vector up to N+nyrs
C_total = np.concatenate([C_data, C_proj])  # length = N + nyrs


In [3]:
lags = 1  # Number of lags
coords = {
    "lags": range(-1, 0),
    "steps": range(N - lags),
    "timeseries_length": range(N),
}

# Model

In [4]:
with pm.Model(coords=coords, check_bounds=False) as surplus_model:
    C = pm.Data('C',C_data)
    PM = pt.vector("PM")
    
    ##############################################################################
    # 2.1 Priors
    ##############################################################################

    # r ~ dlnorm(-1.38, 3.845) => log(r) ~ Normal(-1.38, 1/sqrt(3.845))
    #r = pm.Lognormal("r", mu=-1.38, sigma=1/(3.845**0.5))
    r_ = pm.Normal("r_", mu=-1.38, sigma=1/(3.845**0.5))
    r = pm.Deterministic('r', pt.exp(r_))

    # k ~ dlnorm(-5.042905, 3.7603664) => log(k) ~ Normal(-5.042905, 1/sqrt(3.7603664))
    k = pm.Lognormal("k", mu=-5.042905, sigma=1/(3.7603664**0.5))
    K = pm.Deterministic("K", 1.0/k)

    # iq ~ dgamma(0.001, 0.001); then q = 1/iq; Q = q*K
    iq = pm.Gamma("iq", alpha=0.001, beta=0.001)
    q_ = pm.Deterministic("q", 1.0 / iq)
    q10e4 = pm.Deterministic("q10e4", q_*1000)
    Q = pm.Deterministic("Q", q_ * K)

    # isigma2 ~ dgamma(a0,b0); itau2 ~ dgamma(c0,d0)
    # with the same parameters as in JAGS:
    a0, b0 = 3.785518, 0.010223
    c0, d0 = 1.708603, 0.008613854

    #isigma2 = pm.Gamma("isigma2", alpha=a0, beta=1/b0)  # state process precision
    #itau2   = pm.Gamma("itau2", alpha=c0, beta=1/d0)    # observation precision

    isigma2 = pm.Exponential("isigma2", 1)  # state process precision
    itau2   = pm.Exponential("itau2", 1)    # observation precision

    # Derived standard deviations if you want to monitor them:
    sigma2_ = pm.Deterministic("Sigma2", 1.0 / isigma2)  # state variance
    tau2_   = pm.Deterministic("Tau2",   1.0 / itau2)    # obs variance

    # The "MSP = r*K/4" as in the JAGS code
    MSP = pm.Deterministic("MSP", r*K/4.0)
    # The "EMSP = r/(2*q)" as in the JAGS code
    EMSP = pm.Deterministic("EMSP", r/(2*q_))

    # Initial P
    P_init_LN = pm.Lognormal.dist(mu=0.0, sigma=1.0)
    P_init = pm.Truncated('P_init', P_init_LN, lower=0.01, upper=2)

    # Scan functions to recursively estimate P
    def step_p(i, p_prev, C, r, k, isigma2):
        mu = pt.log(pt.maximum(p_prev + r * p_prev * (1 - p_prev) - k * C[i-1], 0.01))
        x =  pm.Lognormal.dist(mu=mu, sigma=1/(isigma2**0.5))
        return x, collect_default_updates([x])
        
    def p_dist(P_init, C, r, k, isigma2, size):
        p_means, _ = pytensor.scan(
            fn=step_p,
            outputs_info=[dict(initial=P_init, taps=[-1])],
            sequences=pt.arange(N-lags),
            non_sequences=[C,r,k,isigma2],
            n_steps=N-lags
        )
        return p_means
        
    # PYMC recursion stochastic
    p_steps = pm.CustomDist(
            "p_steps",
            P_init,C,r,k,isigma2,
            dist=p_dist,
            dims=("steps"),
        )

    # Store full array for P
    P = pm.Deterministic('P', var=pt.concatenate([P_init[None], p_steps]), dims=("timeseries_length"))
    
    
    # Mean I
    Imean = pm.Deterministic('Imean', Q*P)
    # B
    B = pm.Deterministic('B', K*P)
    # Likelihood
    pm.Lognormal('I', Imean, 1.0/(itau2**0.5), observed=I_data)
    
    # Biomass in next year (1990)
    P1990_mu = pm.Deterministic('P1990_mu', pt.log(pt.maximum(p_steps[-1] + r * p_steps[-1] * (1 - p_steps[-1]) - k * C[-1], 0.01)))
    P1990 = pm.Lognormal('P1990', P1990_mu, sigma=1/(isigma2**0.5))
    B1990 = pm.Deterministic('B1990', K*P1990)

In [5]:
# Now sample from this model:
with surplus_model:
    trace = pm.sample(
        target_accept=0.95,
    )

Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [r_, k, iq, isigma2, itau2, P_init, p_steps, P1990]


Output()

Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 65 seconds.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters.  A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details


In [6]:
az.summary(trace, var_names=['K','r','q10e4','B1990','P1990','MSP','EMSP','Sigma2','Tau2'])

Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
K,288.105,112.663,119.705,505.943,4.031,2.851,756.0,1219.0,1.01
r,0.347,0.187,0.073,0.682,0.008,0.005,549.0,1170.0,1.01
q10e4,19.587,9.531,4.579,37.491,0.759,0.538,143.0,738.0,1.02
B1990,193.791,144.342,18.26,446.954,7.889,5.583,269.0,799.0,1.01
P1990,0.662,0.359,0.122,1.307,0.024,0.017,189.0,400.0,1.02
MSP,23.467,13.233,4.383,47.213,0.385,0.272,897.0,1125.0,1.0
EMSP,10.437,7.143,1.934,22.96,0.352,0.249,334.0,1067.0,1.0
Sigma2,0.134,0.048,0.059,0.218,0.001,0.001,1287.0,1824.0,1.0
Tau2,0.372,0.202,0.112,0.728,0.008,0.005,662.0,1283.0,1.0
