# Introduction to the Rational Speech Acts (RSA) in Python

This tutorial on the Rational Speech Acts (RSA) framework was adopted from the first chapter of the ProbLang textbook: https://www.problang.org/chapters/01-introduction.html

The primary difference is that everything is implemented in Python. Probabilistic programming languages like WebPPL give us a nice abstraction for inferring and sampling from distributions, but it turns out that we don't actually need to do anything fancy for simple RSA. We can compute posteriors by enumeration, a.k.a. just applying the exact mathematical formulas.

If this is your first time with RSA, I would highly recommend reading the linked chapter. This is mostly just an exercise in translating to Python.

## Overview: what is RSA?

The RSA framework is situated within the larger tradition of probabilistic (Bayesian) models of cognition. The particular application is to pragmatics - the subfield of linguistics that deals with how people use external context to enrich the literal meaning of words.

One example of pragmatic language inference can be seen in the following discourse:

**Person 1**: What did you think of John? \
**Person 2**: John was nice.

Here, Person 1 is going to infer that Person 2 thought that John was just ok, because if Person 2 had a really high opinion of John, they would have said something stronger (i.e "John was fantastic").

RSA (Frank & Goodman, 2012) has helped us to formalize the principles that underlie these kind of inferences. At its core, RSA models pragmatic communication between speakers and listeners under the following assumptions: 

- Speakers and listeners **recursively reason** about each others' mental states
- Speakers and listeners are **approximately rational**
- Speakers and listeners are **approximately Bayesian**

In this tutorial, we will see how these assumptions are implemented in precise mathematical terms, allowing us to make quantitative predictions which we can test against experimental data.

In [1]:
import numpy as np
import pandas as pd

## Define objects and utterances

We're going to use the same example from the textbook, which is a reference game, where a speaker attempts to convey an intended referrent to a listener who does not know that referent. Let's say our reference world consists of following objects:

![rsa_scene.png](attachment:rsa_scene.png)

So we'll say that that the set of world states, $S$, is comprised of the following:

$S = \{$blue square, blue circle, green square$\}$.

We'll also assume that the set of utterances that the speaker can use, $U$, is the following:

$U = \{$"blue", "green", "square", "circle"$\}$

<div class="alert alert-block alert-info">
<b>Intuition check:</b> Let's say that you're playing this reference game, in the role of a listener. If the speaker said "blue" to you, which object do you think they're referring to? Why?
</div>

Now with those preliminaries out of the way, let's translate $S$ and $U$ to code.

In [2]:
objects = [
            {"color": "blue", "shape": "square", "string": "blue square"},
            {"color": "blue", "shape": "circle", "string": "blue circle"},
            {"color": "green", "shape": "square", "string": "green square"}
]
utterances = ["blue", "green", "square", "circle"]

## Literal listener

The __literal listener__ is the bottom-most part of our recursive RSA model. It captures a listener who only understands meanings literally, i.e. without any kind of pragmatic inference. Formally, it is defined via a function that maps utterances to a probability distribution over world states. The function is defined as follows:

$$P_{L_0}(s∣u) \propto [[u]](s) \cdot P(s)$$

where $[[u]](s)$ is the meaning function. Intuitively, the meaning function should capture whether utterance $u$ is true of world state $s$. For example, a meaning function would tell us that saying "blue" to refer to the blue square is true, while saying "blue" to refer to anything green is false.

Here we define it in code:

In [3]:
def meaning(obj, utt):
    return (utt == obj["color"]) or (utt == obj["shape"])

Run the following cell to see how the meaning function works:

In [4]:
meaning({"color": "blue", "shape": "square"}, "circle")

False

Before we actually implement the literal listener, we'll define a helper function `normalize_rows` that will be useful for successive calculations.

In [5]:
def normalize_rows(matrix):
    """
    Helper function that normalize probabilities across rows to sum to 1
    """
    totals = np.sum(matrix, axis=1)
    return matrix / totals[:, np.newaxis]

Now we can implement our literal listener: given an utterance, it will produce a probability distribution over states.

In [6]:
def literal_listener(utt):
    """
    Simulate a literal listener
    
    Arguments:
    utt: string that represents what is heard by the listener
    
    Return:
    df: pd.DataFrame of object probabilities for all possible utterances
    probs: pd.Series of the object probabilities associated with the given utterance
    """
    # generate the matrix of utterances x world states
    all_counts = np.zeros(shape=(len(utterances), len(objects)))
    for i in range(len(utterances)):
        for j in range(len(objects)):
            curr_utt = utterances[i]
            curr_obj = objects[j]

            if meaning(curr_obj, curr_utt): all_counts[i, j] = 1
            # if I wanted to incorporate a prior I would do it here
                
    data = normalize_rows(all_counts)
    df_cols = [obj["string"] for obj in objects]
    df = pd.DataFrame(data, columns=df_cols, index=utterances)
    return df, df.loc[utt]

In [7]:
df_l0, probs_l0 = literal_listener("square")
df_l0

Unnamed: 0,blue square,blue circle,green square
blue,0.5,0.5,0.0
green,0.0,0.0,1.0
square,0.5,0.0,0.5
circle,0.0,1.0,0.0


In [8]:
probs_l0

blue square     0.5
blue circle     0.0
green square    0.5
Name: square, dtype: float64

So what we see here is that if the literal listener hears "square", there is a 50% chance they think the intended referent is the blue square, a 50% change they think it's the green square, and a 0% chance they think it's the blue circle.

## Pragmatic speaker

The pragmatic speaker is the next layer of the model. Intuitively, the pragmatic speaker is capturing the idea that the speaker is acting approximately rationally to select utterances that maximize the probability that the listener recovers the intended target. This is accomplished by introducing a speaker utility function

$$U_{S_1}(u;s) = \log L_0(s∣u) − C(u)$$

which incorporates the literal listener $L_0$ (the first instance of recursion in our model). This is part of our full expression for the pragmatic speaker, which is a function that maps probabilities over world states to utterances. Formally, it is defined as

$$P_{S_1}(u∣s) \propto \exp(\alpha \cdot U_{S_1}(u;s)),$$
which expands to
$$P_{S_1}(u∣s) \propto \exp(\alpha \cdot (\log L_0(s∣u) − C(u))).$$

The $\exp(\alpha)$ part is the softmax decision rule, which captures the approximate rationality component.

<div class="alert alert-block alert-info">
<b>Side note:</b> We're ignoring the cost term $C(u)$ for now, because we don't have time to talk about it detail, but I encourage you to read more about it in the textbook!
</div>

In [9]:
def pragmatic_speaker(obj, alpha=1):
    """
    Simulate the pragmatic speaker
    
    Arguments:
    obj: dict to represent the object in the world that the speaker wishes to refer to
    alpha: float for speaker optimality (default set to 1)
    
    Return:
    df: pd.DataFrame of utterance probabilities for all possible objects in the world
    probs: pd.Series of utterance probabilities for the specified object
    """
    epsilon = 0.000000001
    all_vals = []
    for curr_utt in utterances:
        _, probs = literal_listener(curr_utt)
        utility = np.array(probs)
        val = np.exp(alpha * np.log(utility + epsilon))
        all_vals.append(val)
        
    data = normalize_rows(np.array(all_vals).T)
    df_idx = [obj["string"] for obj in objects]
    df = pd.DataFrame(data, columns=utterances, index=df_idx)
    return df, df.loc[obj["string"]]

In [10]:
df_s1, probs_s1 = pragmatic_speaker({"color": "blue", "shape": "square", "string": "blue square"})
df_s1

Unnamed: 0,blue,green,square,circle
blue square,0.5,1e-09,0.5,1e-09
blue circle,0.3333333,6.666667e-10,6.666667e-10,0.6666667
green square,6.666667e-10,0.6666667,0.3333333,6.666667e-10


In [11]:
probs_s1

blue      5.000000e-01
green     1.000000e-09
square    5.000000e-01
circle    1.000000e-09
Name: blue square, dtype: float64

So this is saying that if the intended target is blue square, the speaker has 0.5 chance of saying blue, 0.5 chance of saying square, and 0 chance of saying anything else.

## Pragmatic listener

The pragmatic listener is the third and final component of our model. Here, we are capturing the idea that the pragmatic listener acts in accordance with Bayes rules: the probability that they believe the state of the world is $s$ given that they heard utterance $u$ is proportional to the probability of saying $u$ given the target is $s$, times the prior probability of $s$. Note that the probability of $u$ given $s$ is exactly what the pragmatic speaker gave us, so we're plugging it in here (the second instance of recursion in the model) Formally, the pragmatic listener is defined as follows:

$$P_{L_1}(s \vert u) \propto P_{S_1}(u \vert s) \cdot P(s)$$

In [12]:
def pragmatic_listener(utt):
    epsilon = 0.0000001
    all_vals = []
    for curr_obj in objects:
        _, probs = pragmatic_speaker(curr_obj)
        all_vals.append(probs + epsilon)

    data = normalize_rows(np.array(all_vals).T)
    df_cols = [obj["string"] for obj in objects]
    df = pd.DataFrame(data, columns=df_cols, index=utterances)
    return df, df.loc[utt]

In [13]:
df_l1, probs_l1 = pragmatic_listener('blue')
df_l1

Unnamed: 0,blue square,blue circle,green square
blue,0.5999999,0.4,1.208e-07
green,1.514999e-07,1.509999e-07,0.9999997
square,0.5999999,1.208e-07,0.4
circle,1.514999e-07,0.9999997,1.509999e-07


In [14]:
probs_l1

blue square     5.999999e-01
blue circle     4.000000e-01
green square    1.208000e-07
Name: blue, dtype: float64

So this model predicts that a pragmatic listener who hears the utterance "blue" is more likely to think that the intended referent is the blue square.