In [4]:
import torch
import pyro
import pyro.distributions as dist
import pdb

In [5]:
from sherlog.program import loads
from sherlog.inference import minibatch, FunctionalEmbedding, Optimizer
from sherlog.interface import print, initialize, minotaur

In [6]:
initialize(port=7996, instrumentation=None)

In [7]:
embedder = FunctionalEmbedding(evidence=lambda s: s)

In [8]:
NPlayers = 3

In [9]:
at_bats = dist.Poisson(10).sample((NPlayers,)).long() + 1

In [10]:
at_bats

tensor([11, 10,  7])

## Partially Pooled

In [25]:
def partialpooled():
    observations = []
    players = []
    times = []
    num_players = at_bats.shape[0]
    max_times = 0
    kappa = dist.Pareto(1, 1.5).sample()
    m = dist.Uniform(0, 1).sample()
    phis = []
    print((m * kappa, (1-m) * kappa))
    beta = dist.Beta(m * kappa, (1 - m) * kappa)
    for player in range(NPlayers):
        phi = beta.sample()
        phis.append(phi)
        players.append(f"player({player}, p{player}).")
        for time in range(at_bats[player].item()):
            max_times = max(max_times, time)
            result = ["hit", "miss"][dist.Bernoulli(phi).sample().long()]
            observations.append(f"atbat({player}, {time}, {result})")
    for time in range(max_times + 1):
        times.append(f"time({time}).")
    return phis, observations, players + times 

In [26]:
phis, observations, domains = partialpooled()

In [31]:
params = [f"!parameter p{i} : unit." for i in range(NPlayers)]

In [32]:
params += [f"!parameter b : positive[2]."]

In [None]:
PROG = """
atbat(P, T; {hit, miss} <~ bernoulli[PARAM]) <- player(P, PARAM), time(T).
"""

PROBLEM: how do we express a parameter that gets sampled from somewhere in Sherlog's MLE framework?

## Not Pooled

In [30]:
phi = dist.Beta(2, 3).sample((NPlayers,))

In [31]:
phi

tensor([0.2445, 0.5401, 0.2784])

In [32]:
def notpooled():
    observations = []
    players = []
    times = []
    num_players = at_bats.shape[0]
    max_times = 0
    for player in range(NPlayers):
        players.append(f"player({player}, p{player}).")
        for time in range(at_bats[player].item()):
            max_times = max(max_times, time)
            result = ["hit", "miss"][dist.Bernoulli(phi[player]).sample().long()]
            observations.append(f"atbat({player}, {time}, {result})")
    for time in range(max_times + 1):
        times.append(f"time({time}).")
    return observations, players + times

In [33]:
observations, domains = notpooled()

In [34]:
params = [f"!parameter p{i} : unit." for i in range(NPlayers)]

In [35]:
PROG = """
atbat(P, T; {hit, miss} <~ bernoulli[PARAM]) <- player(P, PARAM), time(T).
"""

In [36]:
SOURCE = "\n".join(params + domains) + PROG

In [39]:
program, _ = loads(SOURCE)

In [46]:
optimizer = Optimizer(program, learning_rate=1e-2, samples=7, force=True, cache=False)

In [None]:
for batch in minibatch(observations, len(observations), epochs=20):
    optimizer.maximize(*embedder.embed_all(batch.data))
    print(optimizer.optimize().item())

In [48]:
phi - torch.tensor([program.parameter(f"p{n}") for n in range(NPlayers)])

tensor([-0.7555, -0.4599, -0.2216])

## Fully Pooled

### Generative Process

In [45]:
def fully_pooled():
    phi = 0.2
    observations = []
    players = []
    times = []
    num_players = at_bats.shape[0]
    max_times = 0
    for player in range(NPlayers):
        players.append(f"player(player{player}).")
        for time in range(at_bats[player].item()):
            max_times = max(max_times, time)
            result = ["hit", "miss"][dist.Bernoulli(phi).sample().long()]
            observations.append(f"atbat(player{player}, time{time}, {result})")
    for time in range(max_times + 1):
        times.append(f"time(time{time}).")
    return observations, players + times
        

In [46]:
observations, domains = fully_pooled()

### Sherlog Program

In [47]:
SOURCE = \
"""
!parameter p : unit.
atbat(P, T; {hit, miss} <~ bernoulli[p]) <- player(P), time(T).
"""

In [48]:
all_source = SOURCE + "\n".join(domains)

In [50]:
program, _ = loads(all_source)

In [51]:
optimizer = Optimizer(program, learning_rate=1e-2, samples=10, force=True, cache=False)

In [52]:
optimizer.maximize(*embedder.embed_all(observations))

In [53]:
optimizer.optimize()

tensor(1.3863, grad_fn=<SumBackward0>)

In [None]:
for batch in minibatch(observations, 10, epochs=10):
    optimizer.maximize(*embedder.embed_all(batch.data))
    print(optimizer.optimize().item())