# Continuous incremental RSA

In this notebook, we provide Python implementations of 2 variants of the RSA framework:
- _Continuous_ semantics RSA [(Degen et al. 2020)](http://alpslab.stanford.edu/papers/2020_DegenEtAl.pdf): this paper introduces the idea of using a continuous, rather than Boolean, meaning function in which the semantic value of a given lexical item is in the range [0, 1], where a higher value represents a less "noisy" signal.
- _Continuous incremental_ semantics RSA [(Waldon and Degen 2021)](https://aclanthology.org/2021.scil-1.19.pdf): this paper combines the aforementioned _continuous_ semantics with _incremental_ semantics [(Cohn-Gordon et al. 2020)](https://aclanthology.org/W19-0109.pdf), which models pragmatic reasoning at the word level, rather than over complete utterances.

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import string

## Part 1: Continuous RSA

### Define objects and utterances

First, we assume a world in which the possible lexical items are 'red', 'blue', 'pin', 'dress'.

In [2]:
adjectives = ['red', 'blue']
nouns = ['pin', 'dress']

Let's define a function to create all the possible objects in the world.

In [3]:
def get_all_objects(adjectives, nouns):
    objects = []
    for adj in adjectives:
        for noun in nouns:
            obj = {}
            obj["color"] = adj
            obj["shape"] = noun
            obj["string"] = adj + " " + noun
            objects.append(obj)
            
    return objects

Now we can see all the objects:

In [4]:
objects = get_all_objects(adjectives, nouns)
objects

[{'color': 'red', 'shape': 'pin', 'string': 'red pin'},
 {'color': 'red', 'shape': 'dress', 'string': 'red dress'},
 {'color': 'blue', 'shape': 'pin', 'string': 'blue pin'},
 {'color': 'blue', 'shape': 'dress', 'string': 'blue dress'}]

Next, we can define all the possible utterances. Similar to above, we'll write a function to compressly express all the combos.

In [5]:
def get_all_utterances(adjectives, nouns):
    utterances = []
    utterances.extend(adjectives)
    utterances.extend(nouns)
    for adj in adjectives:
        for noun in nouns:
            adj_first = adj + " " + noun
            noun_first = noun + " " + adj
            utterances.append(adj_first)
            utterances.append(noun_first)
            
    return utterances

In [6]:
utterances = get_all_utterances(adjectives, nouns)
utterances

['red',
 'blue',
 'pin',
 'dress',
 'red pin',
 'pin red',
 'red dress',
 'dress red',
 'blue pin',
 'pin blue',
 'blue dress',
 'dress blue']

### Define meaning functions

Now, we're going to define a meaning function, which is slightly different from how we implemented it in `intro-to-rsa.ipynb`, where utterances could only be one word long. This meaning function maps an utterance to every possible object for which that utterance applies, taking into account multi-word utterances (so ``blue square`` could only refer to an object that is both blue and square, not, for example, any blue thing).

**Note that this meaning function still uses Boolean (i.e. not continuous) semantics.** In the ensuing sections, we will define a continuous meaning function that also captures asymmetries in referential noise between adjectives and nouns, but we'll hold off on that for now.

In [7]:
def boolean_meaning(obj, utt):
    utt_as_tokens = utt.split()
    all_contained = set(utt_as_tokens).issubset(obj.values())
            
    return int(all_contained)

In [8]:
boolean_meaning({'color': 'red', 'shape': 'square', 'string': 'red square'}, 'square')

1

Now that we've seen how a meaning function works for a multi-word utterance, we are going to dial up the complexity and implement the **continuous meaning function**. The semantic value of an utterance is the product of the semantic value of each of the utterances constituent tokens. We assume semantic values of 0.99 for nouns ('pin', 'dress') and 0.95 for adjectives ('red', 'blue'): stated otherwise, adjectives are noisier than nouns.  As an example, the semantic value associated with the utterance 'red dress' is 0.99 * 0.95 = 0.9405.

In [9]:
def token_applies_to_obj(obj, token):
    return token in obj.values()

def continuous_meaning(obj, utt, v_adj=0.95, v_noun=0.99):
    total = 1
    for token in utt.split():
        if token_applies_to_obj(obj, token):
            if token in adjectives: total *= v_adj
            if token in nouns: total *= v_noun
        else:
            if token in adjectives: total *= (1 - v_adj)
            if token in nouns: total *= (1 - v_noun)
                
    return total

In [10]:
continuous_meaning({'color': 'red', 'shape': 'pin', 'string': 'red pin'}, 'red pin')

0.9405

### Literal listener

Our `literal_listener` function is near identical to the standard literal listener, as introduced in `intro-to-rsa.ipynb`, with one exception: we now can specify the `meaning_fn`. We'll compare what happens when we use Boolean vs. continuous semantics.

In [11]:
def normalize_rows(matrix):
    """
    Helper function that normalize probabilities across rows to sum to 1
    """
    totals = np.sum(matrix, axis=1)
    return matrix / totals[:, np.newaxis]

def literal_listener(utt, meaning_fn):
    # generate the matrix of utterances x world states
    all_counts = np.zeros(shape=(len(utterances), len(objects)))
    for i in range(len(utterances)):
        for j in range(len(objects)):
            curr_utt = utterances[i]
            curr_obj = objects[j]

            all_counts[i, j] = meaning_fn(curr_obj, curr_utt)
            # if I wanted to incorporate a prior I would do it here
                
    data = normalize_rows(all_counts)
    df_cols = [obj["string"] for obj in objects]
    df = pd.DataFrame(data, columns=df_cols, index=utterances)
    return df, df.loc[utt]

In [12]:
l0_df_bool, l0_probs_bool = literal_listener('red', meaning_fn=boolean_meaning)
l0_df_bool

Unnamed: 0,red pin,red dress,blue pin,blue dress
red,0.5,0.5,0.0,0.0
blue,0.0,0.0,0.5,0.5
pin,0.5,0.0,0.5,0.0
dress,0.0,0.5,0.0,0.5
red pin,1.0,0.0,0.0,0.0
pin red,1.0,0.0,0.0,0.0
red dress,0.0,1.0,0.0,0.0
dress red,0.0,1.0,0.0,0.0
blue pin,0.0,0.0,1.0,0.0
pin blue,0.0,0.0,1.0,0.0


In [13]:
l0_probs_bool

red pin       0.5
red dress     0.5
blue pin      0.0
blue dress    0.0
Name: red, dtype: float64

In [14]:
l0_df_cont, l0_probs_cont = literal_listener('red', meaning_fn=continuous_meaning)
l0_df_cont

Unnamed: 0,red pin,red dress,blue pin,blue dress
red,0.475,0.475,0.025,0.025
blue,0.025,0.025,0.475,0.475
pin,0.495,0.005,0.495,0.005
dress,0.005,0.495,0.005,0.495
red pin,0.9405,0.0095,0.0495,0.0005
pin red,0.9405,0.0095,0.0495,0.0005
red dress,0.0095,0.9405,0.0005,0.0495
dress red,0.0095,0.9405,0.0005,0.0495
blue pin,0.0495,0.0005,0.9405,0.0095
pin blue,0.0495,0.0005,0.9405,0.0095


In [15]:
l0_probs_cont

red pin       0.475
red dress     0.475
blue pin      0.025
blue dress    0.025
Name: red, dtype: float64

With the continuous meaning function, notice how we have nonzero values for utterances that are literally false. This is by design!

### Pragmatic speaker

Our pragmatic speaker is the same as in the standard RSA setup, except we specify that we want the literal listener to use continuous semantics. For simplicity, we assume 0 cost for any utterance.

In [16]:
def cost(utt):
    return 0

def pragmatic_speaker(obj, alpha=1):
    all_vals = []
    for curr_utt in utterances:
        _, probs = literal_listener(curr_utt, meaning_fn=continuous_meaning)
        utility = np.array(probs)
        val = np.exp(alpha * (np.log(utility) - cost(curr_utt)))
        all_vals.append(val)
        
    data = normalize_rows(np.array(all_vals).T)
    df_idx = [obj["string"] for obj in objects]
    df = pd.DataFrame(data, columns=utterances, index=df_idx)
    
    return df, df.loc[obj["string"]]

In [17]:
s1_df, s1_probs = pragmatic_speaker({'color': 'red', 'shape': 'pin', 'string': 'red pin'}, alpha=3)
s1_df

Unnamed: 0,red,blue,pin,dress,red pin,pin red,red dress,dress red,blue pin,pin blue,blue dress,dress blue
red pin,0.056629,8e-06,0.0640871,6.604882e-08,0.4395734,0.4395734,4.530289e-07,4.530289e-07,6.40871e-05,6.40871e-05,6.604882e-11,6.604882e-11
red dress,0.056629,8e-06,6.604882e-08,0.0640871,4.530289e-07,4.530289e-07,0.4395734,0.4395734,6.604882e-11,6.604882e-11,6.40871e-05,6.40871e-05
blue pin,8e-06,0.056629,0.0640871,6.604882e-08,6.40871e-05,6.40871e-05,6.604882e-11,6.604882e-11,0.4395734,0.4395734,4.530289e-07,4.530289e-07
blue dress,8e-06,0.056629,6.604882e-08,0.0640871,6.604882e-11,6.604882e-11,6.40871e-05,6.40871e-05,4.530289e-07,4.530289e-07,0.4395734,0.4395734


In [18]:
s1_probs

red           5.662861e-02
blue          8.256102e-06
pin           6.408710e-02
dress         6.604882e-08
red pin       4.395734e-01
pin red       4.395734e-01
red dress     4.530289e-07
dress red     4.530289e-07
blue pin      6.408710e-05
pin blue      6.408710e-05
blue dress    6.604882e-11
dress blue    6.604882e-11
Name: red pin, dtype: float64

While this is a neat exercise, we notice that the utterances 'blue pin' and 'pin blue' have the same probability. This is not the behavior we want: our model should favor one utterance ordering over another. Thus to break this symmetry, we move to the incremental setting.

## Part 2: Incremental continuous RSA

In this part, we will be extending continuous semantics RSA to model utterances _incrementally_ rather than globally. 

### Define incremental continuous meaning function

Following Cohn-Gordon et al. 2019:

For any partial sequence $c$, object $w$ and set of referents $W$, 

$[[c]](w) \in [0, 1]$ is the number of full-utterance extensions of $c$ true for $w$, divided by the number of possible extensions of $c$ into full utterances that are true of any object in $W$.

The following function is our implementation of that meaning function:

In [19]:
def inc_cont_meaning(obj, utt):
    num_true_extensions = []
    num_possible_extensions = []
    
    utility = 0
    for curr_utt in utterances:
        
        utt_starts_with = curr_utt.startswith(utt)
        
        if utt_starts_with:
            num_possible_extensions.append(curr_utt)
            
            curr_utt_as_tokens = curr_utt.split()
            utility += continuous_meaning(obj, utt)
            
    return utility/len(num_possible_extensions)

In [20]:
inc_cont_meaning({'color': 'red', 'shape': 'pin', 'string': 'red pin'}, 'red')

0.9499999999999998

### Word-level literal listener

In incremental RSA, utterances are processed at the word level. The only difference between the standard literal listener and the word-level literal listener is that the word-level $L_0^{\text{WORD}}$ should only accept a single word as an argument.

We show the implementation here for the sake of completeness, but this is interchangeable with the `literal_listener` function defined earlier.

In [21]:
def word_level_literal_listener(word, meaning_fn):
    # generate the matrix of utterances x world states
    all_counts = np.zeros(shape=(len(utterances), len(objects)))
    for i in range(len(utterances)):
        for j in range(len(objects)):
            curr_utt = utterances[i]
            curr_obj = objects[j]

            all_counts[i, j] = meaning_fn(curr_obj, curr_utt)
    
    data = normalize_rows(all_counts)
    df_cols = [obj["string"] for obj in objects]
    df = pd.DataFrame(data, columns=df_cols, index=utterances)
    return df, df.loc[word]

In [22]:
df_word_level_l0, _ = word_level_literal_listener('red', inc_cont_meaning)
df_word_level_l0

Unnamed: 0,red pin,red dress,blue pin,blue dress
red,0.475,0.475,0.025,0.025
blue,0.025,0.025,0.475,0.475
pin,0.495,0.005,0.495,0.005
dress,0.005,0.495,0.005,0.495
red pin,0.9405,0.0095,0.0495,0.0005
pin red,0.9405,0.0095,0.0495,0.0005
red dress,0.0095,0.9405,0.0005,0.0495
dress red,0.0095,0.9405,0.0005,0.0495
blue pin,0.0495,0.0005,0.9405,0.0095
pin blue,0.0495,0.0005,0.9405,0.0095


### Word-level pragmatic speaker

First, we define the helper function `get_possible_completions`:

In [23]:
def get_possible_completions(utt):
    """
    Given a partial utterance, return all the possible completions
    """
    max_utt_len = len(utt.split()) + 1
    
    possible_completions = []
    for curr_utt in utterances:
        if curr_utt.startswith(utt) and len(curr_utt.split()) <= max_utt_len:
            possible_completions.append(curr_utt)
            
    return possible_completions

In [24]:
get_possible_completions('red')

['red', 'red pin', 'red dress']

Now we can define the word-level pragmatic speaker $S_1^{\text{WORD}}$. The key difference here between this and standard RSA is that you are only iterating over the utterances that are possible completions of the context, and _not_ all possible utterances. Thus, we need to pass in `context` as an argument.

In [25]:
def word_level_pragmatic_speaker(obj, context, alpha, v_adj, v_noun):
    all_vals = []
    possible_utterances = get_possible_completions(context)
    for curr_utt in possible_utterances:
        _, probs = word_level_literal_listener(curr_utt, meaning_fn=inc_cont_meaning) # specify meaning_fn here
        utility = np.array(probs)
        val = np.exp(alpha * (np.log(utility) - cost(curr_utt)))
        all_vals.append(val)
    data = normalize_rows(np.array(all_vals).T)
    df_idx = [obj["string"] for obj in objects]
    df = pd.DataFrame(data, columns=possible_utterances, index=df_idx)
    
    return df, df.loc[obj["string"]]

In [26]:
df_word_level_s1, _ = word_level_pragmatic_speaker({'color': 'red', 'shape': 'pin', 'string': 'red pin'}, context = 'red', alpha=5, v_adj=0.95, v_noun=0.99)
df_word_level_s1

Unnamed: 0,red,red pin,red dress
red pin,0.031815,0.968185,1.018081e-10
red dress,0.031815,1.018081e-10,0.968185
blue pin,0.031815,0.968185,1.018081e-10
blue dress,0.031815,1.018081e-10,0.968185


### Utterance level incremental pragmatic speaker

Finally, we can define the utterance level incremental pragmatic speaker $S_1^{\text{UTT}}$. Following the definition from Cohn-Gordon et al. 2019:

For a complete utterance $u$ and object $w$:

$S_1^{\text{UTT}}(u \vert w) = \prod_{i=1}^n S_1^{\text{WORD}}(u_i \vert c=[u_1...u_{i-1}], w)$

where $c$ is the partial utterance up until index $i$.

In [27]:
def incremental_pragmatic_speaker(obj, utt, alpha=5, v_adj=0.95, v_noun=0.99):
    all_vals = []
    utt_len = len(utt.split())
    context = ''
    val = 1
    for i in range(utt_len):

        _, probs = word_level_pragmatic_speaker(obj, context, alpha, v_adj, v_noun)
        partial_utt = " ".join(utt.split()[:i+1])
        index_of_partial_utt = list(probs.index).index(partial_utt)

        val *= list(probs)[index_of_partial_utt]
        
        context = partial_utt

    return val

Finally, we can see how the continuous incremental RSA model makes predictions about the probability of the utterance 'red pin' vs. 'pin red':

In [28]:
incremental_pragmatic_speaker({'color': 'red', 'shape': 'pin', 'string': 'red pin'}, 'red pin')

0.4343550191544007

In [29]:
incremental_pragmatic_speaker({'color': 'red', 'shape': 'pin', 'string': 'red pin'}, 'pin red')

0.5299681170538035

We see that under the given assumptions, namely that adjectives are noisier than nouns, the model predicts that 'pin red', i.e. the postnominal ordering is favored.