In [147]:
import numpy as np


class Agent:
    def __init__(self, n_machines=10):
        self.t = 0  # current time step
        self.n_machines = n_machines
        self.machine_credits = np.array([0] * n_machines)
        self.machine_probs = [1 / n_machines] * n_machines
        self.temperature = 1.5

    def pick_machine(self):
        # returns index of machine
        return np.random.choice([i for i in range(10)], size=1, p=self.machine_probs)[0]

    def recalculate_machine_probs(self):
        self.machine_probs = np.exp(self.machine_credits / self.temperature) / sum(
            np.exp(self.machine_credits / self.temperature))

    def get_reward(self, reward, machine_index):
        self.t += 1
        if reward == 1:
            self.machine_credits[machine_index] = self.machine_credits[machine_index] + 1
        else:
            self.machine_credits[machine_index] = self.machine_credits[machine_index] - 1
        self.recalculate_machine_probs()


class Environment:
    def __init__(self, n_machines=10):
        self.n_machines = n_machines
        # Machine i has reward distribution Bernoulli(p_i)
        # The p_i's are drawn from Uniform([0, 1])
        self.params = np.random.uniform(size=n_machines)

    def _interact(self, machine_index):
        assert 0 <= machine_index < self.n_machines, 'Bad machine index'
        p = self.params[machine_index]
        # Sample from Bernoulli(p) to get reward:
        # (i.e. Binomial distribution with n=1)
        reward = np.random.binomial(n=1, p=p)
        return reward

    def run(self, time_steps=250):
        total_reward = 0
        agent = Agent()
        for _ in range(time_steps):
            machine_index = agent.pick_machine()
            reward = self._interact(machine_index)
            agent.get_reward(reward, machine_index)
            total_reward += reward
        return total_reward


if __name__ == "__main__":
    reward = 0
    n = 2000
    for i in range(n):
        reward += Environment().run()
    print('Average:', reward / n)
    print(Environment().run())

Average: 203.5685
237
