<a href="https://colab.research.google.com/github/markhalka/Paper_Imlimentations/blob/main/EmergentCompositionality.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np




# what happens if you have less outputs than actions? (half as many?, one tenth?)
# and maybe the reciever has some primitive short term memory
# what kind of results would there be?

# you could easily update this to inlucde an array of senders and recievers for arbitrary depth
# (and you could also make it arbitrarily wide as well..)

"""
Implimentation of: CREATIVE COMPOSITIONALITY FROM REINFORCEMENT LEARNING IN SIGNALING GAMES by Michael Franke

the general idea, is that creative compositionality can be explained by two mechanisms: spill-over, and lateral inhibition (both biologically plausible)
"""

# use the same default parameters as the paper
DEFUALT_SIMILARITY = 0.05
DEFAULT_INHIBITION = 0.2
class Agent():
    def __init__(self, states, name):
        self.states = states
        self.name = name
        self.n_states = len(states)
        self.counts = np.ones((self.n_states, self.n_states))
        self.last_input = None
        self.last_ouput = None
        self.i = DEFAULT_INHIBITION
        self.s = DEFAULT_SIMILARITY
    
    def get_index(self, state):
        return self.states.index(state)

    def get_action(self, state):
        probs = []
        base_prob = 0.0
        index = self.get_index(state)
        for i in range(self.n_states):
            base_prob +=  self.counts[index][i]
            probs.append(self.counts[index][i])
        probs = probs / base_prob
        output_index = np.random.choice(np.arange(self.n_states), p=probs)
        output = self.states[output_index]
        self.last_input = state
        self.last_output = output
        return output

    def get_sim(self, state1, state2):
        count = 0
        state1, state2 = (state1, state2) if len(state1) > len(state2) else (state2, state1)
        for i in range(len(state2)):
            was_found = False
            for j in range(len(state1)):
                if state1[j] == state2[i]:
                    was_found = True
                    break
            if was_found:
                count += 1
        return count / len(state1)

    def update_spill_over(self, reward):
        for i, input in enumerate(self.states, 0):
            for o, output in enumerate(self.states, 0):
                input_sim = self.get_sim(self.last_input, input)
                output_sim = self.get_sim(self.last_output, output)
                total_sim = input_sim * output_sim
                if total_sim != 1:
                    total_sim *= self.s
                self.counts[i][o] += total_sim * reward

    def update_lateral(self, reward):
        output_index = self.get_index(self.last_output)
        input_index = self.get_index(self.last_input)
        for s, state in enumerate(self.states, 0):
            if state != self.last_input:
                self.counts[s][output_index] = max(self.counts[s][output_index] - self.i, 0)
            if state != self.last_output:
                self.counts[input_index][s] = max(self.counts[input_index][s] - self.i, 0)

    def update_counts(self, reward):
        if reward == 0:
            return   
        input_index = self.get_index(self.last_input)
        output_index = self.get_index(self.last_output)
        self.update_spill_over(reward)
        self.update_lateral(reward)


class Env():
    def __init__(self, states):
        self.states = states
        self.n_states = len(states)
        self.current_state = None
        
    def get_state(self):
        index = np.random.randint(self.n_states)
        self.current_state = self.states[index]
        return self.current_state
    
    def check_state(self, state):
        return int(self.current_state == state)


class Communication():
    def __init__(self):
        self.states = ["a","b","c","ab","ac","bc"]
        self.env = Env(self.states)
        self.sender = Agent(self.states, "send")
        self.reciever = Agent(self.states, "rec")
    
    def step(self):
        state = self.env.get_state()
        sender_state = self.sender.get_action(state)
        reciever_state = self.reciever.get_action(sender_state)
        reward = self.env.check_state(reciever_state)
        self.sender.update_counts(reward)
        self.reciever.update_counts(reward)
        return reward

    def test(self):
        print(self.sender.get_sim("c","ab"))

    def run(self):
        total_reward = 0.0
        for i in range(10000):
            total_reward += self.step()
            if i % 1000 == 0:
                print(total_reward / 1000.0)
                total_reward = 0.0
        print(self.sender.counts)
        print("\n")
        print(self.reciever.counts)

comm = Communication()
comm.run()
#comm.test()