In [1]:
import torch
import pyro
import pyro.distributions as dist
import pdb

In [2]:
from sherlog.program import loads
from sherlog.inference import minibatch, FunctionalEmbedding, Optimizer
from sherlog.interface import print, initialize, minotaur

In [3]:
initialize(port=8006, instrumentation=None)

In [39]:
embedder = FunctionalEmbedding(evidence=lambda s: s)

In [82]:
def model():
    p = 0.1
    data = []
    treatmentDist = dist.Normal(1., 0.5)
    controlDist = dist.Normal(0., 1.)
    obs = []
    for i in range(2):
        treatment = ["control", "treatment"][i % 2]
        data.append(f"treatmentGroup({i}, {treatment}).")
        if treatment:
            obs.append(f"happiness({i}, {treatmentDist.sample():.2f})")
        else:
            obs.append(f"happiness({i}, {controlDist.sample():.2f})")
    return data, obs

In [83]:
data, obs = model()

In [97]:
PROG = \
"""
!parameter mu_effective : real.
!parameter sigma_effective : positive.
!parameter mu_control : real.
!parameter sigma_control : positive.
!parameter p : unit.

effective(; {no, yes} <~ bernoulli[p]).

happiness(P; normal[mu_control, sigma_control]) <- treatmentGroup(P, treatment), effective(no).
happiness(P; normal[mu_effective, sigma_effective]) <- treatmentGroup(P, treatment), effective(yes).
happiness(P ;normal[mu_control, sigma_control]) <- treatmentGroup(P, control).
"""

In [98]:
SOURCE = "\n".join(data) + PROG

## Testing

In [99]:
program, _ = loads(SOURCE)

In [100]:
optimizer = Optimizer(program, learning_rate=1e-2, samples=10, force=True, cache=False)

In [101]:
optimizer.maximize(*embedder.embed_all(obs))

In [102]:
for batch in minibatch(obs, 30, epochs=20):
    optimizer.maximize(*embedder.embed_all(batch.data))
    print(optimizer.optimize().item())