# Setup 1: Predefined Regions and Sender Stimulus Generalization

In [3]:
import numpy as np
import matplotlib.pyplot as plt
import imageio.v2 as imageio
import seaborn as sns
import os
from IPython.display import HTML

In [4]:
def stimgen(n: int) -> float:
    # Gaussian with FWHM of 2.
    return np.exp(- (n**2) / (4 / np.log(2)))

In [5]:
class World:
    def __init__(self, n_states: int, 
                 n_signals: int, n_actions: int, 
                 seed: int = 0, reward = (1, -1)) -> None:
        self.setup = (n_signals, n_actions)
        self.n_states = n_states
        self.state = 0
        self.random = np.random.RandomState(seed)
        self.reward = reward

    def get_state(self) -> int:
        self.state = self.random.randint(self.n_states)
        return self.state

    def evaluate(self, action: int) -> int:
        step = self.n_states / self.setup[0]
        correct = self.state // step
        return self.reward[0] if action == correct else -abs(self.reward[1])

In [6]:
class Sender:
    def __init__(self, n_stimuli: int, n_signals: int,
                  q_not: float = 1e-6, stimgen: bool = False) -> None:
        # n_stimuli: number of possible states in the world,
        #            each corresponding to a stimulus
        # n_signals: number of signals that can be sent in response,
        #            usually equal to the number of states in the world
        # q_not:     initial signal propensity values. Final value of null signal.
        self.n_signals = n_signals + 1 # +1 here represents null signal.    
        self.n_states  = n_stimuli
        self.signal_weights = np.zeros((self.n_signals, n_stimuli))
        self.rew_hist = np.zeros_like(self.signal_weights)
        self.signal_weights.fill(q_not)
        self.last_situation = (0, 0)
        self.stimgen = stimgen

    def get_signal(self, stimulus: int) -> int:
        # exponential calculation
        num = np.exp(self.signal_weights[:, stimulus])
        den = np.sum(np.exp(self.signal_weights[:, stimulus]))
        probabilities = num / den
        signal = np.random.choice(self.n_signals, p=probabilities)
        if signal == self.n_signals - 1:
            return -1
        self.last_situation = (stimulus, signal)
        return signal

    def update(self, reward: int) -> None:
        # I am capping weight values at 308 due to overflow errors.
        # They must similarly be floored at -323.
        stimulus, signal = self.last_situation
        q_last = self.signal_weights[signal, stimulus]
        self.signal_weights[signal, stimulus] = max(-323, min(q_last + reward, 308))
        self.rew_hist[signal, stimulus] += reward
        # print(f"{stimulus}: {reward}")
        # stimulus generalization
        if self.stimgen:
            l = r = stimulus
            for i in range(1,4):
                re = reward * stimgen(i)
                # reward right
                r += 1
                if r >= self.n_states and self.n_signals > 3:
                    # wrap around ONLY if there are more than 2 signals
                    # regions cannot border each other twice
                    r = 0
                
                if r < self.n_states:
                    # print(f'{r}: {re}', end=' ')
                    q_last = self.signal_weights[signal, r]
                    self.signal_weights[signal, r] = max(-323, min(q_last + re, 308))
                    self.rew_hist[signal, r] += re

                # reward left
                l -= 1
                if l < 0 and self.n_signals > 3:
                    l = self.n_states - 1
                
                if l >= 0:
                    # print(f'{l}: {re}', end=' ')
                    q_last = self.signal_weights[signal, l]
                    self.signal_weights[signal, l] = max(-323, min(q_last + re, 308))
                    self.rew_hist[signal, l] += re

    def checkpoint(self):
        self.rew_hist /= 100
        sav = np.copy(self.rew_hist)
        self.rew_hist = np.zeros_like(self.signal_weights)
        return sav

In [7]:
class Receiver:
    def __init__(self, n_signals, n_actions, 
                 q_not: float = 1e-6, stimgen: bool = False) -> None:
        # n_signals: number of signals that can be sent in response,
        #            usually equal to the number of states in the world
        # n_actions: number of actions that can be taken in response,
        #            usually equal to the number of states in the world
        # q_not:     initial action propensity value
        self.n_actions = n_actions
        self.n_signals = n_signals
        self.action_weights = np.zeros((n_signals, n_actions))
        self.rew_hist = np.zeros_like(self.action_weights)
        self.action_weights.fill(q_not)
        self.last_situation = (0, 0)
        self.stimgen = stimgen

    def get_action(self, signal: int) -> int:
        # exponential calculation
        num = np.exp(self.action_weights[signal, :])
        den = np.sum(np.exp(self.action_weights[signal, :]))
        probabilities = num / den
        action = np.random.choice(self.n_actions, p=probabilities)
        self.last_situation = (signal, action)

        return action
    
    def update(self, reward: int) -> None:
        signal, action = self.last_situation
        q_last = self.action_weights[signal, action]
        self.action_weights[signal, action] = max(-323, min(q_last + reward, 308))
        self.rew_hist[signal, action] += reward
        # print(f"{action}: {reward}")
        # stimulus generalization
        if self.stimgen: # Should always be False in setup 1
            l = r = action
            for i in range(1,4):
                re = reward * stimgen(i)
                # reward right
                r += 1
                if r >= self.n_actions and self.n_signals > 2:
                    r = 0
                
                if r < self.n_actions:
                    q_last = self.action_weights[signal, r]
                    self.action_weights[signal, r] = max(-323, min(q_last + re, 308))
                    self.rew_hist[signal, r] += re
                    # print(f'{r}: {re}', end=' ')

                # reward left
                l -= 1
                if l < 0 and self.n_signals > 2:
                    l = self.n_actions - 1
                
                if l >= 0:
                    q_last = self.action_weights[signal, l]
                    self.action_weights[signal, l] = max(-323, min(q_last + re, 308))
                    self.rew_hist[signal, l] = re
                    # print(f'{l}: {re}', end=' ')

    def checkpoint(self):
        self.rew_hist /= 100
        sav = np.copy(self.rew_hist)
        self.rew_hist = np.zeros_like(self.action_weights)
        return sav

In [8]:
class History:
    def __init__(self, epochs, states, signals, actions):
        self.send_hist = np.zeros((epochs // 25, signals+1, states))
        self.reci_hist = np.zeros((epochs // 25, signals, actions))
        self.send_rew = np.zeros_like(self.send_hist)
        self.reci_rew = np.zeros_like(self.reci_hist)
        self.slows = np.zeros((epochs//100))
        self.epochs = epochs
        self.ep = 0
        self.ep2 = 0

    def add_25(self, send_weights, reci_weights):
        self.send_hist[self.ep] = send_weights
        self.reci_hist[self.ep] = reci_weights
        self.ep += 1

    def add_100(self, send_rew, reci_rew, slow):
        self.send_rew[self.ep2] = send_rew
        self.reci_rew[self.ep2] = reci_rew
        self.slows[self.ep2] = slow
        self.ep2 += 1

    def make_gif(self, fps, seed, filename_base, html=False):
        if not os.path.exists(f'./images'):
            os.mkdir(f'images') 
        for i in range(self.ep):
            fig, axs = plt.subplots(2, 1, figsize=(8, 6))
            plt.tight_layout(pad=3)

            sns.heatmap(
                np.exp(self.send_hist[i]) /
                np.exp(self.send_hist[i]).sum(axis=0),
                linewidth=0.5, linecolor='white',
                square=True, cbar=False, annot=True, fmt='.1f', ax=axs[0])
            axs[0].set_ylabel('messages')
            axs[0].set_xlabel('world states')
            axs[0].set_title(f'Sender\'s weights')

            sns.heatmap(
                (np.exp(self.reci_hist[i].T) /
                np.exp(self.reci_hist[i].T).sum(axis=0)).T,
                linewidth=0.5, linecolor='white',
                square=True, cbar=False, annot=True, fmt='.1f', ax=axs[1])
            axs[1].set_xlabel('actions')
            axs[1].set_ylabel('messages')
            axs[1].set_title(f'Receiver\'s weights')
            
            
            fig.suptitle(f'Rollout {i*25}')
            plt.savefig(f"./images/game_{i*25}.png")
            plt.close(fig)

        images = []
        for filename in [f'./images/game_{j*25}.png' for j in range(self.ep)]:
            images.append(imageio.imread(filename))
        imageio.mimsave(f'{filename_base}.gif', images, fps=fps)
        if html:
            display(HTML('<img src="{}">'.format(f'{filename_base}.gif')))
        # no return
    
    def print_send_map(self):
        final = (self.send_hist[self.ep-1] == self.send_hist[self.ep-1].max(axis=0)[None, :]).astype(int)
        print('a|s', end='')
        for i in range(final.shape[1]):
            print(f'{i:2}', end=' ')
        print()
        for i in range(final.shape[0]):
            for j in range(-1, final.shape[1]):
                if j == -1:
                    print(f'{i:2}', end=' ')
                else:
                    print(f'{final[i, j]:2}', end=' ')
            print()

    def print_reci_map(self):
        final = (self.reci_hist[-1].T == self.reci_hist[-1].T.max(axis=0)[None, :]).astype(int).T
        print('m|a', end='')
        for i in range(final.shape[1]):
            print(f'{i:2}', end=' ')
        print()
        for i in range(final.shape[0]):
            for j in range(-1, final.shape[1]):
                if j == -1:
                    print(f'{i:2}', end=' ')
                else:
                    print(f'{final[i, j]:2}', end=' ')
            print()

## Setup 1: Experiment

In [39]:
st = 20           # number of states
si = 2            # number of signals
ac = 2            # number of actions
seed = 0          # random seed
pos = 1           # positive reward
neg = 1           # negative reward -  will always be cast as a negative number
initial = 25      # initial weights
epochs = 20000    # number of epochs
gif_fps = 20      # gif fps

In [40]:
W = World(st, si, ac, seed, (pos, neg))
S = Sender(st, si, initial, stimgen=True)
R = Receiver(si, ac, initial, stimgen=False)
H = History(epochs, st, si, ac)

In [41]:
slow = past_rewards = 0

for epoch in range(epochs):
    stimulus = W.get_state()
    signal = S.get_signal(stimulus)
    if signal != -1:
        action = R.get_action(signal)
        reward = W.evaluate(action)
        past_rewards += reward
        S.update(reward)
        R.update(reward)
    # else null action
    
    if epoch % 25 == 0:
        # save history
        H.add_25(S.signal_weights, R.action_weights)
        

    if epoch % 100 == 0:
        slow = past_rewards / 100
        past_rewards = 0
        H.add_100(S.checkpoint(), R.checkpoint(), slow)

In [42]:
# final action-state mapping
H.print_send_map()

a|s 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 
 0  0  0  0  0  0  0  0  0  0  1  1  1  1  1  1  1  1  1  1  1 
 1  1  1  1  1  1  1  1  1  1  0  0  0  0  0  0  0  0  0  0  0 
 2  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0 


In [43]:
# final message-action mapping
H.print_reci_map()

m|a 0  1 
 0  0  1 
 1  1  0 


In [44]:
H.make_gif(gif_fps, seed, f'./images/{st}-{si}-{ac}_setup1', html=True)