In [21]:
from game import World, Sender, Receiver
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import imageio.v2 as imageio
from IPython.display import HTML
from time import time
import os

In [22]:
def early_stop(epochs, rewards, threshold=0.95):
    return np.sum(rewards[-epochs:]) / epochs > threshold

In [23]:
def make_gif(filename_base, epochs, seed):
    images = []
    nm = filename_base.split('-')[-1]
    for filename in [f'images/{nm}_{i}.png' for i in range(epochs) if i % 25 == 0]:
        images.append(imageio.imread(filename))
    if not os.path.exists(f'gifs/{seed}'):
        os.mkdir(f'gifs/{seed}')
    imageio.mimsave(f'gifs/{seed}/{filename_base}.gif', images, fps=10)

In [24]:
def stress(n, epochs, states, signals, actions, initial_weights):
    found = False
    num_found = 0
    for _ in range(n):
        seed = np.floor(time()).astype(int)
        W, S, R = World(states, seed), Sender(states, signals, initial_weights), Receiver(signals, actions, initial_weights)
        past_rewards = 0
        history = []
        stopped = False
        # stopped_at = -1
        for epoch in range(epochs):
            stimulus = W.get_state()
            signal = S.get_signal(stimulus)
            action = R.get_action(signal)
            reward = W.evaluate(action)
            past_rewards += reward
            S.update(reward)
            R.update(reward)

            if epoch % 25 == 0:
                plt.tight_layout(pad=0)
                plot = sns.heatmap(
                    # np.exp(R.action_weights) /
                    # np.exp(R.action_weights).sum(axis=0),
                    R.action_weights / R.action_weights.sum(axis=0),
                    square=True, cbar=False, annot=True, fmt='.1f'
                ).get_figure()
                plt.xlabel('messages')
                plt.ylabel('actions')
                plt.title(f'Receiver\'s weights, rollout {epoch}')
                plt.savefig(f"images/receiver_{epoch}.png")
                plt.clf()

                plot = sns.heatmap(
                    # np.exp(S.signal_weights) /
                    # np.exp(S.signal_weights).sum(axis=0),
                    S.signal_weights / S.signal_weights.sum(axis=0),
                    square=True, cbar=False, annot=True, fmt='.1f'
                ).get_figure()
                plt.xlabel('world states')
                plt.ylabel('messages')
                plt.title(f'Sender\'s weights, rollout {epoch}')
                plt.savefig(f"images/sender_{epoch}.png")
                plt.clf()

            if epoch % 100 == 0:
                # print(f'Epoch {epoch}, last 100 epochs reward: {past_rewards/100}')
                # print(f"Last State: {stimulus}, Signal: {signal}, Action: {action}, Reward: {reward}")
                history.append(past_rewards/100)
                past_rewards = 0

            if early_stop(7, history, 0.97):
                # print(f'Early stop at epoch {epoch}')
                # stopped_at = epoch
                stopped = True
                break

        if not stopped and history[-1] < 0.8:
            print(f'Possibility of no convergence at seed {seed}')
            print("Making gifs...")
            make_gif(f'{states}-{actions}-{signals}-sender', epochs, seed)
            make_gif(f'{states}-{actions}-{signals}-receiver', epochs, seed)
            print("Observation to message mapping:")
            print(S.signal_weights.argmax(1))
            print("Message to action mapping:")
            print(R.action_weights.argmax(1))
            print("-"*50)
            num_found += 1
            found = True


    if not found:
        print(f'No possibility of no convergence found for {states}-{actions}-{signals} after {n} {epochs}-epoch tries. Try increasing n.')
    print(f'Found {num_found} cases of no convergence, out of {n} trials.')


In [25]:
stress(15, 10000, 3, 3, 3, 1)

No possibility of no convergence found for 3-3-3 after 15 10000-epoch tries. Try increasing n.
Found 0 cases of no convergence, out of 15 trials.


<Figure size 640x480 with 0 Axes>