-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
83 lines (77 loc) · 5.94 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from FileModHandler import FileModified
from Algos import Algos
import gymnasium as gym
from gymnasium.spaces import Box, Dict, Discrete
from gymnasium.utils.env_checker import check_env
import matplotlib.pyplot as plt
import numpy as np
import ML_Env
from FileModHandler import FileModified
np.set_printoptions(threshold=np.inf, linewidth=140)
sarsa = Algos(numRows=11, numCols=5, alpha=0.8, gamma=0.9, epsilon=0.8)
sarsa.qTable = np.ones((11, 5, 11, 5))
print(sarsa.getMaxActionIndex((0,0)))
newLoc = sarsa.updateQTable_Sarsa((2, 2), (3, 3), (3, 3), sarsa.getNextAction((2, 2), sarsa.epsilon), 55, sarsa.epsilon)
newLoc = sarsa.updateQTable_Sarsa((0, 0),(10, 4), (10, 4), sarsa.getNextAction((10, 4), sarsa.epsilon), 55, sarsa.epsilon)
newLoc = sarsa.updateQTable_Sarsa((1, 0),(10, 4), (10, 4), sarsa.getNextAction((10, 4), sarsa.epsilon), 55, sarsa.epsilon)
newLoc = sarsa.updateQTable_Sarsa((2, 0),(10, 4), (10, 4), sarsa.getNextAction((10, 4), sarsa.epsilon), 55, sarsa.epsilon)
newLoc = sarsa.updateQTable_Sarsa((3, 0),(10, 4), (10, 4), sarsa.getNextAction((10, 4), sarsa.epsilon), 55, sarsa.epsilon)
newLoc = sarsa.updateQTable_Sarsa((4, 0),(10, 4), (10, 4), sarsa.getNextAction((10, 4), sarsa.epsilon), 55, sarsa.epsilon)
newLoc = sarsa.updateQTable_Sarsa((5, 0),(10, 4), (10, 4), sarsa.getNextAction((10, 4), sarsa.epsilon), 55, sarsa.epsilon)
newLoc = sarsa.updateQTable_Sarsa((6, 0),(10, 4), (10, 4), sarsa.getNextAction((10, 4), sarsa.epsilon), 55, sarsa.epsilon)
newLoc = sarsa.updateQTable_Sarsa((7, 0),(10, 4), (10, 4), sarsa.getNextAction((10, 4), sarsa.epsilon), 55, sarsa.epsilon)
newLoc = sarsa.updateQTable_Sarsa((8, 0),(10, 4), (10, 4), sarsa.getNextAction((10, 4), sarsa.epsilon), 55, sarsa.epsilon)
newLoc = sarsa.updateQTable_Sarsa((9, 0),(10, 4), (10, 4), sarsa.getNextAction((10, 4), sarsa.epsilon), 55, sarsa.epsilon)
newLoc = sarsa.updateQTable_Sarsa((10, 0),(10, 4), (10, 4), sarsa.getNextAction((10, 4), sarsa.epsilon), 55, sarsa.epsilon)
# qTableDivide_Sarsa = np.full((sarsa.rows, sarsa.cols), sarsa.rows * sarsa.cols)
# sarsa.Average_And_Visualize_QTable(sarsa.q_table_sum_SarsaLearning, qTableDivide_Sarsa, "QTable of Sarsa-Learning")
# shows the values that previously led to high rewards (starting state values)
print(np.unravel_index(np.argmax(sarsa.qTable[7,0]),sarsa.qTable[7,0].shape))
collapsed = np.average(sarsa.qTable, axis = (0,1))
print(collapsed.__len__())
print(collapsed.transpose())
# shows the positions that led to high rewards after starting from previous position (ending state values)
collapsed = np.average(sarsa.qTable, axis = (2,3))
print(collapsed.__len__())
print(collapsed.transpose())
print(np.shape(sarsa.qTable))
print(np.unravel_index(np.argmax(sarsa.qTable[0]), sarsa.qTable.shape))
print(np.unravel_index(np.argmax(sarsa.qTable[1]), sarsa.qTable.shape))
print(np.unravel_index(np.argmax(sarsa.qTable[2]), sarsa.qTable.shape))
print(sarsa.qTable[8,0])
print(sarsa.qTable[2])
print(np.unravel_index(np.argmax(sarsa.qTable[7,0]),sarsa.qTable[7,0].shape))
print("\n\n")
qlearning = Algos(numRows=11, numCols=5, alpha=0.8, gamma=0.9, epsilon=0.8)
qlearning.qTable = np.ones((11, 5, 11, 5))
newLoc = qlearning.updateQTable_QLearning((2, 2), (3, 3), (3, 3), qlearning.getNextAction((2, 2), qlearning.epsilon), 55, qlearning.epsilon)
print(sarsa.getMaxActionIndex((2, 2)))
newLoc = qlearning.updateQTable_QLearning((0, 0),(10, 4), (10, 4), qlearning.getNextAction((10, 4), qlearning.epsilon), 55, qlearning.epsilon)
print(sarsa.getMaxActionIndex((0, 0)))
newLoc = qlearning.updateQTable_QLearning((1, 0),(10, 4), (10, 4), qlearning.getNextAction((10, 4), qlearning.epsilon), 55, qlearning.epsilon)
newLoc = qlearning.updateQTable_QLearning((2, 0),(10, 4), (10, 4), qlearning.getNextAction((10, 4), qlearning.epsilon), 55, qlearning.epsilon)
newLoc = qlearning.updateQTable_QLearning((3, 0),(10, 4), (10, 4), qlearning.getNextAction((10, 4), qlearning.epsilon), 55, qlearning.epsilon)
newLoc = qlearning.updateQTable_QLearning((4, 0),(10, 4), (10, 4), qlearning.getNextAction((10, 4), qlearning.epsilon), 55, qlearning.epsilon)
newLoc = qlearning.updateQTable_QLearning((5, 0),(10, 4), (10, 4), qlearning.getNextAction((10, 4), qlearning.epsilon), 55, qlearning.epsilon)
newLoc = qlearning.updateQTable_QLearning((6, 0),(10, 4), (10, 4), qlearning.getNextAction((10, 4), qlearning.epsilon), 55, qlearning.epsilon)
newLoc = qlearning.updateQTable_QLearning((7, 0),(10, 4), (10, 4), qlearning.getNextAction((10, 4), qlearning.epsilon), 55, qlearning.epsilon)
newLoc = qlearning.updateQTable_QLearning((8, 0),(10, 4), (10, 4), qlearning.getNextAction((10, 4), qlearning.epsilon), 55, qlearning.epsilon)
newLoc = qlearning.updateQTable_QLearning((9, 0),(10, 4), (10, 4), qlearning.getNextAction((10, 4), qlearning.epsilon), 55, qlearning.epsilon)
newLoc = qlearning.updateQTable_QLearning((10, 0),(10, 4), (10, 4), qlearning.getNextAction((10, 4), qlearning.epsilon), 56, qlearning.epsilon)
# qTableDivide_Sarsa = np.full((sarsa.rows, sarsa.cols), sarsa.rows * sarsa.cols)
# sarsa.Average_And_Visualize_QTable(sarsa.q_table_sum_SarsaLearning, qTableDivide_Sarsa, "QTable of Sarsa-Learning")
# shows the values that previously led to high rewards (starting state values)
print(np.unravel_index(np.argmax(qlearning.qTable[7,0]),qlearning.qTable[7,0].shape))
collapsed = np.average(qlearning.qTable, axis = (0,1))
print(collapsed.__len__())
print(collapsed.transpose())
# shows the positions that led to high rewards after starting from previous position (ending state values)
collapsed = np.average(qlearning.qTable, axis = (2,3))
print(collapsed.__len__())
print(collapsed.transpose())
print(np.shape(qlearning.qTable))
print(np.unravel_index(np.argmax(qlearning.qTable[0]), qlearning.qTable.shape))
print(np.unravel_index(np.argmax(qlearning.qTable[1]), qlearning.qTable.shape))
print(np.unravel_index(np.argmax(qlearning.qTable[2]), qlearning.qTable.shape))
print(np.unravel_index(np.argmax(qlearning.qTable[3]), qlearning.qTable.shape))
print(np.unravel_index(np.argmax(qlearning.qTable[7,0]),qlearning.qTable[7,0].shape))