In [76]:
import numpy as np
from collections import defaultdict
from tabulate import tabulate
import plotly.graph_objs as go
import plotly.figure_factory as ff
from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot
init_notebook_mode(connected=True)

In [77]:
# A world of infinite cheese
class World(object):
    
    def __init__(self):
        self.cheese = False
    
    def step(self, eat):
        if eat:
            self.cheese = True
        else:
            self.cheese = False
        
        return self.cheese

In [78]:
class Mouse(object):
    def __init__(self, value_learning_rate=0.1, action_learning_rate=0.01):
        self._value_learning_rate = value_learning_rate
        self._action_learning_rate = action_learning_rate
        
        self.reset()
    
    def reset(self):
        self.energy = self.initial_energy = 5
        self.cheese = self.initial_cheese = None
        self.action = self.initial_action = None
        self.state = None
        self.previous_state = None
        
        self._learned_value_table = {}
        self._action_table = {}
        for i in range(11):
            for v in [True, False]:
                k = (i, v)
                self._learned_value_table[k] = 0
                self._action_table[k] = 0.5
    
    @staticmethod
    def _update_energy(energy, cheese):
        if cheese:
            energy += 1
        else:
            energy -= 1
        
        return np.clip(energy, 0, 10)
    
    @staticmethod
    def _innate_evaluation(current_state, full_point=7):
        value = 0
        energy, cheese = current_state
        if cheese:
            value = -energy + full_point
        elif energy < 4:
            value = energy - 4
        return value
    
    def _update_learned_value_table(self, previous_state, value_difference, learning_rate, debug=True):
        previous_value = self._learned_value_table[previous_state]
        new_value = previous_value + learning_rate * value_difference
        self._learned_value_table[previous_state] = new_value
        if debug:
            print("---- Update Learned Value Table")
            print("Previous State:", previous_state)
            print("Action:", self.action)
            print("Previous Value:", previous_value)
            print("Value Difference:", value_difference)
            print("New Value:", new_value)

    def _update_action_table(self, previous_state, value_difference, learning_rate, debug=True):
        if self.action:
            previous_strength = self._action_table[previous_state]
            new_strength = previous_strength + learning_rate * value_difference
            new_strength = np.clip(new_strength, 0.01, 0.99)
            self._action_table[previous_state] = new_strength
            if debug:
                print("---- Update Action Table")
                print("Previous State:", previous_state)
                print("Previous Strength:", previous_strength)
                print("Value Difference:", value_difference)
                print("New Strength:", new_strength)
        else:
            if debug:
                print("Did not act last step. Nothing to change.")
        
    def _learn(self, previous_state, value, debug=False):
        value_difference = value - self._learned_value_table[previous_state]
        self._update_learned_value_table(
            previous_state, 
            value_difference, 
            self._value_learning_rate,
            debug=debug
        )
        self._update_action_table(
            previous_state, 
            value_difference, 
            self._action_learning_rate,
            debug=debug
        )
    
    def _get_action(self, cheese):
        eat = False
        act_chance = self._action_table[(self.energy, cheese)]
        if np.random.random() < act_chance:
            eat = True
        return eat
    
    def _dict_to_table(self, d):
        int_d = defaultdict(list)
        for k, v in d.items():
            p1, p2 = k
            int_d[p1].append(v)
        
        return [[k] + v for k, v in int_d.items()]
    
    def _display_table(self, d_table, headers):
        t = self._dict_to_table(d_table)
        print()
        print(tabulate(t, headers=headers))
        print()
    
    def display_knowledge(self):
        print("---- Learned Value Table")
        self._display_table(self._learned_value_table, headers=["Energy", "Cheese", "No Cheese"])
        print("---- Action Table")
        self._display_table(self._action_table, headers=["Energy", "Cheese", "No Cheese"])
    
    def step(self, obseravation, debug=False):
        cheese = observation # Boolean if the mouse tastes cheese or not
        if debug: print("Initial energy value", self.energy)
        next_energy = self._update_energy(self.energy, cheese)
        if debug: print("Updated energy value", next_energy)
        self.previous_state = self.state
        self.state = (next_energy, cheese)
        
        if debug: 
            print("Previous State: (Energy, Cheese)", self.previous_state)
            print("Previous Action:", self.action)
            print("Current State: (Energy, Cheese)", self.state)
        if self.previous_state is not None:
            if debug: print("Evaluating and learning ...")
            v_innate = self._innate_evaluation(self.state)
            if debug: print("Innate Value:", v_innate)
            v_learned = self._learned_value_table[self.state]
            if debug: print("Learned Value:", v_learned)
            value = v_innate + v_learned
            if debug: print("Total Value:", value)
            self._learn(self.previous_state, value, debug)
            if debug: self.display_knowledge()
        
        next_action = self._get_action(cheese)
        if debug: 
            print("Action Chosen:", next_action)
            print("=" * 50)
        self.action = next_action
        self.cheese = cheese
        self.energy = next_energy
        return self.action
        
        

In [79]:
world = World()
mouse = Mouse(value_learning_rate=0.1, action_learning_rate=0.05)

In [80]:
mouse.reset()
observation = False # No cheese to begin with
for i in range(100):
    action = mouse.step(observation, debug=False)
    observation = world.step(action)

In [81]:
mouse.display_knowledge()

---- Learned Value Table

  Energy       Cheese    No Cheese
--------  -----------  -----------
       0   0              0
       1   0              0
       2   0              0
       3   0              1.48604
       4   0.534157       0.393402
       5   0.469682       0.423839
       6  -0.00295746     0.1237
       7  -0.189083      -0.284343
       8  -0.364174      -0.331778
       9  -0.32058       -1.35468
      10  -0.993523       0

---- Action Table

  Energy    Cheese    No Cheese
--------  --------  -----------
       0  0.5          0.5
       1  0.5          0.5
       2  0.5          0.5
       3  0.5          0.99
       4  0.821618     0.807094
       5  0.734841     0.728934
       6  0.44453      0.473315
       7  0.33488      0.329746
       8  0.29279      0.309
       9  0.351        0.01
      10  0.05         0.5



In [104]:
values = np.array(mouse._dict_to_table(mouse._learned_value_table)).T
_, ys_cheese, ys_no_cheese = values
chances = np.array(mouse._dict_to_table(mouse._action_table)).T
_, ysa_cheese, ysa_no_cheese = chances

In [105]:
iplot([go.Scatter(y=ys_cheese), go.Scatter(y=ys_no_cheese)])

In [106]:
iplot([go.Scatter(y=ysa_cheese), go.Scatter(y=ysa_no_cheese)])