-
Notifications
You must be signed in to change notification settings - Fork 1
/
Trainer.py
130 lines (104 loc) · 3.46 KB
/
Trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
#!/usr/bin/env python3
import sys
import time
import numpy as np
import ujson
from matplotlib import pyplot as plt
import Agent
def main():
if len(sys.argv) != 3:
print("Please provide the board and the episodes as a command-line argument.")
return
board = sys.argv[1]
episodes = int(sys.argv[2])
train(board, episodes, False, True, True, True)
def train(board, episodes, plot, save, train_w, train_b):
avg_reward_w = []
avg_reward_b = []
bar_y_w = []
bar_x_w = []
bar_y_b = []
bar_x_b = []
wr_w = 0
wr_b = 0
time_w = 0
time_b = 0
if train_w:
start_time1 = time.time()
avg_reward_w, wr_w, bar_y_w, bar_x_w = Agent.train_agent_vs_random(board, episodes, "w")
end_time1 = time.time()
time_w = end_time1 - start_time1
if train_b:
start_time2 = time.time()
avg_reward_b, wr_b, bar_y_b, bar_x_b = Agent.train_agent_vs_random(board, episodes, "b")
end_time2 = time.time()
time_b = end_time2 - start_time2
if plot:
learning_curve_plot(episodes, avg_reward_w, avg_reward_b)
if train_w:
bar_plot(bar_x_w, bar_y_w)
if train_b:
bar_plot(bar_x_b, bar_y_b)
if save:
folder = 'stats'
graph_file = folder + '/' + board + "_graph" + ".stats"
wr_file = folder + '/' + board + "_wr" + ".stats"
time_file = folder + '/' + board + "_time" + ".stats"
with open(graph_file, 'a+') as file:
if train_w:
file.write(ujson.dumps(avg_reward_w.tolist()))
file.write('\n')
if train_b:
file.write(ujson.dumps(avg_reward_b.tolist()))
file.write('\n')
file.write('\n')
with open(wr_file, 'a+') as file:
if train_w:
file.write(str(wr_w) + '\n')
if train_b:
file.write(str(wr_b) + '\n')
file.write('\n')
with open(time_file, 'a+') as file:
if train_w:
file.write(str(time_w) + '\n')
if train_b:
file.write(str(time_b) + '\n')
file.write('\n')
def learning_curve_plot(episodes, avg_reward_w, avg_reward_b):
plt.semilogy(avg_reward_w, label="White Reward, T = " + str(episodes))
plt.semilogy(avg_reward_b, label="Black Reward, T = " + str(episodes))
plt.legend()
plt.xlabel("Time")
plt.ylabel("Average Reward")
plt.show()
def bar_plot(bar_x, bar_y):
print([item[0] for item in bar_y])
print([item[1] for item in bar_y])
print([item[2] for item in bar_y])
# set width of bar
barwidth = 0.25
fig = plt.subplots(figsize=(12, 8))
# set height of bar
white = [item[0] for item in bar_y]
black = [item[1] for item in bar_y]
draw = [item[2] for item in bar_y]
# Set position of bar on X axis
br1 = np.arange(len(white))
br2 = [x + barwidth for x in br1]
br3 = [x + barwidth for x in br2]
# Make the plot
plt.bar(br1, white, width=barwidth,
label='White')
plt.bar(br2, black, width=barwidth,
label='Black')
plt.bar(br3, draw, width=barwidth,
label='Draw')
plt.xlabel('Episodes', fontweight='bold', fontsize=15)
plt.ylabel('Stats', fontweight='bold', fontsize=15)
plt.xticks([r + barwidth for r in range(len(white))],
bar_x)
plt.grid()
plt.legend()
plt.show()
if __name__ == "__main__":
main()