### Sufficient statistics for speed

There's one more thing we can do with this model. Conceptually, there are just 3 groups of individuals in this trial - never takers in the test group, compliers in the test group and the control group. We aren't using any individual-level information, so we can aggregate the data at the group level and use sufficient statistics to speed up the model. Effectively, instead of having 24k data points, we'll have just 3. And instead of fitting to Bernoulli distributions, we will fit to binomial distributions.

Let's do a bit of data prep:

In [None]:
total_compliers = df.query("W == 1").shape[0]
survived_compliers = df.query("(W == 1) & (Y == 1)").shape[0]

total_never_takers = df.query("(Z == 1) & (W == 0)").shape[0]
survived_never_takers = df.query("(Z == 1) & (W == 0) & (Y == 1)").shape[0]

total_control_group = df.query("(Z == 0)").shape[0]
survived_control_group = df.query("(Z == 0) & (Y == 1)").shape[0]

print(f"""
total_compliers: {total_compliers}
survived_compliers: {survived_compliers}
total_never_takers: {total_never_takers}
survived_never_takers: {survived_never_takers}
total_control_group: {total_control_group}
survived_control_group: {survived_control_group}
""")


total_compliers: 9675
survived_compliers: 9663
total_never_takers: 2419
survived_never_takers: 2385
total_control_group: 11588
survived_control_group: 11514



And the model with identical structure otherwise:

In [None]:
with pm.Model(
    coords={"segment": ["control_group", "compliers", "never-takers"]}
) as sufficient_stat_model:
    γ = pm.Normal("γ", mu=0, sigma=1)
    α = pm.Normal("α", mu=1.5, sigma=0.5)
    # the "treatment" effect is fixed to 0 for control group
    β = pm.Deterministic(
        "β",
        pm.math.concatenate(
            [pt.tensor.zeros(1), pm.Normal("β_", mu=0, sigma=0.2, shape=2)]
        ),
        dims="segment",
    )

    # compliers
    pm.Binomial(
        "compliers", logit_p=α + β[1], n=total_compliers, observed=survived_compliers
    )

    # never takers
    pm.Binomial(
        "never_takers",
        logit_p=α + β[2],
        n=total_never_takers,
        observed=survived_never_takers,
    )

    # add probability to comply
    pm.Binomial(
        "treatment",
        logit_p=γ,
        n=(total_compliers + total_never_takers),
        observed=total_compliers,
    )

    # deterministics
    η = pm.Deterministic("survival_rate", pm.invlogit(α + β), dims="segment")
    pm.Deterministic("ATT", (η[1] - η[0]))
    pm.Deterministic("never_taker_effect", (η[2] - η[0]))
    π = pm.Deterministic("probability_to_comply", pm.invlogit(γ))

    # control group
    pm.Mixture(
        "control_group",
        w=pm.math.stack([π, 1 - π]),
        comp_dists=pm.Binomial.dist(
            logit_p=pm.math.stack([α, α + β[2]]),
            n=pm.math.stack([total_control_group, total_control_group]),
        ),
        observed=survived_control_group,
    )

If you're not convinced the two models are equivalent, we can prove that by calculating log-likehood at the same starting points for both models:

In [None]:
logp_sufficient = sufficient_stat_model.dlogp().eval({"α": 0, "γ": 0, "β_": [0, 0]})
logp_informative = informative_prior_model.dlogp().eval({"α": 0, "γ": 0, "β_": [0, 0]})
print("Binomial model")
print(logp_sufficient)
print("Bernoulli model")
print(logp_informative)

Binomial model
[ 3628.  11727.   4825.5  4035.5]
Bernoulli model
[ 3628.  11727.   4825.5  4035.5]


Just to demonstrate the speed benefits, let's draw 5x more samples and do it with 5x more chains.

In [None]:
with sufficient_stat_model:
    pymc_sufficient_trace = pm.sample(
        draws=10_000, tune=1_000, chains=20, nuts_sampler="nutpie"
    )

Progress,Draws,Divergences,Step Size,Gradients/Draw
,11000,0,0.97,3
,11000,0,0.96,3
,11000,0,0.94,7
,11000,0,0.97,7
,11000,0,0.97,7
,11000,0,0.98,7
,11000,0,0.99,3
,11000,0,0.96,7
,11000,0,0.96,3
,11000,0,0.97,3


Well, it barely took any time outside of sampling. The results are nearly identical to the last model.

In [None]:
az.summary(
    pymc_sufficient_trace,
    hdi_prob=0.95,
    round_to=4,
    var_names=["survival_rate", "probability_to_comply", "ATT", "never_taker_effect"],
)

Unnamed: 0,mean,sd,hdi_2.5%,hdi_97.5%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
survival_rate[control_group],0.9932,0.0007,0.9919,0.9945,0.0,0.0,170000.3606,157761.8874,1.0001
survival_rate[compliers],0.9968,0.0004,0.9959,0.9977,0.0,0.0,341097.0645,172091.5177,1.0001
survival_rate[never-takers],0.99,0.0015,0.9869,0.9929,0.0,0.0,322555.7496,169531.17,1.0001
probability_to_comply,0.7999,0.0036,0.7927,0.807,0.0,0.0,308249.2068,164717.6769,1.0
ATT,0.0036,0.0007,0.0023,0.0049,0.0,0.0,154689.0738,151219.9016,1.0001
never_taker_effect,-0.0032,0.0015,-0.0061,-0.0005,0.0,0.0,255294.7345,164444.9329,1.0001
