Permalink
Browse files

Script to store .gif for a pretrained network. Thanks to @tambetm

  • Loading branch information...
1 parent 8466031 commit 7e5feff68d2c9e96d97311e0ebcfd738e06927f2 @kuz committed Aug 15, 2015
Showing with 110 additions and 39 deletions.
  1. +54 −25 dqn/test_agent.lua
  2. BIN {images → gifs}/breakout.gif
  3. +3 −0 install_dependencies.sh
  4. +48 −0 test_cpu
  5. +5 −14 watch_pretrained → test_gpu
View
@@ -1,9 +1,10 @@
--[[
Copyright (c) 2014 Google Inc.
-
See LICENSE file for full terms of limited license.
]]
+gd = require "gd"
+
if not dqn then
require "initenv"
end
@@ -29,21 +30,13 @@ cmd:option('-network', '', 'reload pretrained network')
cmd:option('-agent', '', 'name of agent file to use')
cmd:option('-agent_params', '', 'string of agent parameters')
cmd:option('-seed', 1, 'fixed input seed for repeatable experiments')
-cmd:option('-saveNetworkParams', false,
- 'saves the agent network in a separate file')
-cmd:option('-prog_freq', 5*10^3, 'frequency of progress output')
-cmd:option('-save_freq', 5*10^4, 'the model is saved every save_freq steps')
-cmd:option('-eval_freq', 10^4, 'frequency of greedy evaluation')
-cmd:option('-save_versions', 0, '')
-
-cmd:option('-steps', 10^5, 'number of training steps to perform')
-cmd:option('-eval_steps', 10^5, 'number of evaluation steps')
cmd:option('-verbose', 2,
'the higher the level, the more information is printed to screen')
cmd:option('-threads', 1, 'number of BLAS threads')
cmd:option('-gpu', -1, 'gpu flag')
-cmd:option('-ep', 0.0, 'exploration probability')
+cmd:option('-gif_file', '', 'GIF path to write session screens')
+cmd:option('-csv_file', '', 'CSV path to write session data')
cmd:text()
@@ -59,24 +52,60 @@ local print = function(...)
io.flush()
end
-local step = 0
+-- file names from command line
+local gif_filename = opt.gif_file
+
+-- start a new game
+local screen, reward, terminal = game_env:newGame()
+
+-- compress screen to JPEG with 100% quality
+local jpg = image.compressJPG(screen:squeeze(), 100)
+-- create gd image from JPEG string
+local im = gd.createFromJpegStr(jpg:storage():string())
+-- convert truecolor to palette
+im:trueColorToPalette(false, 256)
+
+-- write GIF header, use global palette and infinite looping
+im:gifAnimBegin(gif_filename, true, 0)
+-- write first frame
+im:gifAnimAdd(gif_filename, false, 0, 0, 7, gd.DISPOSAL_NONE)
-local screen, reward, terminal = game_env:getState()
+-- remember the image and show it first
+local previm = im
+local win = image.display({image=screen})
-print("Running...")
-local win = nil
-while step < opt.steps do
- step = step + 1
- local action_index = agent:perceive(reward, screen, terminal, true, opt.ep)
+print("Started playing...")
- -- game over? get next game!
- if not terminal then
- screen, reward, terminal = game_env:step(game_actions[action_index], true)
- else
- screen, reward, terminal = game_env:newGame()
- end
+-- play one episode (game)
+while not terminal do
+ -- if action was chosen randomly, Q-value is 0
+ agent.bestq = 0
+
+ -- choose the best action
+ local action_index = agent:perceive(reward, screen, terminal, true, 0.05)
+
+ -- play game in test mode (episodes don't end when losing a life)
+ screen, reward, terminal = game_env:step(game_actions[action_index], false)
-- display screen
- win = image.display({image=screen, win=win})
+ image.display({image=screen, win=win})
+
+ -- create gd image from tensor
+ jpg = image.compressJPG(screen:squeeze(), 100)
+ im = gd.createFromJpegStr(jpg:storage():string())
+
+ -- use palette from previous (first) image
+ im:trueColorToPalette(false, 256)
+ im:paletteCopy(previm)
+
+ -- write new GIF frame, no local palette, starting from left-top, 7ms delay
+ im:gifAnimAdd(gif_filename, false, 0, 0, 7, gd.DISPOSAL_NONE)
+ -- remember previous screen for optimal compression
+ previm = im
end
+
+-- end GIF animation and close CSV file
+gd.gifAnimEnd(gif_filename)
+
+print("Finished playing, close window to exit!")
File renamed without changes
@@ -34,6 +34,8 @@ sudo apt-get install -qqy ncurses-dev
sudo apt-get install -qqy imagemagick
sudo apt-get install -qqy unzip
sudo apt-get install -qqy libqt4-dev
+sudo apt-get install -qqy liblua5.1-0-dev
+sudo apt-get install -qqy libgd-dev
sudo apt-get update
@@ -116,6 +118,7 @@ RET=$?; if [ $RET -ne 0 ]; then echo "Error. Exiting."; exit $RET; fi
echo "Alewrap installation completed"
echo "Installing Lua-GD ... "
+mkdir $PREFIX/src
cd $PREFIX/src
rm -rf lua-gd
git clone https://github.com/ittner/lua-gd.git
View
@@ -0,0 +1,48 @@
+#!/bin/bash
+
+if [ -z "$1" ]
+ then echo "Please provide the name of the game, e.g. ./watch_pretrained breakout"; exit 0
+fi
+
+if [ -z "$2" ]
+ then echo "Please provide the pretrained network file, e.g. ./watch_pretrained breakout DQN3_0_1_breakout_FULL_Y.t7"; exit 0
+fi
+
+ENV=$1
+NETWORK=$2
+FRAMEWORK="alewrap"
+
+game_path=$PWD"/roms/"
+env_params="useRGB=true"
+agent="NeuralQLearner"
+n_replay=1
+netfile="\"convnet_atari3\""
+update_freq=4
+actrep=4
+discount=0.99
+seed=1
+learn_start=50000
+pool_frms_type="\"max\""
+pool_frms_size=2
+initial_priority="false"
+replay_memory=1000000
+eps_end=0.1
+eps_endt=replay_memory
+lr=0.00025
+agent_type="DQN3_0_1"
+preproc_net="\"net_downsample_2x_full_y\""
+agent_name=$agent_type"_"$1"_FULL_Y"
+state_dim=7056
+ncols=1
+agent_params="lr="$lr",ep=1,ep_end="$eps_end",ep_endt="$eps_endt",discount="$discount",hist_len=4,learn_start="$learn_start",replay_memory="$replay_memory",update_freq="$update_freq",n_replay="$n_replay",network="$netfile",preproc="$preproc_net",state_dim="$state_dim",minibatch_size=32,rescale_r=1,ncols="$ncols",bufferSize=512,valid_size=500,target_q=10000,clip_delta=1,min_reward=-1,max_reward=1"
+gif_file="../gifs/$ENV.gif"
+gpu=0
+random_starts=30
+pool_frms="type="$pool_frms_type",size="$pool_frms_size
+num_threads=4
+
+args="-framework $FRAMEWORK -game_path $game_path -name $agent_name -env $ENV -env_params $env_params -agent $agent -agent_params $agent_params -actrep $actrep -gpu $gpu -random_starts $random_starts -pool_frms $pool_frms -seed $seed -threads $num_threads -network $NETWORK -gif_file $gif_file"
+echo $args
+
+cd dqn
+../torch/bin/qlua test_agent.lua $args
@@ -1,22 +1,17 @@
#!/bin/bash
if [ -z "$1" ]
- then echo "Please provide the name of the game, e.g. ./view_pretrained breakout"; exit 0
+ then echo "Please provide the name of the game, e.g. ./watch_pretrained breakout"; exit 0
fi
if [ -z "$2" ]
- then echo "Please provide the pretrained network file, e.g. ./view_pretrained breakout DQN3_0_1_breakout_FULL_Y"; exit 0
+ then echo "Please provide the pretrained network file, e.g. ./watch_pretrained breakout DQN3_0_1_breakout_FULL_Y.t7"; exit 0
fi
ENV=$1
NETWORK=$2
FRAMEWORK="alewrap"
-EP=0.0 # for visualizing exploration probability
-if [ ! -z "$3" ]; then
- EP=$3
-fi
-
game_path=$PWD"/roms/"
env_params="useRGB=true"
agent="NeuralQLearner"
@@ -40,18 +35,14 @@ agent_name=$agent_type"_"$1"_FULL_Y"
state_dim=7056
ncols=1
agent_params="lr="$lr",ep=1,ep_end="$eps_end",ep_endt="$eps_endt",discount="$discount",hist_len=4,learn_start="$learn_start",replay_memory="$replay_memory",update_freq="$update_freq",n_replay="$n_replay",network="$netfile",preproc="$preproc_net",state_dim="$state_dim",minibatch_size=32,rescale_r=1,ncols="$ncols",bufferSize=512,valid_size=500,target_q=10000,clip_delta=1,min_reward=-1,max_reward=1"
-steps=50000000
-eval_freq=250000
-eval_steps=125000
-prog_freq=10000
-save_freq=125000
+gif_file="../gifs/$ENV.gif"
gpu=1
random_starts=30
pool_frms="type="$pool_frms_type",size="$pool_frms_size
num_threads=4
-args="-framework $FRAMEWORK -game_path $game_path -name $agent_name -env $ENV -env_params $env_params -agent $agent -agent_params $agent_params -steps $steps -eval_freq $eval_freq -eval_steps $eval_steps -prog_freq $prog_freq -save_freq $save_freq -actrep $actrep -gpu $gpu -random_starts $random_starts -pool_frms $pool_frms -seed $seed -threads $num_threads -network $NETWORK -ep $EP"
+args="-framework $FRAMEWORK -game_path $game_path -name $agent_name -env $ENV -env_params $env_params -agent $agent -agent_params $agent_params -actrep $actrep -gpu $gpu -random_starts $random_starts -pool_frms $pool_frms -seed $seed -threads $num_threads -network $NETWORK -gif_file $gif_file"
echo $args
cd dqn
-qlua test_agent.lua $args
+../torch/bin/qlua test_agent.lua $args

0 comments on commit 7e5feff

Please sign in to comment.