In [1]:
import torch
from itertools import chain
import pyro.distributions as dist
import pyro
from pyro import poutine
import numpy as np
from itertools import chain
import pdb
from pmextract import extract

In [2]:
from sherlog.program import loads
from sherlog.inference import minibatch, FunctionalEmbedding, Optimizer
from sherlog.interface import print, initialize, minotaur

In [3]:
initialize(port=8008, instrumentation=None)

In [4]:
embedder = FunctionalEmbedding(evidence=lambda s: s)

In [5]:
NTOPICS = 1
NJUDGES = 2
JUDGES_PER_PAPER = 1
NDOCS = 2

In [6]:
topics = torch.randint(0, NTOPICS, (NDOCS,))

In [7]:
judges = torch.tensor(np.array([
    np.random.choice(NJUDGES, JUDGES_PER_PAPER, replace=False) for _ in range(NDOCS)
])).long()

In [8]:
judges.shape

torch.Size([2, 1])

In [9]:
topics.shape

torch.Size([2])

In [10]:
yesno = ["yes", "no"]

In [19]:
def model():
    with pyro.plate("judges", NJUDGES):
        threshold = pyro.sample("threshold", dist.Normal(4, 0.5))
        with pyro.plate("topics", NTOPICS):
            expertise = pyro.sample("expertise", dist.Gamma(1., 1.))
    with pyro.plate("papers", NDOCS) as papers:
        with poutine.block():
            quality = pyro.sample("quality", dist.Normal(2.5, 1.0))
        with pyro.plate("paperjudge", 3) as paperjudge:
            paper_expertise = torch.gather(expertise[topics, :], 0, judges.long())
            score_dist = dist.Normal(quality.float(), paper_expertise.T)
            meetsThresh = pyro.sample("meetsThresh", dist.Bernoulli(probs = 1 - score_dist.cdf(threshold[judges].T))).long()
            obs = [f"meets_thresh({p}, {j}, {yesno[meetsThresh[j, p]]})" for p in papers for j in paperjudge]
    
    params = [f"!parameter paperQuality{i} : real" for i in range(NDOCS)]
    params.extend([f"!parameter judgeThreshold{i} : real" for i in range(NJUDGES)])
    params.extend([f"!parameter judgeExpertise{j}_{t} : real" for j in range(NJUDGES) for t in range(NTOPICS)])
    data = list(chain.from_iterable((f"judged({j}, {p})" for (j,p) in enumerate(judges[:,i])) for i in range(JUDGES_PER_PAPER)))
    data.extend([f"hasTopic({p}, {t})" for p, t in enumerate(topics)])
    data.extend([f"judge_expertise({j}, {t}, judgeExpertise{j}_{t})" for j in range(NJUDGES) for t in range(NTOPICS)])
    data.extend([f"paper_quality({p}, paperQuality{p})" for p in range(NDOCS)])
    data.extend([f"judge_threshold({j}, judgeThreshold{j})" for j in range(NJUDGES)])

    return obs, params + data

In [20]:
obs, data = model()

In [21]:
PROG = """
meet_thresh_prob(P, J; gaussian_survival[TH, Q, E]) <- paper_quality(P, Q),
  judge_expertise(J, T, E), hasTopic(P, T), judged(J, P), judge_threshold(J, TH).
  
meets_thresh(P, J; {yes, no} <~ bernoulli[PROB]) <- meet_thresh_prob(P, J, PROB).
"""

In [22]:
SOURCE = ".\n".join(data) + "." + PROG

In [23]:
print(SOURCE)

In [24]:
obs

['meets_thresh(0, 0, yes)',
 'meets_thresh(0, 1, yes)',
 'meets_thresh(0, 2, yes)',
 'meets_thresh(1, 0, yes)',
 'meets_thresh(1, 1, yes)',
 'meets_thresh(1, 2, yes)']

In [25]:
program, _ = loads(SOURCE)

In [26]:
optimizer = Optimizer(program, learning_rate=1e-2, samples=10, force=True, cache=False)

In [None]:
optimizer.maximize(*embedder.embed_all(obs))

Fatal error: exception Invalid_argument("Invalid sampling.")
