In [1]:
import numpy as np
import seaborn as sns
from matplotlib import pyplot as plt
from matplotlib import animation
import imageio

In [2]:
# Code based on https://tomekkorbak.com/2019/10/08/lewis-signaling-games/
from tkinter import * 
from PIL import Image, ImageTk



In [3]:
class Sender:
        
    def __init__(self, n_inputs: int, n_messages: int, state_action_probs: np.ndarray = None, eps: float = 1e-6):
        self.n_messages = n_messages
        
        if state_action_probs is not None:
            self.action_weights = state_probs
        else:
            self.message_weights = np.zeros((n_inputs, n_messages))
            self.message_weights.fill(eps)
        
        self.last_situation = (0, 0)
        
    def send_message(self, input: int) -> int:
        probs = np.exp(self.message_weights[input, :])/np.sum(np.exp(self.message_weights[input, :]))
        message = np.random.choice(self.n_messages, p=probs)
        self.last_situation = (input, message)
        return message

    def learn_from_feedback(self, reward: int) -> None:
        self.message_weights[self.last_situation] += reward

In [4]:
class Receiver:
        
    def __init__(self, n_messages: int, n_actions: int, state_action_probs = None, eps: float = 1e-6):
        self.n_actions = n_actions
        if state_action_probs is not None:
            self.action_weights = state_probs.T
        else:
            self.action_weights = np.ndarray((n_messages, n_actions))
            self.action_weights.fill(eps)
        self.last_situation = (0, 0)
        
    def act(self, message: int) -> int:
        probs = np.exp(self.action_weights[message, :])/np.sum(np.exp(self.action_weights[message, :]))
        action = np.random.choice(self.n_actions, p=probs)
        self.last_situation = (message, action)
        return action

    def learn_from_feedback(self, reward: int) -> None:
        self.action_weights[self.last_situation] += reward

In [18]:
class World:
    def __init__(self, n_states: int, n_messages: int, signal_cost_dist:np.ndarray = None, state_prob_distribution: np.ndarray = None, seed: int = 1701):
        self.n_states = n_states
        self.state = 0
        self.rng = np.random.RandomState(seed)
        self.signal_costs = np.ones(n_messages)
        if signal_cost_dist is not None:
            self.signal_costs = signal_cost_dist
        
        self.state_probs = np.ones((n_states, n_messages))
        if state_prob_distribution is not None:
            self.state_probs = (self.state_probs.T * state_prob_distribution).T
        
    def emit_state(self) -> int:
        self.state = self.rng.randint(self.n_states)
        return self.state
    
    def evaluate_action(self, action: int) -> int:
        return self.signal_costs[action] if action == self.state else -self.signal_costs[action]

In [23]:
NUM_STATES = 4
NUM_MESSAGES = 4
STATE_PROBS = [0.96, 0.01, 0.01, 0.01] # np.random.uniform(0, NUM_STATES, NUM_STATES)
COST_DIST = [0.96, 0.01, 0.01, 0.01]

sender, receiver = Sender(NUM_STATES, NUM_MESSAGES), Receiver(NUM_MESSAGES, NUM_STATES)
world = World(NUM_STATES, NUM_MESSAGES)
past_rewards = 0
matrices = []
for epoch in range(2800):
    world_state = world.emit_state()
    message = sender.send_message(world_state)
    action = receiver.act(message)
    reward = world.evaluate_action(action)
    receiver.learn_from_feedback(reward)
    sender.learn_from_feedback(reward)
    past_rewards += reward
    if epoch % 25 == 0:
        plt.tight_layout(pad=0)
        plot = sns.heatmap(
            np.exp(receiver.action_weights)/np.exp(receiver.action_weights).sum(axis=0), 
            square=True, cbar=False, annot=True, fmt='.1f'
        ).get_figure()
        plt.xlabel('messages')
        plt.ylabel('actions')
        plt.title(f'Receiver\'s weights, rollout {epoch}')
        plt.savefig(f"receiver_{epoch}.png")
        plt.clf()
        
        plot = sns.heatmap(
            np.exp(sender.message_weights)/np.exp(sender.message_weights).sum(axis=0), 
            square=True, cbar=False,annot=True, fmt='.1f'
        ).get_figure()
        plt.xlabel('world states')
        plt.ylabel('messages')
        plt.title(f'Sender\'s weights, rollout {epoch}')
        plt.savefig(f"sender_{epoch}.png")
        plt.clf()
           
    if epoch % 100 == 0:
        print(f'Epoch {epoch}, last 100 epochs reward: {past_rewards/100}')
        print(world_state, message, action, reward)
        past_rewards = 0

print("Observation to message mapping:")
print(sender.message_weights.argmax(1))
print("Message to action mapping:")
print(receiver.action_weights.argmax(1))

Epoch 0, last 100 epochs reward: 0.01
2 3 2 1.0
Epoch 100, last 100 epochs reward: 0.52
1 0 1 1.0
Epoch 200, last 100 epochs reward: 1.0
3 1 3 1.0
Epoch 300, last 100 epochs reward: 1.0
2 3 2 1.0
Epoch 400, last 100 epochs reward: 1.0
0 2 0 1.0
Epoch 500, last 100 epochs reward: 1.0
3 1 3 1.0
Epoch 600, last 100 epochs reward: 1.0
1 0 1 1.0
Epoch 700, last 100 epochs reward: 1.0
3 1 3 1.0
Epoch 800, last 100 epochs reward: 1.0
3 1 3 1.0
Epoch 900, last 100 epochs reward: 1.0
1 0 1 1.0
Epoch 1000, last 100 epochs reward: 1.0
3 1 3 1.0
Epoch 1100, last 100 epochs reward: 1.0
3 1 3 1.0
Epoch 1200, last 100 epochs reward: 1.0
1 0 1 1.0
Epoch 1300, last 100 epochs reward: 1.0
1 0 1 1.0
Epoch 1400, last 100 epochs reward: 1.0
0 2 0 1.0
Epoch 1500, last 100 epochs reward: 1.0
2 3 2 1.0
Epoch 1600, last 100 epochs reward: 1.0
1 0 1 1.0
Epoch 1700, last 100 epochs reward: 1.0
3 1 3 1.0
Epoch 1800, last 100 epochs reward: 1.0
2 3 2 1.0
Epoch 1900, last 100 epochs reward: 1.0
1 0 1 1.0
Epoch 2000

<Figure size 432x288 with 0 Axes>

In [24]:
def make_gif(filename_base):
    images = []
    for filename in [f'{filename_base}_{i}.png' for i in range(2800) if i % 25 == 0]:
        images.append(imageio.imread(filename))
    imageio.mimsave(f'{filename_base}.gif', images)

In [25]:
make_gif('sender')
make_gif('receiver')

<img src="sender.gif" width="750" align="center">

<img src="receiver.gif" width="750" align="center">