In [1]:
import gym
import numpy as np

import time

In [2]:
import sys
import os

sys.path.append(os.path.abspath('../'))

del sys, os

In [3]:
import matplotlib.pyplot as plt

In [4]:
# LaTeX rendering in graphs
from distutils.spawn import find_executable
if find_executable('latex'):
    plt.rc('text', usetex=True)

plt.rc('font', family='serif')

# High resolution graphs
%config InlineBackend.figure_format = 'retina'

In [5]:
import torch

In [6]:
%reload_ext autoreload
%autoreload 2

In [7]:
import models.rnn as rnns
import models.mlp as mlps
import models.linear as linears
import control.agents as agents
import control.environments as env

In [8]:
from utils.notifications import Slack

In [9]:
import copy

# Setup

In [10]:
env_name = 'Taxi-v2'
#env_name = 'Breakout-ram-v0'

In [11]:
environment = env.Environment(
    environment=gym.make(env_name), 
    agent=None,
    verbose=True,
    max_steps=300,
    capacity=1000,
    representation_method='one_hot_encoding'
)

  result = entry_point.load(False)


In [12]:
input_dimension = environment.get_input_dimension()
n_actions = environment.n_actions
print("Input dimension: {}; number of actions: {}".format(input_dimension, n_actions))

Input dimension: 501; number of actions: 6


In [13]:
model = mlps.MLP(input_dimension=input_dimension, 
                 hidden_dimension=200,
                 n_hidden_layers=1,
                 n_actions=n_actions,
                 dropout=0)
optimiser = torch.optim.Adam(model.parameters(), lr=.1)
agent = agents.DQNAgent(model, optimiser, gamma=.99, temperature=1, algorithm='expsarsa', n_actions=n_actions)
environment.agent = agent

In [18]:
print(environment.agent.q(environment.state_representation(environment.environment.reset())))
model.load_state_dict(torch.load('../saved/taxi/mlp/state_dict.pth'))
agent.commit()
print(environment.agent.q(environment.state_representation(environment.environment.reset())))
print(environment.boltzmann(environment.state_representation(environment.environment.reset())))

[-152.5814  -151.62679 -152.77357 -152.7423  -151.98988 -152.57753]
[-152.5814  -151.62679 -152.77357 -152.7423  -151.98988 -152.57753]
[0.123688   0.32130238 0.10206328 0.10530633 0.22347164 0.12416831]


# Experiment

## Boltzmann

In [29]:
environment.reset()

done = False
full_return = 0.

counter = 0
while not done and counter < environment.max_steps:
    
    s, reward, done, i = environment.step(environment.action)

    p, q = environment.boltzmann(s, return_q=True)
    a = environment.sample_action(p)
    
    environment.state, environment.action = s, a

    full_return = environment.agent.gamma * full_return + reward
    counter += 1
    
    print(np.argmax(s))
    print(environment.agent.q(s))


249
[-152.5814  -151.62679 -152.77357 -152.7423  -151.98988 -152.57753]
149
[-152.5814  -151.62679 -152.77357 -152.7423  -151.98988 -152.57753]
49
[-152.5814  -151.62679 -152.77357 -152.7423  -151.98988 -152.57753]
49
[-152.5814  -151.62679 -152.77357 -152.7423  -151.98988 -152.57753]
49
[-152.5814  -151.62679 -152.77357 -152.7423  -151.98988 -152.57753]
69
[-152.5814  -151.62679 -152.77357 -152.7423  -151.98988 -152.57753]
69
[-152.5814  -151.62679 -152.77357 -152.7423  -151.98988 -152.57753]
89
[-152.5814  -151.62679 -152.77357 -152.7423  -151.98988 -152.57753]
69
[-152.5814  -151.62679 -152.77357 -152.7423  -151.98988 -152.57753]
69
[-152.5814  -151.62679 -152.77357 -152.7423  -151.98988 -152.57753]
69
[-152.5814  -151.62679 -152.77357 -152.7423  -151.98988 -152.57753]
169
[-152.5814  -151.62679 -152.77357 -152.7423  -151.98988 -152.57753]
169
[-152.5814  -151.62679 -152.77357 -152.7423  -151.98988 -152.57753]
269
[-152.5814  -151.62679 -152.77357 -152.7423  -151.98988 -152.57753]
2

[-152.5814  -151.62679 -152.77357 -152.7423  -151.98988 -152.57753]
69
[-152.5814  -151.62679 -152.77357 -152.7423  -151.98988 -152.57753]
69
[-152.5814  -151.62679 -152.77357 -152.7423  -151.98988 -152.57753]
49
[-152.5814  -151.62679 -152.77357 -152.7423  -151.98988 -152.57753]
49
[-152.5814  -151.62679 -152.77357 -152.7423  -151.98988 -152.57753]
49
[-152.5814  -151.62679 -152.77357 -152.7423  -151.98988 -152.57753]
69
[-152.5814  -151.62679 -152.77357 -152.7423  -151.98988 -152.57753]
69
[-152.5814  -151.62679 -152.77357 -152.7423  -151.98988 -152.57753]
69
[-152.5814  -151.62679 -152.77357 -152.7423  -151.98988 -152.57753]
69
[-152.5814  -151.62679 -152.77357 -152.7423  -151.98988 -152.57753]
69
[-152.5814  -151.62679 -152.77357 -152.7423  -151.98988 -152.57753]
69
[-152.5814  -151.62679 -152.77357 -152.7423  -151.98988 -152.57753]
49
[-152.5814  -151.62679 -152.77357 -152.7423  -151.98988 -152.57753]
49
[-152.5814  -151.62679 -152.77357 -152.7423  -151.98988 -152.57753]
49
[-152.

In [32]:
n_episodes = 3
agent.temperature = 0.1

plt.figure()

for i in range(n_episodes):

    full_return, counter, observations = environment.evaluation_episode(render=False,return_observations=True)
    
    q = []
    
    observation_old = None
    q_old = None
    
    for observation in observations:
        
        observation_new = environment.state_representation(observation)
        q_new = environment.agent.q(environment.state_representation(observation))
        
        if observation_old is not None:
            same_obs = np.array_equal(observation_old, observation_new)
            same_q = np.array_equal(q_old, q_new)
            if not same_obs and not same_q:
                print("great")
        
        observation_old = observation_new
        q_old = q_new
        
        q.append(q_new)
        
    q = np.asarray(q)
    print(np.std(q, axis=0))
    print(q[:,0])
    plt.plot(q[:,0], label='0')
    break
    #plt.plot(q[:,1], label='1')
    #plt.plot(q[:,2], label='2')
    
plt.show()

SyntaxError: invalid syntax (<ipython-input-32-4cc637ebf2f2>, line 12)

In [15]:
plt.figure()
for i in range(n_episodes):
    x = np.asarray(q_estimation[i])
    plt.plot(x[:,1])

plt.show()

NameError: name 'q_estimation' is not defined

<Figure size 432x288 with 0 Axes>

## Testing

In [23]:
agent.temperature = 0.1
for _ in range(5):
    environment.exploration_episode(render=True)

In [24]:
for _ in range(5):
    environment.evaluation_episode(render=True)