forked from siyamak45/CS188.1x-Project3
-
Notifications
You must be signed in to change notification settings - Fork 1
/
nbrun.py
76 lines (55 loc) · 1.98 KB
/
nbrun.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import random
import sys
import mdp
import environment
import util
import optparse
import gridworld
class Opts(object):
def __init__(self, display=True):
self.speed = 1.0
self.gridSize=150
self.discount=.9
self.learningRate=.1
self.epsilon=.1
self.iters =10
self.episodes=10
self.manual=False
self.quiet=True
self.pause=False
self.display=display
def runAgent(a, mdp, opts):
env = gridworld.GridworldEnvironment(mdp)
import textGridworldDisplay
import graphicsGridworldDisplay
display = textGridworldDisplay.TextGridworldDisplay(mdp)
if opts.display:
display = graphicsGridworldDisplay.GraphicsGridworldDisplay(mdp, opts.gridSize, opts.speed)
display.start()
###########################
# RUN EPISODES
###########################
# FIGURE OUT WHAT TO DISPLAY EACH TIME STEP (IF ANYTHING)
displayCallback = lambda state: None
if opts.display:
displayCallback = lambda state: display.displayQValues(a, state, "CURRENT Q-VALUES")
messageCallback = lambda x: gridworld.printString(x)
if opts.quiet:
messageCallback = lambda x: None
# FIGURE OUT WHETHER TO WAIT FOR A KEY PRESS AFTER EACH TIME STEP
pauseCallback = lambda : None
if opts.pause:
pauseCallback = lambda : display.pause()
# FIGURE OUT WHETHER THE USER WANTS MANUAL CONTROL (FOR DEBUGGING AND DEMOS)
if opts.manual:
decisionCallback = lambda state : getUserAction(state, mdp.getPossibleActions)
else:
decisionCallback = a.getAction
returns = 0
for episode in range(1, opts.episodes+1):
returns += gridworld.runEpisode(a, env, opts.discount, decisionCallback, displayCallback, messageCallback, pauseCallback, episode)
if opts.episodes > 0:
print
print "AVERAGE RETURNS FROM START STATE: "+str((returns+0.0) / opts.episodes)
print
print