# The Rational Speech Act framework

Up here I will include some introduction to the RSA framework.

Note: This notebook must be run against Pyro 4392d54a220c328ee356600fb69f82166330d3d6 or later.

In [11]:
import torch

import collections
import argparse

import pyro
import pyro.distributions as dist
import pyro.poutine as poutine

## this should work but doesn't, so i hacked it...:
# from utils.search_inference import factor, HashingMarginal, memoize, Search

In [12]:
cd "utils"

[Errno 2] No such file or directory: 'utils'
/Users/ngoodman/pplAffComp/code/utils


In [13]:
from search_inference import factor, HashingMarginal, memoize, Search

Now define helpers. 

`Marginal` takes an un-normalized stochastic function constructs the distribution over execution traces by using `Search`, and constructs the marginal distribution on return values (via `HashingMarginal`).

In [14]:
torch.set_default_dtype(torch.float64)  # double precision for numerical stability

def Marginal(fn):
    return memoize(lambda *args: HashingMarginal(Search(fn).run(*args)))

Now let's set up a simple world: there are 4 objects and utterances "none are blue", "some are blue", "all are blue".

The first ingredients are priors: probabilities for the number of blue objects and the utterance. Both are taken to be uniform.

Next are the meanings of the utterances.

In [15]:
total_number = 4

def state_prior():
    n = pyro.sample("state", dist.Categorical(probs=torch.ones(total_number+1) / total_number+1))
    return n

def utterance_prior():
    ix = pyro.sample("utt", dist.Categorical(probs=torch.ones(3) / 3))
    return ["none","some","all"][ix]

meanings = {
    "none": lambda N: N==0,
    "some": lambda N: N>0,
    "all": lambda N: N==total_number,
}

def meaning(utterance, state):
    return meanings[utterance](state)

And now the RSA model definitions.

The `literal_listener` simply imposes that the utterance is true.

The `speaker` chooses an utterance to convey `state` to the literal listener. (When `alpha` is greater than one the speaker is more optimal.)

The `pragmatic_listener` infers which state is likely, given that the speaker chose a given utterance.

In [16]:
@Marginal
def literal_listener(utterance):
    state = state_prior()
    factor("literal_meaning", 0. if meaning(utterance, state) else -999999.)
    return state


@Marginal
def speaker(state):
    alpha = 1.
    with poutine.scale(scale=torch.tensor(alpha)):
        utterance = utterance_prior()
        pyro.sample("listener", literal_listener(utterance), obs=state)
    return utterance


@Marginal
def pragmatic_listener(utterance):
    state = state_prior()
    pyro.sample("speaker", speaker(state), obs=utterance)
    return state

Now let's see if it works: how does the pragmatic listener interpret the "some" utterance?

In [18]:
interp_dist = pragmatic_listener("some")

print([(s, interp_dist.log_prob(s).exp().item()) 
       for s in interp_dist.enumerate_support()])

[(tensor(0), 0.0), (tensor(1), 0.3125), (tensor(2), 0.3125), (tensor(3), 0.3125), (tensor(4), 0.0625)]


Yay, we get a scalar implicature: "some" is interpretted as likely not including all 4. Try looking at the `literal_listener` too -- no implicature.