This repository has been archived by the owner on Sep 8, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
76 lines (61 loc) · 1.74 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
import random
from models import DQAgent, RandomAgent, BubbleSortAgent
import utils
def test():
"""Test DQNsort"""
arr = list(range(5))
random.shuffle(arr)
agent = DQAgent(arr, is_train=False)
agent.load_model('./data/dqn_latest.pt')
while range(50):
print(agent.arr)
if agent.arr == sorted(agent.arr):
break
agent.update()
print("Done!")
def test_compare(make_gif):
"""Test comparing DQN sort"""
arr = list(range(5))
random.shuffle(arr)
rand_agent = RandomAgent(arr.copy())
bubble_agent = BubbleSortAgent(arr.copy())
dqn_agent = DQAgent(arr.copy(), is_train=False)
dqn_agent.load_model('./data/dqn_latest.pt')
agents = [
rand_agent,
bubble_agent,
dqn_agent,
]
agent_names = [
"Random",
"Bubble",
"DQN",
]
imgs = []
is_done = False
winner = -1
while not is_done:
print("Random :", rand_agent.arr)
print("Bubble :", bubble_agent.arr)
print("DQNsort:", dqn_agent.arr)
print("===")
if make_gif:
imgs.append(utils.visualize_agents(agents, agent_names))
for i, agent in enumerate(agents):
agent.update()
if agent.arr == sorted(agent.arr):
is_done = True
winner = i
break
print("Random :", rand_agent.arr)
print("Bubble :", bubble_agent.arr)
print("DQNsort:", dqn_agent.arr)
print("===")
if make_gif:
end_img = utils.visualize_agents(agents, agent_names)
imgs += [end_img] * 10
print(agent_names[winner], "wins!")
if make_gif:
# Generate visuals
utils.imgs2gif(imgs)