In [1]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))

In [2]:
import matplotlib.pyplot as plt
from pprint import pprint

%matplotlib inline
plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'

plt.style.use('ggplot')

%load_ext autoreload
%autoreload 2

In [3]:
import sys
sys.path.append('..')

from gym_minigrid_navigation.utils import show_video
from navigation_policy import gen_env, get_agent, run_episode, run_episodes

from rewards import get_reward_function
from utils import init_logger, switch_reproducibility_on, display_stats

init_logger('dqn')
init_logger('navigation_policy')
init_logger('gym_minigrid_navigation.environments')

### config 

In [4]:
from pyhocon import ConfigFactory

config = ConfigFactory.parse_file('../conf/minigrid_dqn_navigation_resnet.hocon')
config['env']['video_path'] = '../outputs/video/'
config['training']['reward'] = 'image_net_similarity'
config['training']['reward_params'] = {'pretrained': False, 'device': 'cuda'}

In [5]:
switch_reproducibility_on(config['seed'])

### environment 

In [6]:
reward_functions = get_reward_function(config)
env = gen_env(config['env'], reward_functions)

### agent 

In [7]:
from utils import init_logger

agent = get_agent(env, config)

2021-02-04 22:58:11,802 INFO    dqn                    : Running on device: cuda:0


### trainings 

In [None]:
scores, steps = run_episodes(env, agent, n_episodes=config['training.n_episodes'], verbose=config['training.verbose'])

display_stats(scores, steps)

2021-02-04 23:35:23,919 INFO    navigation_policy      : Episode: 100. Average score: 5.109515762329101. Average steps: 135.61
2021-02-05 00:08:46,535 INFO    navigation_policy      : Episode: 200. Average score: 6.98623589515686. Average steps: 120.62


### visualisation 

In [None]:
env = gen_env(config['env'], reward_functions, verbose=True)
print(run_episode(env, agent, train_mode=False))

show_video()

In [None]:
stop

In [None]:
agent.qnetwork_target.master

In [None]:
model = agent.qnetwork_target.master
model.output_size

In [None]:
state = env.reset()
states = agent._vstack([state] * 6)
states.shape

In [None]:
model(states).shape