# The Rational Speech Act framework



In [1]:
import torch

import collections
import argparse

import pyro
import pyro.distributions as dist
import pyro.poutine as poutine

In [2]:
cd "utils"

/Users/ngoodman/pplAffComp/code/utils


In [4]:
from search_inference import factor, HashingMarginal, memoize, Search

Now define some helpers. 

`Marginal` takes an un-normalized stochastic function constructs the distribution over execution traces by using `Search`, and constructs the marginal distribution on return values (via `HashingMarginal`).

`project` takes a distribution over some (discrete) domain and a function `qud` on this domain. It creates the push-forward distribution, using `Marginal` (as a Python decorator).

In [5]:
torch.set_default_dtype(torch.float64)  # double precision for numerical stability

def Marginal(fn):
    return memoize(lambda *args: HashingMarginal(Search(fn).run(*args)))

@Marginal
def project(dist,qud):
    v = pyro.sample("proj",dist)
    return qud_fns[qud](v)

Here are a whole bunch of functions setting up the basic definitions for the example domain.

In [6]:
State = collections.namedtuple("State", ["price", "valence"])


def price_prior():
    values = [50, 51, 500, 501, 1000, 1001, 5000, 5001, 10000, 10001]
    probs = torch.tensor([0.4205, 0.3865, 0.0533, 0.0538, 0.0223, 0.0211, 0.0112, 0.0111, 0.0083, 0.0120])
    ix = pyro.sample("price", dist.Categorical(probs=probs))
    return values[ix]


def valence_prior(price):
    probs = {
        50: 0.3173,
        51: 0.3173,
        500: 0.7920,
        501: 0.7920,
        1000: 0.8933,
        1001: 0.8933,
        5000: 0.9524,
        5001: 0.9524,
        10000: 0.9864,
        10001: 0.9864
    }
    return pyro.sample("valence", dist.Bernoulli(probs=probs[price])).item() == 1


def meaning(utterance, price):
    return utterance == price

def approx(x, b=None):
    if b is None:
        b = 10.
    div = float(x)/b
    rounded = int(div) + 1 if div - float(int(div)) >= 0.5 else int(div)
    return int(b) * rounded

qud_fns = {
    "price": lambda state: State(price=state.price, valence=None),
    "valence": lambda state: State(price=None, valence=state.valence),
    "priceValence": lambda state: State(price=state.price, valence=state.valence),
    "approxPrice": lambda state: State(price=approx(state.price), valence=None),
    "approxPriceValence": lambda state: State(price=approx(state.price), valence=state.valence),
}


def qud_prior():
    values = ["price", "valence", "priceValence", "approxPrice", "approxPriceValence"]
    ix = pyro.sample("qud", dist.Categorical(probs=torch.ones(len(values)) / len(values)))
    return values[ix]


def utterance_cost(numberUtt):
    preciseNumberCost = 1.
    return 0. if approx(numberUtt) == numberUtt else preciseNumberCost


def utterance_prior():
    utterances = [50, 51, 500, 501, 1000, 1001, 5000, 5001, 10000, 10001]
    utteranceLogits = -torch.tensor(list(map(utterance_cost, utterances)),
                                    dtype=torch.float64)
    ix = pyro.sample("utterance", dist.Categorical(logits=utteranceLogits))
    return utterances[ix]

And now the RSA model definitions.

In [7]:
@Marginal
def literal_listener(utterance):
    price = price_prior()
    state = State(price=price, valence=valence_prior(price))
    factor("literal_meaning", 0. if meaning(utterance, price) else -999999.)
    return state


@Marginal
def speaker(qudValue, qud):
    alpha = 1.
    with poutine.scale(scale=torch.tensor(alpha)):
        utterance = utterance_prior()
        literal_marginal = literal_listener(utterance)
        projected_literal = project(literal_marginal, qud)
        pyro.sample("listener", projected_literal, obs=qudValue)
    return utterance


@Marginal
def pragmatic_listener(utterance):
    # priors
    price = price_prior()
    valence = valence_prior(price)
    qud = qud_prior()

    # model
    state = State(price=price, valence=valence)
    qudValue = qud_fns[qud](state)
    speaker_marginal = speaker(qudValue, qud)
    pyro.sample("speaker", speaker_marginal, obs=utterance)
    return state

In [10]:
pragmatic_marginal = pragmatic_listener(10000)

print([(s, pragmatic_marginal.log_prob(s).exp().item())
       for s in pragmatic_marginal.enumerate_support()])

[(State(price=50, valence=False), 0.020944558617174744), (State(price=50, valence=True), 0.18962994295418753), (State(price=51, valence=False), 0.01925106279557203), (State(price=51, valence=True), 0.1742972008366076), (State(price=500, valence=False), 0.000808846021274365), (State(price=500, valence=True), 0.059996129350093005), (State(price=501, valence=False), 0.0008164336950199049), (State(price=501, valence=True), 0.06055894482242036), (State(price=1000, valence=False), 0.00017359794987375894), (State(price=1000, valence=True), 0.028312162297699544), (State(price=1001, valence=False), 0.00016425635615857894), (State(price=1001, valence=True), 0.026788637869123784), (State(price=5000, valence=False), 3.889558295405094e-05), (State(price=5000, valence=True), 0.015160315922876049), (State(price=5001, valence=False), 3.854830096338977e-05), (State(price=5001, valence=True), 0.015024955959278942), (State(price=10000, valence=False), 0.0030440475496016244), (State(price=10000, valence=T