In [35]:
import numpy as np
import pandas as pd

In [36]:
__author__ = 'Chris Potts'


class RSA:
    """Implementation of the core Rational Speech Acts model.

    Parameters
    ----------
    lexicon : `np.array` or `pd.DataFrame`
        Messages along the rows, states along the columns.
    prior : array-like
        Same length as the number of colums in `lexicon`.
    costs : array-like
        Same length as the number of rows in `lexicon`.
    alpha : float
        The temperature parameter. Default: 1.0
    """
    def __init__(self, lexicon, prior, costs, alpha=1.0):
        self.lexicon = lexicon
        self.prior = np.array(prior)
        self.costs = np.array(costs)
        self.alpha = alpha

    def literal_listener(self):
        """Literal listener predictions, which corresponds intuitively
        to truth conditions with priors.

        Returns
        -------
        np.array or pd.DataFrame, depending on `self.lexicon`.
        The rows correspond to messages, the columns to states.

        """
        return rownorm(self.lexicon * self.prior)

    def speaker(self):
        """Returns a matrix of pragmatic speaker predictions.

        Returns
        -------
        np.array or pd.DataFrame, depending on `self.lexicon`.
        The rows correspond to states, the columns to states.
        """
        lit = self.literal_listener().T
        utilities = self.alpha * (safelog(lit) + self.costs)
        return rownorm(np.exp(utilities))

    def listener(self):
        """Returns a matrix of pragmatic listener predictions.

        Returns
        -------
        np.array or pd.DataFrame, depending on `self.lexicon`.
        The rows correspond to messages, the columns to states.
        """
        spk = self.speaker().T
        return rownorm(spk * self.prior)


def rownorm(mat):
    """Row normalization of np.array or pd.DataFrame"""
    return (mat.T / mat.sum(axis=1)).T


def safelog(vals):
    """Silence distracting warnings about log(0)."""
    with np.errstate(divide='ignore'):
        return np.log(vals)


if __name__ == '__main__':
    """Examples"""

    from IPython.display import display


    def display_reference_game(mod):
        d = mod.lexicon.copy()
        d['costs'] = mod.costs
        d.loc['prior'] = list(mod.prior) + [""]
        d.loc['alpha'] = [mod.alpha] + [" "] * mod.lexicon.shape[1]
        display(d)


    # Core lexiccon:
    msgs = ['шляпа', 'очки', 'усы', 'шрам']
    states = ['r1', 'r2', 'r3', 'r4']
    lex = pd.DataFrame([
        [1.0, 1.0, 1.0, 0.0],
        [1.0, 1.0, 0.0, 1.0],
        [1.0, 0.0, 1.0, 0.0],
        [0.0, 0.0, 1.0, 0.0]], index=msgs, columns=states)

In [37]:
    print("="*70 + "\nEven priors and all-0 message costs\n")
    basic_mod = RSA(lexicon=lex, prior=[0.25, 0.25, 0.25, 0.25], costs=[0.0, 0.0, 0.0, 0.0])

    display_reference_game(basic_mod)

    print("\nLiteral listener")
    display(basic_mod.literal_listener())

    print("\nPragmatic speaker")
    display(basic_mod.speaker())

    print("\nPragmatic listener")
    display(basic_mod.listener())

Even priors and all-0 message costs



Unnamed: 0,r1,r2,r3,r4,costs
шляпа,1.0,1.0,1.0,0.0,0.0
очки,1.0,1.0,0.0,1.0,0.0
усы,1.0,0.0,1.0,0.0,0.0
шрам,0.0,0.0,1.0,0.0,0.0
prior,0.25,0.25,0.25,0.25,
alpha,1.0,,,,



Literal listener


Unnamed: 0,r1,r2,r3,r4
шляпа,0.333333,0.333333,0.333333,0.0
очки,0.333333,0.333333,0.0,0.333333
усы,0.5,0.0,0.5,0.0
шрам,0.0,0.0,1.0,0.0



Pragmatic speaker


Unnamed: 0,шляпа,очки,усы,шрам
r1,0.285714,0.285714,0.428571,0.0
r2,0.5,0.5,0.0,0.0
r3,0.181818,0.0,0.272727,0.545455
r4,0.0,1.0,0.0,0.0



Pragmatic listener


Unnamed: 0,r1,r2,r3,r4
шляпа,0.295302,0.516779,0.187919,0.0
очки,0.16,0.28,0.0,0.56
усы,0.611111,0.0,0.388889,0.0
шрам,0.0,0.0,1.0,0.0


In [39]:
print("="*70 + "\nEven priors and all-0 message costs; alpha = 4\n")
alpha_mod = RSA(lexicon=lex, prior=[0.25, 0.25, 0.25, 0.25], costs=[0.0, 0.0, 0.0, 0.0], alpha=4.0)

display_reference_game(alpha_mod)

print("\nLiteral listener")
display(alpha_mod.literal_listener())

print("\nPragmatic speaker")
display(alpha_mod.speaker())

print("\nPragmatic listener")
display(alpha_mod.listener())

Even priors and all-0 message costs; alpha = 4



Unnamed: 0,r1,r2,r3,r4,costs
шляпа,1.0,1.0,1.0,0.0,0.0
очки,1.0,1.0,0.0,1.0,0.0
усы,1.0,0.0,1.0,0.0,0.0
шрам,0.0,0.0,1.0,0.0,0.0
prior,0.25,0.25,0.25,0.25,
alpha,4.0,,,,



Literal listener


Unnamed: 0,r1,r2,r3,r4
шляпа,0.333333,0.333333,0.333333,0.0
очки,0.333333,0.333333,0.0,0.333333
усы,0.5,0.0,0.5,0.0
шрам,0.0,0.0,1.0,0.0



Pragmatic speaker


Unnamed: 0,шляпа,очки,усы,шрам
r1,0.141593,0.141593,0.716814,0.0
r2,0.5,0.5,0.0,0.0
r3,0.011486,0.0,0.058148,0.930366
r4,0.0,1.0,0.0,0.0



Pragmatic listener


Unnamed: 0,r1,r2,r3,r4
шляпа,0.216808,0.765604,0.017587,0.0
очки,0.086253,0.304582,0.0,0.609164
усы,0.924967,0.0,0.075033,0.0
шрам,0.0,0.0,1.0,0.0


In [40]:
print("="*70 + "\nEven priors, imbalanced message costs\n")
cost_most = RSA(lexicon=lex, prior=[0.25, 0.25, 0.25, 0.25], costs=[-6.0, 0.0, 0.0, 0.0])

display_reference_game(cost_most)

print("\nLiteral listener")
display(cost_most.literal_listener())

print("\nPragmatic speaker")
display(cost_most.speaker())

print("\nPragmatic listener")
display(cost_most.listener())

Even priors, imbalanced message costs



Unnamed: 0,r1,r2,r3,r4,costs
шляпа,1.0,1.0,1.0,0.0,-6.0
очки,1.0,1.0,0.0,1.0,0.0
усы,1.0,0.0,1.0,0.0,0.0
шрам,0.0,0.0,1.0,0.0,0.0
prior,0.25,0.25,0.25,0.25,
alpha,1.0,,,,



Literal listener


Unnamed: 0,r1,r2,r3,r4
шляпа,0.333333,0.333333,0.333333,0.0
очки,0.333333,0.333333,0.0,0.333333
усы,0.5,0.0,0.5,0.0
шрам,0.0,0.0,1.0,0.0



Pragmatic speaker


Unnamed: 0,шляпа,очки,усы,шрам
r1,0.000991,0.399604,0.599406,0.0
r2,0.002473,0.997527,0.0,0.0
r3,0.000551,0.0,0.33315,0.6663
r4,0.0,1.0,0.0,0.0



Pragmatic listener


Unnamed: 0,r1,r2,r3,r4
шляпа,0.246786,0.61605,0.137164,0.0
очки,0.166701,0.416134,0.0,0.417165
усы,0.642756,0.0,0.357244,0.0
шрам,0.0,0.0,1.0,0.0


In [41]:
print("="*70 + "\nEven priors and all-0 message costs; alpha = 4\n")
alpha_mod = RSA(lexicon=lex, prior=[0.1, 0.1, 0.5, 0.3], costs=[0.0, 0.0, 0.0, 0.0])

display_reference_game(alpha_mod)

print("\nLiteral listener")
display(alpha_mod.literal_listener())

print("\nPragmatic speaker")
display(alpha_mod.speaker())

print("\nPragmatic listener")
display(alpha_mod.listener())

Even priors and all-0 message costs; alpha = 4



Unnamed: 0,r1,r2,r3,r4,costs
шляпа,1.0,1.0,1.0,0.0,0.0
очки,1.0,1.0,0.0,1.0,0.0
усы,1.0,0.0,1.0,0.0,0.0
шрам,0.0,0.0,1.0,0.0,0.0
prior,0.1,0.1,0.5,0.3,
alpha,1.0,,,,



Literal listener


Unnamed: 0,r1,r2,r3,r4
шляпа,0.142857,0.142857,0.714286,0.0
очки,0.2,0.2,0.0,0.6
усы,0.166667,0.0,0.833333,0.0
шрам,0.0,0.0,1.0,0.0



Pragmatic speaker


Unnamed: 0,шляпа,очки,усы,шрам
r1,0.280374,0.392523,0.327103,0.0
r2,0.416667,0.583333,0.0,0.0
r3,0.280374,0.0,0.327103,0.392523
r4,0.0,1.0,0.0,0.0



Pragmatic listener


Unnamed: 0,r1,r2,r3,r4
шляпа,0.133581,0.198516,0.667904,0.0
очки,0.098727,0.146719,0.0,0.754554
усы,0.166667,0.0,0.833333,0.0
шрам,0.0,0.0,1.0,0.0


In [47]:
print("="*70 + "\nEven priors and all-0 message costs; alpha = 4\n")
alpha_mod = RSA(lexicon=lex, prior=[0.88, 0.04, 0.04, 0.04], costs=[-10.0, 10.0, 3.0, 0.0], alpha = 1)

display_reference_game(alpha_mod)

print("\nLiteral listener")
display(alpha_mod.literal_listener())

print("\nPragmatic speaker")
display(alpha_mod.speaker())

print("\nPragmatic listener")
display(alpha_mod.listener())

Even priors and all-0 message costs; alpha = 4



Unnamed: 0,r1,r2,r3,r4,costs
шляпа,1.0,1.0,1.0,0.0,-10.0
очки,1.0,1.0,0.0,1.0,10.0
усы,1.0,0.0,1.0,0.0,3.0
шрам,0.0,0.0,1.0,0.0,0.0
prior,0.88,0.04,0.04,0.04,
alpha,1.0,,,,



Literal listener


Unnamed: 0,r1,r2,r3,r4
шляпа,0.916667,0.041667,0.041667,0.0
очки,0.916667,0.041667,0.0,0.041667
усы,0.956522,0.0,0.043478,0.0
шрам,0.0,0.0,1.0,0.0



Pragmatic speaker


Unnamed: 0,шляпа,очки,усы,шрам
r1,2.059194e-09,0.999049,0.000951,0.0
r2,2.061154e-09,1.0,0.0,0.0
r3,1.00981e-06,0.0,0.466178,0.533821
r4,0.0,1.0,0.0,0.0



Pragmatic listener


Unnamed: 0,r1,r2,r3,r4
шляпа,0.042852,0.00195,0.955198,0.0
очки,0.916594,0.041703,0.0,0.041703
усы,0.042936,0.0,0.957064,0.0
шрам,0.0,0.0,1.0,0.0
