In [1]:
import torch
import pyro
import pyro.distributions as dist
import pdb

In [2]:
from sherlog.program import loads
from sherlog.inference import minibatch, FunctionalEmbedding, Optimizer
from sherlog.interface import print, initialize, minotaur

In [3]:
initialize(port=8006, instrumentation=None)

In [39]:
embedder = FunctionalEmbedding(evidence=lambda s: s)

In [45]:
def model():
    p = 0.1
    data = []
    treatmentDist = dist.Normal(90., 6.)
    controlDist = dist.Normal(75., 3.)
    obs = []
    for i in range(1):
        treatment = ["control", "treatment"][i] # [dist.Bernoulli(0.5).sample().long()]
        data.append(f"treatmentGroup({i}, {treatment}).")
        if treatment:
            obs.append(f"ageOfDeath({treatmentDist.sample():.2f})")
        else:
            obs.append(f"ageOfDeath({controlDist.sample():.2f})")
    return data, obs

In [46]:
data, obs = model()

In [47]:
data

['treatmentGroup(0, control).']

In [48]:
obs

['ageOfDeath(79.97)']

In [55]:
SOURCE = \
"""
!parameter mu_effective : real.
!parameter sigma_effective : positive.
!parameter mu_control : real.
!parameter sigma_control : positive.
!parameter p : unit.

observe(;normal[mu_effective, sigma_effective]).
"""

In [56]:
obs = [f"observe({torch.randn(1).item()*6 + 90:.2f})"]

## Old Stuff

In [5]:
def model():
    p = 0.1
    data = []
    treatmentDist = dist.Normal(90., 6.)
    controlDist = dist.Normal(75., 3.)
    obs = []
    for i in range(1):
        treatment = ["control", "treatment"][i] # [dist.Bernoulli(0.5).sample().long()]
#         data.append(f"treatmentGroup({i}, {treatment}).")
        if treatment:
            obs.append(f"ageOfDeath({treatmentDist.sample():.2f})")
        else:
            obs.append(f"ageOfDeath({controlDist.sample():.2f})")
    return data, obs

In [6]:
data, obs = model()

In [7]:
data

[]

In [8]:
obs

['ageOfDeath(91.36)']

In [70]:
# <- treatmentGroup(P, control).
# ageOfDeath(P; normal[mu_control, sigma_control]) <- treatmentGroup(P, treatment), effective(no).
# ageOfDeath(P; normal[mu_effective, sigma_effective]) <- treatmentGroup(P, treatment), effective(yes).

In [21]:
PROG = \
"""
!parameter mu_effective : real.
!parameter sigma_effective : positive.
!parameter mu_control : real.
!parameter sigma_control : positive.
!parameter p : unit.

effective(; {no, yes} <~ bernoulli[p]).

ageOfDeath(P; normal[mu_control, sigma_control]) <- treatmentGroup(P, _).
"""

In [9]:
PROG = \
"""
!parameter mu_effective : real.
!parameter sigma_effective : positive.
!parameter mu_control : real.
!parameter sigma_control : positive.
!parameter p : unit.

ageOfDeath(; normal[mu_control, sigma_control]).
"""

In [10]:
SOURCE = "\n".join(data) + PROG

In [11]:
print(SOURCE)

## Testing

In [58]:
program, _ = loads(SOURCE)

In [59]:
optimizer = Optimizer(program, learning_rate=1e-2, samples=10, force=True, cache=False)

In [66]:
thenormal = dist.Normal(0., 1)

In [68]:
thenormal.log_prob(torch.tensor(80.)).exp()

tensor(0.)

In [60]:
optimizer.maximize(*embedder.embed_all(obs))

Objective produced infinite result. [objective=Objective(evidence=<sherlog.program.evidence.Evidence object at 0x7fb92e3febe0>, conditional=None, parameters=None)]


In [None]:
for batch in minibatch(obs, 30, epochs=20):
    optimizer.maximize(*embedder.embed_all(batch.data))
    print(optimizer.optimize().item())