In [1]:
import gym
import numpy as np

import time

In [2]:
import sys
import os

sys.path.append(os.path.abspath('../'))

del sys, os

In [3]:
import matplotlib.pyplot as plt

In [4]:
# LaTeX rendering in graphs
from distutils.spawn import find_executable
if find_executable('latex'):
    plt.rc('text', usetex=True)

plt.rc('font', family='serif')

# High resolution graphs
%config InlineBackend.figure_format = 'retina'

In [5]:
import torch

In [6]:
%reload_ext autoreload
%autoreload 2

In [7]:
import models.rnn as rnns
import models.mlp as mlps
import models.linear as linears
import control.agents as agents
import control.environments as env

In [8]:
from utils.notifications import Slack

In [9]:
import copy

# Setup

## Environment


In [10]:
env_name = 'Taxi-v2'

environment = env.Environment(
    environment=gym.make(env_name), 
    agent=None,
    verbose=True,
    max_steps=200,
    capacity=500,
    #representation_method='one_hot_encoding',
    representation_method='observation',
)

  result = entry_point.load(False)


## Model

In [11]:
model_linear = linears.Linear(
    input_dimension=environment.get_input_dimension(), 
    n_actions=environment.n_actions,
)

model_mlp = mlps.MLP(
    input_dimension=environment.get_input_dimension(), 
    hidden_dimension=250,
    n_hidden_layers=2,
    n_actions=environment.n_actions,
    dropout=.5
)

model_rnn = rnns.RNN(
    input_dimension=environment.get_input_dimension(), 
    hidden_dimension=50,
    n_actions=environment.n_actions,
    dropout=.3,
    truncate=10
)

## Agent

In [12]:
model = model_mlp

agent = agents.DQNAgent(
    model=model,
    optimiser=torch.optim.Adam(model.parameters(), lr=.001), 
    gamma=.99, 
    temperature=10, 
    algorithm='sarsa', 
    n_actions=environment.n_actions,
    terminal_state=environment.max_obs
)

environment.agent = agent

Load an agent:

In [13]:
#model.load_state_dict(torch.load('../saved/taxi/linear/state_dict_saved.pth'))
#model.load_state_dict(torch.load('../saved/taxi/mlp/state_dict_saved.pth'))
#agent.commit()

# Experiment

## Training

In [19]:
environment.run(
    epochs=10,
    segments=10,
    episodes=100,
    wall_time=2,
    num_evaluation=200,
    batch_size=100,
    save_directory='../saved/taxi/mlp',
    log_directory='mlp_obersvations_taxi',
)

100%|###############################################################| 10/10 [04:02<00:00, 25.85s/it]


>> Training return : -505.13


  0%|                                                                        | 0/10 [00:00<?, ?it/s]

>> Evaluation return : -201.49, steps : 200.00


 20%|############8                                                   | 2/10 [00:50<03:26, 25.85s/it]


KeyboardInterrupt: 

## Testing

In [17]:
environment.exploration_episode(render=True)

+---------+
|R: | : :[35mG[0m|
| : : :[43m [0m: |
| : : : : |
| | : | : |
|[34;1mY[0m| : |B: |
+---------+

+---------+
|R: | : :[35mG[0m|
| : : : : |
| : : :[43m [0m: |
| | : | : |
|[34;1mY[0m| : |B: |
+---------+
  (South)
269
+---------+
|R: | : :[35mG[0m|
| : : : : |
| : : :[43m [0m: |
| | : | : |
|[34;1mY[0m| : |B: |
+---------+
  (Dropoff)
269
+---------+
|R: | : :[35mG[0m|
| : : :[43m [0m: |
| : : : : |
| | : | : |
|[34;1mY[0m| : |B: |
+---------+
  (North)
169
+---------+
|R: | : :[35mG[0m|
| : :[43m [0m: : |
| : : : : |
| | : | : |
|[34;1mY[0m| : |B: |
+---------+
  (West)
149
+---------+
|R: | : :[35mG[0m|
| : : :[43m [0m: |
| : : : : |
| | : | : |
|[34;1mY[0m| : |B: |
+---------+
  (East)
169
+---------+
|R: | : :[35mG[0m|
| : : : :[43m [0m|
| : : : : |
| | : | : |
|[34;1mY[0m| : |B: |
+---------+
  (East)
189
+---------+
|R: | : :[35m[43mG[0m[0m|
| : : : : |
| : : : : |
| | : | : |
|[34;1mY[0m| : |B: |
+---------+
  (North)
89

+---------+
|R: | : :[35mG[0m|
| : : : : |
| : : : : |
| | : | : |
|[34;1mY[0m| : |[43mB[0m: |
+---------+
  (Pickup)
469
+---------+
|R: | : :[35mG[0m|
| : : : : |
| : : : : |
| | : | : |
|[34;1mY[0m| : |[43mB[0m: |
+---------+
  (Dropoff)
469
+---------+
|R: | : :[35mG[0m|
| : : : : |
| : : : : |
| | : | : |
|[34;1mY[0m| : |[43mB[0m: |
+---------+
  (West)
469
+---------+
|R: | : :[35mG[0m|
| : : : : |
| : : : : |
| | : | : |
|[34;1mY[0m| : |[43mB[0m: |
+---------+
  (Dropoff)
469
+---------+
|R: | : :[35mG[0m|
| : : : : |
| : : : : |
| | : |[43m [0m: |
|[34;1mY[0m| : |B: |
+---------+
  (North)
369
+---------+
|R: | : :[35mG[0m|
| : : : : |
| : : : : |
| | : |[43m [0m: |
|[34;1mY[0m| : |B: |
+---------+
  (West)
369
+---------+
|R: | : :[35mG[0m|
| : : : : |
| : : : : |
| | : | : |
|[34;1mY[0m| : |[43mB[0m: |
+---------+
  (South)
469
+---------+
|R: | : :[35mG[0m|
| : : : : |
| : : : : |
| | : | : |
|[34;1mY[0m| : |[43mB[0m: |
+------

+---------+
|R: | : :[35mG[0m|
| : : : : |
| : : : : |
| | : | : |
|[34;1mY[0m| :[43m [0m|B: |
+---------+
  (East)
449
+---------+
|R: | : :[35mG[0m|
| : : : : |
| : : : : |
| | : | : |
|[34;1mY[0m| :[43m [0m|B: |
+---------+
  (Pickup)
449
+---------+
|R: | : :[35mG[0m|
| : : : : |
| : : : : |
| | : | : |
|[34;1mY[0m|[43m [0m: |B: |
+---------+
  (West)
429
+---------+
|R: | : :[35mG[0m|
| : : : : |
| : : : : |
| | : | : |
|[34;1mY[0m|[43m [0m: |B: |
+---------+
  (South)
429
+---------+
|R: | : :[35mG[0m|
| : : : : |
| : : : : |
| | : | : |
|[34;1mY[0m|[43m [0m: |B: |
+---------+
  (Pickup)
429
+---------+
|R: | : :[35mG[0m|
| : : : : |
| : : : : |
| | : | : |
|[34;1mY[0m| :[43m [0m|B: |
+---------+
  (East)
449
+---------+
|R: | : :[35mG[0m|
| : : : : |
| : : : : |
| | : | : |
|[34;1mY[0m| :[43m [0m|B: |
+---------+
  (South)
449
+---------+
|R: | : :[35mG[0m|
| : : : : |
| : : : : |
| | : | : |
|[34;1mY[0m| :[43m [0m|B: |
+---------+

In [15]:
environment.evaluation_episode(render=True)

+---------+
|[35mR[0m: | : :[34;1mG[0m|
| : : : : |
| : : : : |
| | : | : |
|Y| :[43m [0m|B: |
+---------+

State: 444
+---------+
|[35mR[0m: | : :[34;1mG[0m|
| : : : : |
| : : : : |
| | :[43m [0m| : |
|Y| : |B: |
+---------+
  (North)
State: 344
+---------+
|[35mR[0m: | : :[34;1mG[0m|
| : : : : |
| : :[43m [0m: : |
| | : | : |
|Y| : |B: |
+---------+
  (North)
State: 244
+---------+
|[35mR[0m: | : :[34;1mG[0m|
| : : : : |
| : : :[43m [0m: |
| | : | : |
|Y| : |B: |
+---------+
  (East)
State: 264
+---------+
|[35mR[0m: | : :[34;1mG[0m|
| : : :[43m [0m: |
| : : : : |
| | : | : |
|Y| : |B: |
+---------+
  (North)
State: 164
+---------+
|[35mR[0m: | : :[34;1mG[0m|
| : : : :[43m [0m|
| : : : : |
| | : | : |
|Y| : |B: |
+---------+
  (East)
State: 184
+---------+
|[35mR[0m: | : :[34;1m[43mG[0m[0m|
| : : : : |
| : : : : |
| | : | : |
|Y| : |B: |
+---------+
  (North)
State: 84
+---------+
|[35mR[0m: | : :[42mG[0m|
| : : : : |
| : : : : |
| | : | 

## Save

In [41]:
environment.agent.save('../saved/taxi/mlp')