-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_dqn.py
109 lines (91 loc) · 3.21 KB
/
train_dqn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import rung_rl.plotter as plt
from rung_rl.agents.dqn.dqn_agent import DQNAgent
from rung_rl.agents.random_agent import RandomAgent
from rung_rl.game.Game import Game
CONCURRENT_GAMES = 128
def train_dqn(num_games, debug=False):
win_rate_radiant = []
win_rate_dire = []
games = []
weak_agent = DQNAgent(False, True)
weak_agent.eval = True
print("Starting training")
agent = DQNAgent(True, True) # to indicate that we want to train the agent
agent.save_model("weak")
players = [agent, agent, agent, agent]
plt.plot(games, win_rate_radiant, win_rate_dire)
i = 0
games_i = 0
while 1:
game = Game(players, debug, debug)
game.initialize()
game.play_game()
agent.optimize_model()
if i % 250 == 0:
agent.mirror_models()
if i % 300 == 0:
print("Total Games: {}".format(i))
if i % 10000 == 0 and i != 0:
players[0].save_model("final")
agent.eval = True
agent.train = False
weak_agent.load_model("weak")
players3 = [agent, weak_agent, agent, weak_agent]
win_rate_r, _ = evaluate(500, players3, 0)
players2 = [agent, RandomAgent(), agent, RandomAgent()]
win_rate_d, _ = evaluate(500, players2, 0)
win_rate_radiant.append(win_rate_r / 100)
win_rate_dire.append(win_rate_d / 100)
games.append(games_i)
plt.plot(games, win_rate_radiant, win_rate_dire)
plt.savefig()
agent.eval = False
agent.train = True
if win_rate_r < 50:
# if the previous agent beats you, train against that
strategy_collapse(players3, agent)
games_i += 2500
agent.save_model("weak")
i += 1
games_i += 1
def strategy_collapse(players, agent):
"""
In order to prevent strategy collapse, we ocassionally train against former version of ourself
that beat us in evaluation
"""
wins = 0
for i in range(2500):
game = Game(players)
game.initialize()
game.play_game()
agent.optimize_model()
if i % 250 == 0:
agent.mirror_models()
def evaluate(num_games, players, idx=0, debug=False):
"""
Evaluate the agent with the given index to count the number of wins of the particular agent
"""
print("Starting evaluation...")
wins = 0
toss = None
for i in range(num_games):
game = Game(players, debug, debug)
game.initialize(toss)
winners = game.play_game()
if idx in winners:
wins += 1
toss = winners[0]
avg_reward = 0
print(wins, wins / num_games, avg_reward)
return wins / num_games * 100, avg_reward
if __name__ == "__main__":
train_dqn(1)
agent = DQNAgent(False, True)
agent.load_model("final")
shaped_agent = DQNAgent(False, True)
shaped_agent.load_model_from_path("../saved_models/dqn_best_recurrent_2/model_dqn_final")
shaped_agent.eval = True
agent.eval = True
players = [agent, shaped_agent, agent, shaped_agent]
players = [shaped_agent, RandomAgent(), shaped_agent, RandomAgent()]
evaluate(1000, players, 0, False)