In [1]:
import numpy as np
import seaborn as sns
from matplotlib import pyplot as plt
from matplotlib import animation
import imageio

In [2]:
# Code based on https://tomekkorbak.com/2019/10/08/lewis-signaling-games/
from tkinter import * 
from PIL import Image, ImageTk



In [3]:
class Sender:
        
    def __init__(self, n_inputs: int, n_messages: int, state_action_probs: np.ndarray = None, eps: float = 1e-6):
        self.n_messages = n_messages
        
        if state_action_probs is not None:
            self.action_weights = state_probs
        else:
            self.message_weights = np.zeros((n_inputs, n_messages))
            self.message_weights.fill(eps)
        
        self.last_situation = (0, 0)
        
    def send_message(self, input: int) -> int:
        probs = np.exp(self.message_weights[input, :])/np.sum(np.exp(self.message_weights[input, :]))
        message = np.random.choice(self.n_messages, p=probs)
        self.last_situation = (input, message)
        return message

    def learn_from_feedback(self, reward: int) -> None:
        self.message_weights[self.last_situation] += reward

In [4]:
class Receiver:
        
    def __init__(self, n_messages: int, n_actions: int, state_action_probs = None, eps: float = 1e-6):
        self.n_actions = n_actions
        if state_action_probs is not None:
            self.action_weights = state_probs.T
        else:
            self.action_weights = np.ndarray((n_messages, n_actions))
            self.action_weights.fill(eps)
        self.last_situation = (0, 0)
        
    def act(self, message: int) -> int:
        probs = np.exp(self.action_weights[message, :])/np.sum(np.exp(self.action_weights[message, :]))
        action = np.random.choice(self.n_actions, p=probs)
        self.last_situation = (message, action)
        return action

    def learn_from_feedback(self, reward: int) -> None:
        self.action_weights[self.last_situation] += reward

In [5]:
class World:
    def __init__(self, n_states: int, n_messages: int, state_prob_distribution: np.ndarray = None, seed: int = 1701):
        self.n_states = n_states
        self.state = 0
        self.rng = np.random.RandomState(seed)
        self.signal_costs = np.ones(n_messages)
        self.signal_costs.fill(1)
        #self.signal_costs = [0.75, 0.1, 0.1, 0.05]
        
        self.state_probs = np.ones((n_states, n_messages))
        if state_prob_distribution is not None:
            self.state_probs = (self.state_probs.T * state_prob_distribution).T

        
    def emit_state(self) -> int:
        self.state = self.rng.randint(self.n_states)
        return self.state
    
    def evaluate_action(self, action: int) -> int:
        return 1 if action == self.state else -1
        #return self.signal_costs[action] if action == self.state else -self.signal_costs[action]

In [14]:
NUM_STATES = 10 #10
NUM_MESSAGES = 14 #10
STATE_PROBS = np.random.uniform(0, NUM_STATES, NUM_STATES)

sender, receiver = Sender(NUM_STATES, NUM_MESSAGES), Receiver(NUM_MESSAGES, NUM_STATES)
world = World(NUM_STATES, NUM_MESSAGES)
past_rewards = 0
matrices = []
for epoch in range(3000):
    world_state = world.emit_state()
    message = sender.send_message(world_state)
    action = receiver.act(message)
    reward = world.evaluate_action(action)
    receiver.learn_from_feedback(reward)
    sender.learn_from_feedback(reward)
    past_rewards += reward
    if epoch % 25 == 0:
        plt.tight_layout(pad=0)
        plot = sns.heatmap(
            np.exp(receiver.action_weights)/np.exp(receiver.action_weights).sum(axis=0), 
            square=True, cbar=False, annot=True, fmt='.1f'
        ).get_figure()
        plt.xlabel('actions')
        plt.ylabel('messages')
        plt.title(f'Receiver\'s weights, rollout {epoch}')
        plt.savefig(f"receiver_{epoch}_n_by_m.png")
        plt.clf()
        
        plot = sns.heatmap(
            np.exp(sender.message_weights)/np.exp(sender.message_weights).sum(axis=0), 
            square=True, cbar=False,annot=True, fmt='.1f'
        ).get_figure()
        plt.xlabel('messages')
        plt.ylabel('world states')
        plt.title(f'Sender\'s weights, rollout {epoch}')
        plt.savefig(f"sender_{epoch}_n_by_m.png")
        plt.clf()
           
    if epoch % 100 == 0:
        print(f'Epoch {epoch}, last 100 epochs reward: {past_rewards/100}')
        print(world_state, message, action, reward)
        past_rewards = 0

print("Observation to message mapping:")
print(sender.message_weights.argmax(1))
print("Message to action mapping:")
print(receiver.action_weights.argmax(1))

Epoch 0, last 100 epochs reward: -0.01
4 10 7 -1
Epoch 100, last 100 epochs reward: -0.76
1 12 1 1
Epoch 200, last 100 epochs reward: -0.68
9 8 4 -1
Epoch 300, last 100 epochs reward: -0.48
9 3 9 1
Epoch 400, last 100 epochs reward: -0.3
6 13 0 -1
Epoch 500, last 100 epochs reward: -0.06
4 6 5 -1
Epoch 600, last 100 epochs reward: 0.1
8 7 8 1
Epoch 700, last 100 epochs reward: 0.16
4 7 8 -1
Epoch 800, last 100 epochs reward: 0.38
7 1 7 1
Epoch 900, last 100 epochs reward: 0.62
1 4 1 1
Epoch 1000, last 100 epochs reward: 0.7
4 6 4 1
Epoch 1100, last 100 epochs reward: 0.74
6 11 3 -1
Epoch 1200, last 100 epochs reward: 0.64
9 3 9 1
Epoch 1300, last 100 epochs reward: 0.82
7 1 7 1
Epoch 1400, last 100 epochs reward: 0.9
4 6 4 1
Epoch 1500, last 100 epochs reward: 0.78
7 1 7 1
Epoch 1600, last 100 epochs reward: 0.84
6 7 8 -1
Epoch 1700, last 100 epochs reward: 0.84
2 12 2 1
Epoch 1800, last 100 epochs reward: 0.86
4 6 4 1
Epoch 1900, last 100 epochs reward: 1.0
9 3 9 1
Epoch 2000, last 10

<Figure size 432x288 with 0 Axes>

In [16]:
print("Observation to message mapping:")
print(sender.message_weights)
print("Message to action mapping:")
print(receiver.action_weights)

Observation to message mapping:
[[-4.99999900e+00 -5.99999900e+00 -4.99999900e+00 -3.99999900e+00
  -5.99999900e+00 -5.99999900e+00 -5.99999900e+00 -4.99999900e+00
  -5.99999900e+00 -5.99999900e+00 -3.99999900e+00 -5.99999900e+00
  -4.99999900e+00  2.21000001e+02]
 [-2.99999900e+00 -2.99999900e+00 -2.99999900e+00 -9.99999000e-01
   2.86000001e+02 -2.99999900e+00 -1.99999900e+00 -2.99999900e+00
  -1.99999900e+00 -1.99999900e+00 -1.99999900e+00 -1.99999900e+00
  -2.99999900e+00 -1.99999900e+00]
 [-2.99999900e+00 -3.99999900e+00 -3.99999900e+00 -3.99999900e+00
  -2.99999900e+00 -2.99999900e+00 -3.99999900e+00 -2.99999900e+00
  -2.99999900e+00 -3.99999900e+00 -2.99999900e+00 -2.99999900e+00
   2.86000001e+02 -3.99999900e+00]
 [-5.99999900e+00 -5.99999900e+00 -5.99999900e+00 -5.99999900e+00
  -5.99999900e+00 -7.99999900e+00 -6.99999900e+00 -6.99999900e+00
   1.73000001e+02 -6.99999900e+00 -5.99999900e+00 -5.99999900e+00
  -5.99999900e+00 -5.99999900e+00]
 [-3.99999900e+00 -5.99999900e+00 -4

In [12]:
def make_gif(filename_base):
    images = []
    for filename in [f'{filename_base}_{i}_n_by_m.png' for i in range(3000) if i % 25 == 0]:
        images.append(imageio.imread(filename))
    imageio.mimsave(f'{filename_base}_n_by_m.gif', images)

In [13]:
make_gif('sender')
make_gif('receiver')

<img src="sender_n_by_m.gif" width="750" align="center">

<img src="receiver_n_by_m.gif" width="750" align="center">