In [1]:
from src.environment import Environment
from src.qLearning import Q
from src.valueTable import ValueTable
from src.greedyExplorer import GreedyExplorer
from sklearn.metrics import mean_squared_error
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import numpy as np

env = Environment(arenaStates=12, numberActions=5)
learner = Q(alpha=0.1, gamma=0.9)
table = ValueTable(env._numberStates(), env._numberActions())

#initialize table with 0
table.initializeTable(initValue=0.0)
explorer = GreedyExplorer(goalEnvironment=env, valueTable=table, epsilon=0.30, decay=0.9967)

#initizlize defense positions
for i in range(12):
    table.updateTable(12*i+i, 2, 100)

episodesNumber =  int(raw_input("Enter the number of episodes:"))

x = np.arange(1, episodesNumber+1, 1)
stepsList = []
mseList = []

stepmax = -1

for ballInPos in range(12):
    env.setBallPos(ballInPos)
    explorer.updateEpsilon(0.3)
    for i in range(episodesNumber):
        env.randomKeeper()
        steps = 0
        oldTable = table.returnTable().copy()
        while 1:
            state = env.returnState()
            action = explorer.chooseAction()
            env.performAction(action)
            reward = env.getReward()
            newState = env.returnState()
            table.updateTable(state, action+2, learner.learn(table.stateValue(state, action+2), reward, table.maxState(newState)))
            steps += 1
            if env.isGoal():
                break
        explorer.useDecay()
        newTable = table.returnTable().copy()  
        # print "Episode :", i+1, ", number of steps:", steps
        mseList.append(mean_squared_error(oldTable, newTable))
        
        # Number max of steps
        if steps > stepmax:
            stepmax = steps

        stepsList.append(steps)


np.set_printoptions(suppress=True)
print table.printTable()

blue_patch = mpatches.Patch(color='blue', label='Steps')
orange_patch = mpatches.Patch(color='orange', label='MSE')

# Ploting episodes X steps graph
# plt.plot(x, stepsList)
# plt.plot(x, [i*stepmax for i in mseList])

# plt.autoscale(tight=True)
# plt.legend(handles=[blue_patch, orange_patch])
# plt.grid(True)
# plt.xlabel('Episodes')

plt.show()




Enter the number of episodes:1000
Value Table: [[  0.           0.         620.79920006   0.          88.30580792]
 [  0.         629.59987561  82.51517546  41.81164417  69.91035254]
 [657.89643854 223.26956841 190.94576649 142.01832827 165.90885627]
 [ 55.69623485 560.92801877  33.14904124  50.59535083  32.31420633]
 [588.14887345 209.21105672 232.07828825 163.36606343 169.65868444]
 [ 38.22011365 501.77430174  17.85927434  30.1104634    4.20795021]
 [523.61539549 127.71235379 141.84706035  75.85944845 120.46434839]
 [ -0.19       421.60336434  -0.199       37.12574546  13.78076946]
 [464.17528263  12.90867743 124.02731315 104.30412762  74.68137615]
 [  2.02764764 376.55085309  12.34896053  22.21705123   9.7209374 ]
 [405.60723221  -0.393238    23.59505124  48.63542291   0.        ]
 [ -0.4750318  335.1725154   49.02063895   0.           0.        ]
 [  0.           0.         145.92130202 569.53857254  91.50956704]
 [  0.          36.14384565 536.37659432   4.88827141  54.17983813]
 