In [1]:
import torch
import pyro
import pyro.distributions as dist
import pdb

In [2]:
from sherlog.program import loads
from sherlog.inference import minibatch, FunctionalEmbedding, Optimizer
from sherlog.interface import print, initialize, minotaur

In [3]:
initialize(port=8001, instrumentation=None)

In [4]:
embedder = FunctionalEmbedding(evidence=lambda s: s)

In [5]:
def model():
    p = 0.1
    data = []
    treatmentDist = dist.Normal(90., 6.)
    controlDist = dist.Normal(75., 3.)
    obs = []
    for i in range(2):
        treatment = dist.Bernoulli(0.5).sample().long()
        data.append(f"treatmentGroup({i}, {treatment}).")
        if treatment:
            obs.append(f"ageOfDeath({i}, {treatmentDist.sample()})")
        else:
            obs.append(f"ageOfDeath({i}, {controlDist.sample()})")
    return data, obs

In [6]:
data, obs = model()

In [7]:
obs

['ageOfDeath(0, 90.7716293334961)', 'ageOfDeath(1, 76.48066711425781)']

In [8]:
PROG = \
"""
!parameter mu_effective : real.
!parameter sigma_effective : positive.
!parameter mu_control : real.
!parameter sigma_control : positive.
!parameter p : unit.
effective(; bernoulli[p]).
gotTreatment(P, 1) <- treatmentGroup(P, 1), effective(1).
gotTreatment(P, 0) <- treatmentGroup(P, 1), effective(0).
gotTreatment(P, 0) <- treatmentGroup(P, 0), effective(0).
gotTreatment(P, 0) <- treatmentGroup(P, 0), effective(1).
ageOfDeath(P; normal[mu_effective, sigma_effective]) <- gotTreatment(P,1).
ageOfDeath(P; normal[mu_control, sigma_control]) <- gotTreatment(P,0).
"""

In [9]:
SOURCE = "\n".join(data) + PROG

In [10]:
print(SOURCE)

In [11]:
program, _ = loads(SOURCE)

In [16]:
optimizer = Optimizer(program, learning_rate=1e-2, samples=10, force=True, cache=False)

In [17]:
optimizer.maximize(*embedder.embed_all(obs[0:1]))

Objective produced infinite result. [objective=Objective(evidence=<sherlog.program.evidence.Evidence object at 0x7f6136690ac0>, conditional=None, parameters=None)]


In [None]:
for batch in minibatch(obs, 30, epochs=20):
    optimizer.maximize(*embedder.embed_all(batch.data))
    print(optimizer.optimize().item())