In [1]:
import torch
import pyro
import pyro.distributions as dist

In [2]:
import matplotlib.pyplot as plt

### Examples

##### Teacher/Learner (recursive hierarchical model):
Original: <br>
[1] http://v1.probmods.org/inference-about-inference.html#communication-and-language

##### Implementation

Priors:

In [59]:
dies = ['A', 'B']
colors = ['red', 'green', 'blue']

# p(side|die):
A_probs = torch.tensor([0.0, 0.2, 0.8])
B_probs =torch.tensor([0.1, 0.3, 0.6])
roll = (lambda die_idx: dist.Categorical(A_probs)
              if die_idx==0 else dist.Categorical(B_probs))

# p(side):
side_prior = dist.Categorical(torch.tensor([0.3, 0.3, 0.3]))

# p(die):
die_prior = dist.Bernoulli(torch.tensor(0.5))

# helpers:
to_color = lambda color_idx: colors[int(color_idx)]
to_color_idx = lambda color: torch.tensor(
    colors.index(color))
to_die_name = lambda die_idx: dies[int(die_idx)]
to_die_idx = lambda die_name: torch.tensor(dies.index(die_name))


##### Test
priors:

In [60]:
to_color(roll(to_die_idx('A')).sample())

'blue'

##### Define
init for recursion:

In [73]:
def learner_init(side, count=0):
    # counter:
    count += 1;
    
    # p(die):
    die = pyro.sample("ldie", die_prior)

    # p(lside|die):
    side_dist = roll(die)
    lside = pyro.sample("lside", side_dist)

    if lside == side:
        return(die)
    return(learner_init(side, count=count))

##### Testing
init conditioning:

In [74]:
to_die_name(learner_init(to_color_idx('green')))

'B'

In [75]:
data = [int(learner_init(to_color_idx('green'))) for i in range(100)]
# print(data)
classes = set(data)
print(classes)

%matplotlib
# plt.figure(figsize=(10,1))
plt.hist(data, len(classes), # density=True, 
        orientation='horizontal', stacked=True,
        rwidth=0.1, label = [str(c) for c in classes])


{0, 1}
Using matplotlib backend: Qt5Agg


(array([36., 64.]), array([0. , 0.5, 1. ]), <a list of 2 Patch objects>)

##### Main recursion loop:

Here we define learner and teacher loop. The learner, given side, must guess the die. The teacher want for the learner to guess correctly, so it try to get learner side with use of which the learner can better to judge about the die.<br><br>
`
learner(side): {die| die = die_prior(), equal(side, teacher(die)) 
teacher(die): {side| side = side_prior(), equal(side, learner(side))}
`<br><br>
So we have "pair of mutually recursive functions" [1] here.
The `depth` and `learner_init` used for fix this.

In [80]:
def learner(tside, depth, count=0):
    if depth == 0:
        # itit:
        die = learner_init(tside, count=count)
        return(die)
    depth -= 1
    die = learner_loop(tside, depth, count=count)
    return(die)

def learner_loop(tside, depth, count=0):

    # counter:
    count += 1;
       
    # ldie (sim)|(throu teacher)|tside (fixed)=ttside (sim):
    ldie = pyro.sample("ldie", die_prior)
    ttside = teacher(ldie, depth)
    if ttside == tside:
        return(ldie)
    return(learner_loop(tside, depth, count=count))

def teacher(ldie, depth):
    # tside (sim)|(throu lerner)|ldie (fixed)=lldie (sim):
    tside = pyro.sample("tside", side_prior)
    lldie = learner(tside, depth)
    if lldie == ldie:
        return(tside)
    return(teacher(ldie, depth))

### Test:

If `depth=0` means that the learner not use teacher and just assume more probable from `learner_init` (which is 'B').

In [77]:
to_die_name(learner(to_color_idx('green'), 0))

'A'

In [66]:
print(A_probs)
print(B_probs)

tensor([0.0000, 0.2000, 0.8000])
tensor([0.1000, 0.3000, 0.6000])


In [78]:
data = [int(learner(to_color_idx('green'), 0)) for i in range(100)]
# print(data)
classes = set(data)
print(classes)

%matplotlib
# plt.figure(figsize=(10,1))
plt.hist(data, len(classes), # density=True, 
        orientation='horizontal', stacked=True,
        rwidth=0.1, label = dies)


{0, 1}
Using matplotlib backend: Qt5Agg


(array([44., 56.]), array([0. , 0.5, 1. ]), <a list of 2 Patch objects>)

Now, with `depth=1`, learner lerns with judge about teacher motivs, so
"Now die A becomes the better inference, because “if the teacher had meant to communicate B, they would have shown the red side because that can never come from A.”" [1]:

In [81]:
to_die_name(learner(to_color_idx('green'), 1))

'B'

In [72]:
print(A_probs)
print(B_probs)

tensor([0.0000, 0.2000, 0.8000])
tensor([0.1000, 0.3000, 0.6000])


In [87]:
data = [int(learner(to_color_idx('green'), 1)) for i in range(300)]
# print(data)
classes = set(data)
print(classes)

%matplotlib
# plt.figure(figsize=(10,1))
plt.hist(data, len(classes), # density=True, 
        orientation='horizontal', stacked=True,
        rwidth=0.1, label = dies)


{0, 1}
Using matplotlib backend: Qt5Agg


(array([187., 113.]), array([0. , 0.5, 1. ]), <a list of 2 Patch objects>)

##### TODO:
   Implement using `pyro.infer`

In [3]:
# setup the optimizer:
adam_params = {"lr": 0.05, "betas": (0.9, 0.999)}
optimizer = pyro.optim.Adam(adam_params)

# setup the loss:
loss = pyro.infer.Trace_ELBO()

In [None]:
def rd_inference(guess):
    def rd_model(guess):
        pyro.param('ldie_a', torch.tensor(1.0))
        pyro.sample("ldie", dist.Categorical(torch.ones(3)))
        
    def rd_guide(guess):
        pyro.param('ldie', torch.tensor())
        return(pyro.sample("ldie", die_prior))
    
    def roll_die(die):
        
        
    side = roll_die()
    rd_cond = pyro.condition(rd_guide, data={'lside': side,
                                             'tside':side})

In [None]:
def t_guide(tdie):
    aside = pyro.param("aside", torch.tensor(1))
    bside = pyro.param("bside", torch.tensor(1))
    cside = pyro.param("cside", torch.tensor(1))

    return(pyro.sample("tside",
                       dist.Categorical(torch.tensor([aside, bside, cside]))))

def t_cond(tdie):
    tdie = pyro.sample("tdie", die_prior, obs=pyro.param('ldie'))
    pyro.sample("tside", side_prior,obs=tside)
    
def l_guide(tside):
    a = pyro.param('a', torch.tensor(1))
    res = pyro.sample('ldie', dist.Bernoulli(a))
    
def l_cond(tside):
    
    ldie = pyro.sample("ldie", die_prior)
    tside = pyro.sample("tside",
                        dist.Categorical(torch.tensor([
                            pyro.param('aside'),
                            pyro.param('bside'),
                            pyro.param('cside')]))
    lside = pyro.sample("lside", side_prior, obs=tside)
    

In [None]:
def lerner_inference(guess):
    def lerner_model(guess):
        ldie = pyro.sample("ldie", die_prior)
        lside = pyro.sample("lside", side_prior)
        return(lside)
    
    def lerner_guide(guess):
        return(pyro.sample("ldie", die_prior))

    lerner_cond = pyro.condition(lerner_model,
                                  data={"lside": pyro.param('tside').item()})
    svi = pyro.infer.SVI(model=lerner_cond,
                     guide=lerner_guide,
                     optim=optimizer,
                     loss=loss)
    return(svi)

In [3]:
def teacher_inference(guess):
    def teacher_guide(guess):
        return(pyro.sample("tside", side_prior))

    teacher_cond = pyro.condition(teacher_guide,
                                  data={"tdie": pyro.param('ldie').item()})
    svi = pyro.infer.SVI(model=teacher_cond,
                     guide=teacher_guide,
                     optim=optimizer,
                     loss=loss)
    return(svi)

In [4]:
pyro.clear_param_store()
svi = teacher_inference(1.0)
svi.step(1.0)

KeyError: 'ldie'