### Into to Multi Armed Bandit

In [34]:
import numpy as np
from typing import List
from rich.console import Console

- The reward is supplied by a Gaussian bandit, making the environment stochastic.
- A long run policy cannot be inferred/decided as easily as in the case of deterministic rewards.

In [35]:
console = Console()

In [40]:
class GaussianBandit(object):
    def __init__(self, mu, std):
        self.mu = mu
        self.std = std
        
    def pull_lever(self):        
        reward = np.random.normal(self.mu, self.std)
        return np.round(reward, 2)
        
class GaussiandBanditGame(object):
    def __init__(self, bandits: List[GaussianBandit]):
        self.bandits = bandits
        self._shuffle_game()
        self._reset_game()
        
    def take_action(self, choice):
        reward = self.bandits[choice - 1].pull_lever()
        self.rewards.append(reward)
        self.total_reward += reward 
        self.n_played += 1
        return reward
        
    def start_game(self):
        self._reset_game()
        console.log("Game has started. Enter 0 to end the game.")
        while True:
            choice = int(input(f"Please select a machine to play with from 1 to {len(self.bandits)}"))
            if choice in range(1,len(self.bandits)+1):
                console.log(f"Round {self.n_played}")
                reward = self.take_action(choice)
                console.log(f"Machine {choice} is chosen.")
                console.log(f"Reward of this action is {reward}")
                console.log(f"Average reward so far is {self.total_reward / self.n_played}")
            else:
                break
        console.log("Game has ended. Game stats are...")
        if self.n_played > 0:
            console.log(f"Total reward is {self.total_reward} after {self.n_played} rounds.")
            console.log(f"Average reward is {self.total_reward / self.n_played}")
        
    
    def _shuffle_game(self):
        np.random.shuffle(self.bandits)
        
    def _reset_game(self):
        self.rewards = []
        self.total_reward = 0
        self.n_played = 0

In [41]:
bandit1 = GaussianBandit(5, 3)
bandit2 = GaussianBandit(6, 2)
bandit3 = GaussianBandit(1, 5)

In [42]:
game = GaussiandBanditGame([bandit1, bandit2, bandit3])

In [48]:
game.start_game()

Please select a machine to play with from 1 to 3 2


Please select a machine to play with from 1 to 3 2


Please select a machine to play with from 1 to 3 2


Please select a machine to play with from 1 to 3 2


Please select a machine to play with from 1 to 3 2


Please select a machine to play with from 1 to 3 2


Please select a machine to play with from 1 to 3 2


Please select a machine to play with from 1 to 3 2


Please select a machine to play with from 1 to 3 2


Please select a machine to play with from 1 to 3 2


Please select a machine to play with from 1 to 3 2


Please select a machine to play with from 1 to 3 2


Please select a machine to play with from 1 to 3 2


Please select a machine to play with from 1 to 3 2


Please select a machine to play with from 1 to 3 2


Please select a machine to play with from 1 to 3 2


Please select a machine to play with from 1 to 3 2


Please select a machine to play with from 1 to 3 2


Please select a machine to play with from 1 to 3 2


Please select a machine to play with from 1 to 3 2


Please select a machine to play with from 1 to 3 2


Please select a machine to play with from 1 to 3 2


Please select a machine to play with from 1 to 3 2


Please select a machine to play with from 1 to 3 2


Please select a machine to play with from 1 to 3 2


Please select a machine to play with from 1 to 3 2


Please select a machine to play with from 1 to 3 2


Please select a machine to play with from 1 to 3 2


Please select a machine to play with from 1 to 3 2


Please select a machine to play with from 1 to 3 4


In [49]:
game.rewards

[4.04,
 1.84,
 2.47,
 7.3,
 4.13,
 0.26,
 10.64,
 3.69,
 4.94,
 4.33,
 0.23,
 5.38,
 2.42,
 6.49,
 3.15,
 0.37,
 4.79,
 3.42,
 4.61,
 4.15,
 2.09,
 2.5,
 1.76,
 10.61,
 4.63,
 4.56,
 7.48,
 2.02,
 3.17]