# Basic Lewis Signaling Games
Employs Roth-Erev (Herrnstein) learning.

In [6]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import imageio.v2 as imageio
from IPython.display import HTML
from time import time

In [4]:
# define constants
epochs = 10000
# seed = np.floor(time()).astype(int)
seed = 0
states = 3
actions = 3
signals = 3
initial_weights = 1

# constants for early stopping
stable_epochs = 7
threshold = 0.97

# Setup
Classes - in the state that produces the most basic signaling game.
Initial code influenced heavily by https://tomekkorbak.com/2019/10/08/lewis-signaling-games/

In [3]:
class World:
    def __init__(self, n_states: int, seed: int) -> None:
        self.n_states = n_states
        self.state = 0
        self.random = np.random.RandomState(seed)

    def get_state(self) -> int:
        self.state = self.random.randint(self.n_states)
        return self.state

    def evaluate(self, action: int) -> int:
        return 1 if action == self.state else 0

In [4]:
class Sender:
    def __init__(self, n_stimuli: int, n_signals: int, q_not: float = 1e-6) -> None:
        # n_stimuli: number of possible states in the world,
        #            each corresponding to a stimulus
        # n_signals: number of signals that can be sent in response,
        #            usually equal to the number of states in the world
        # q_not:     initial probabilities of sending each signal before a reward
        self.n_signals = n_signals
        self.signal_weights = np.zeros((n_stimuli, n_signals))
        self.signal_weights.fill(q_not)
        self.last_situation = (0, 0)

    def get_signal(self, stimulus: int) -> int:
        # p(i) = q(i) / sum(q)
        num = self.signal_weights[stimulus, :]
        den = np.sum(self.signal_weights[stimulus, :])
        probabilities = num / den
        signal = np.random.choice(self.n_signals, p=probabilities)
        self.last_situation = (stimulus, signal)
        return signal

    def update(self, reward: int) -> None:
        stimulus, signal = self.last_situation
        self.signal_weights[stimulus, signal] += reward

In [5]:
class Receiver:
    def __init__(self, n_signals, n_actions, q_not: float = 1e-6) -> None:
        # n_signals: number of signals that can be sent in response,
        #            usually equal to the number of states in the world
        # n_actions: number of actions that can be taken in response,
        #            usually equal to the number of states in the world
        # q_not:     initial probabilities of taking each action before a reward
        self.n_actions = n_actions
        self.action_weights = np.zeros((n_signals, n_actions))
        self.action_weights.fill(q_not)
        self.last_situation = (0, 0)

    def get_action(self, signal: int) -> int:
        # p(i) = q(i) / sum(q)
        num = self.action_weights[signal, :]
        den = np.sum(self.action_weights[signal, :])
        probabilities = num / den
        action = np.random.choice(self.n_actions, p=probabilities)
        self.last_situation = (signal, action)
        return action

    def update(self, reward: int) -> None:
        signal, action = self.last_situation
        self.action_weights[signal, action] += reward


In [1]:
def make_gif(filename_base):
    images = []
    nm = filename_base.split('-')[-1]
    for filename in [f'images/{nm}_{i}.png' for i in range(epochs) if i % 25 == 0]:
        images.append(imageio.imread(filename))
    imageio.mimsave(f'{filename_base}.gif', images, fps=10)
    # display(Image(filename=f'{filename_base}.gif'))
    display(HTML('<img src="{}">'.format(f'{filename_base}.gif')))

In [7]:
def early_stop(epochs, rewards, threshold=0.95):
    return np.sum(rewards[-epochs:]) / epochs > threshold

## Experiment

In [8]:
W = World(states, seed)
S = Sender(states, signals, q_not=initial_weights)
R = Receiver(signals, actions, q_not=initial_weights)
past_rewards = 0
history = []

In [9]:
for epoch in range(epochs):
    stimulus = W.get_state()
    signal = S.get_signal(stimulus)
    action = R.get_action(signal)
    reward = W.evaluate(action)
    past_rewards += reward
    S.update(reward)
    R.update(reward)

    if epoch % 25 == 0:
        plt.tight_layout(pad=0)
        plot = sns.heatmap(
            # np.exp(R.action_weights) /
            # np.exp(R.action_weights).sum(axis=0),
            R.action_weights / R.action_weights.sum(axis=0),
            square=True, cbar=False, annot=True, fmt='.1f'
        ).get_figure()
        plt.xlabel('messages')
        plt.ylabel('actions')
        plt.title(f'Receiver\'s weights, rollout {epoch}')
        plt.savefig(f"images/receiver_{epoch}.png")
        plt.clf()

        plot = sns.heatmap(
            # np.exp(S.signal_weights) /
            # np.exp(S.signal_weights).sum(axis=0),
            S.signal_weights / S.signal_weights.sum(axis=0),
            square=True, cbar=False, annot=True, fmt='.1f'
        ).get_figure()
        plt.xlabel('world states')
        plt.ylabel('messages')
        plt.title(f'Sender\'s weights, rollout {epoch}')
        plt.savefig(f"images/sender_{epoch}.png")
        plt.clf()

    if epoch % 100 == 0:
        print(f'Epoch {epoch}, last 100 epochs reward: {past_rewards/100}')
        # print(f"Last State: {stimulus}, Signal: {signal}, Action: {action}, Reward: {reward}")
        history.append(past_rewards/100)
        past_rewards = 0

    if early_stop(stable_epochs, history, threshold):
        print(f'Early stop at epoch {epoch}')
        epochs = epoch
        break

Epoch 0, last 100 epochs reward: 0.0
Epoch 100, last 100 epochs reward: 0.43
Epoch 200, last 100 epochs reward: 0.42
Epoch 300, last 100 epochs reward: 0.52
Epoch 400, last 100 epochs reward: 0.63
Epoch 500, last 100 epochs reward: 0.67
Epoch 600, last 100 epochs reward: 0.85
Epoch 700, last 100 epochs reward: 0.75
Epoch 800, last 100 epochs reward: 0.85
Epoch 900, last 100 epochs reward: 0.79
Epoch 1000, last 100 epochs reward: 0.87
Epoch 1100, last 100 epochs reward: 0.86
Epoch 1200, last 100 epochs reward: 0.89
Epoch 1300, last 100 epochs reward: 0.9
Epoch 1400, last 100 epochs reward: 0.94
Epoch 1500, last 100 epochs reward: 0.91
Epoch 1600, last 100 epochs reward: 0.9
Epoch 1700, last 100 epochs reward: 0.93
Epoch 1800, last 100 epochs reward: 0.93
Epoch 1900, last 100 epochs reward: 0.94
Epoch 2000, last 100 epochs reward: 0.94
Epoch 2100, last 100 epochs reward: 0.89
Epoch 2200, last 100 epochs reward: 0.97
Epoch 2300, last 100 epochs reward: 0.97
Epoch 2400, last 100 epochs rew

<Figure size 640x480 with 0 Axes>

In [7]:
make_gif(f'{states}-{actions}-{signals}-sender')

In [8]:
make_gif(f'{states}-{actions}-{signals}-receiver')

In [12]:
print("Observation to message mapping:")
print(S.signal_weights.argmax(1))
print("Message to action mapping:")
print(R.action_weights.argmax(1))

Observation to message mapping:
[0 2 1]
Message to action mapping:
[0 2 1]
