In [3]:
import os
import json

import argparse
import warnings
import random

import numpy as np
import sys; sys.path.insert(0, '..')
import torch

from distutils.dir_util import copy_tree

from rl2 import collectors, envs, agents, models, defaults

import rl2.utils.common as common
from rl2.utils.distributions import CategoricalHead, ScalarHead
from rl2.agents.agent import AbstractAgent
from rl2.modules import DeepMindEnc

In [4]:
from rl2.envs.gym.atari import make_atari

In [5]:
seed = 42
env = 'atari'
env_id = 'Breakout'

In [6]:
eps = 1e-8
inf = 1e8

In [7]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [8]:
env = make_atari(env_id+'NoFrameskip-v4', 1, 42)




Base class

inherit this class to implement your agent

In [9]:
class GeneralAgent(AbstractAgent):
    name = 'General'

    def __init__(self, model, collector, epoch):
#         self.args = args
        self.model = model
        self.epoch = epoch
        self.collector = collector
#         self.logger = Logger(self.name, args=args)

#         self.info = EvaluationMetrics([
#             'Time/Step',
#             'Time/Item',
#             'Loss/Total',
#             'Loss/Value',
#             'Values/Reward',
#             'Values/Value',
#             'Score/Train',
#         ])

    def loss_func(self, *args, **kwargs):
        raise NotImplementedError

    def train(self):
        self.collector.step_env()
        for epoch in range(self.epoch):
            self.collector.reset_count()
            while self.collector.has_next():
                data = self.collector.step()
                loss = self.loss_func(*data, info=self.info)
                self.model.step(loss)

Your implementation

In [10]:
class PPOAgent(GeneralAgent):
    name = 'PPO'

    def __init__(self,
                 model,
                 collector,
                 epoch=4,
                 vf_coef=0.5,
                 ent_coef=0.01,
                 cliprange=0.1):
        super().__init__(model=model, collector=collector, epoch=epoch)

        self.vf_coef = vf_coef
        self.ent_coef = ent_coef

        self.cliprange = cliprange

#         self.metrics.set([
#             'Loss/Policy',
#             'Values/Entropy',
#             'Values/Adv',
#         ])

    def loss_func(self, obs, old_acs, old_nlps, advs, old_rets):
        ac_dist, val_dist = self.model.forward(obs)
        vals = val_dist.mean
        nlps = -ac_dist.log_prob(old_acs)
        ent = ac_dist.entropy().mean()
        old_vals = old_rets - advs

        advs = (advs - advs.mean()) / (advs.std() + settings.EPS)

        vals_clipped = (old_vals + torch.clamp(vals - old_vals,
                                               -self.cliprange,
                                               self.cliprange))

        vf_loss_clipped = 0.5 * (vals_clipped - old_rets.detach()).pow(2)
        vf_loss = 0.5 * (vals - old_rets.detach()).pow(2)

        vf_loss = torch.max(vf_loss, vf_loss_clipped).mean()

        ratio = torch.exp(old_nlps - nlps).unsqueeze(-1)
        pg_loss1 = -advs * ratio

        ratio = torch.clamp(ratio, 1 - self.cliprange, 1 + self.cliprange)
        pg_loss2 = -advs * ratio

        pg_loss = torch.max(pg_loss1, pg_loss2).mean()

        # Total loss
        loss = pg_loss - self.ent_coef * ent + self.vf_coef * vf_loss

        if self.metrics is not None:
            self.metrics.update('Values/Value', vals.mean().item())
            self.metrics.update('Values/Adv', advs.mean().item())
            self.metrics.update('Values/Entropy', ent.item())
            self.metrics.update('Loss/Value', vf_loss.item())
            self.metrics.update('Loss/Policy', pg_loss.item())
            self.metrics.update('Loss/Total', loss.item())

        return loss


ppo

env & input shape

In [11]:
# env = getattr(envs, args.env)(args)

# Create network components for the agent
input_shape = env.observation_space.shape

reshape from tf to pt

In [12]:
if len(input_shape) > 1:
    input_shape = (input_shape[-1], *input_shape[:-1])

networks

In [13]:
encoder = DeepMindEnc(input_shape).to(device)
actor = CategoricalHead(encoder.out_shape, env.action_space.n).to(device)
critic = ScalarHead(encoder.out_shape, 1).to(device)

networks = [encoder, actor, critic]

In [14]:
# Declare optimizer
optimizer = 'torch.optim.Adam'

In [15]:
from rl2 import models

In [16]:
# Create a model using the necessary networks
model = models.ActorCriticModel(networks, optimizer)

In [17]:
# Create a collector for managing data collection
collector = collectors.PGCollector(env, model, device)

In [18]:
agent = PPOAgent(model, collector)

In [19]:
from tqdm import tqdm

In [20]:
# # Finally create an agent with the defined components
# train(args, 'PPOAgent', 'ppo', model, collector)
steps = int(5e7)
steps = steps // collector.num_workers + 1
for step in tqdm(range(5)):
#     if train_fn is None:
    agent.train()
#     else:
#         train_fn(agent, step, steps)

100%|██████████| 5/5 [00:01<00:00,  3.96it/s]


In [21]:
model = agent.model

torch.save(model.nets[0].state_dict(), 'deepmindenc.pt')

torch.save(model.nets[2].state_dict(), 'scalarhead.pt')

In [34]:
model.modules()

<generator object Module.named_parameters at 0x1698bdf10>

In [38]:
torch.jit.save(torch.jit.script(next(model.modules())), 'deepmindenc_script.pt')

RuntimeError: 
undefined value super:
  File "../rl2/utils/distributions.py", line 127
    def __init__(self, logits):
        super().__init__(logits=logits)
        ~~~~~ <--- HERE
'CategoricalDist.__init__' is being compiled since it was called from 'CategoricalDist'
  File "../rl2/utils/distributions.py", line 41
    def forward(self, x):
        x = self.linear(x)
        dist = CategoricalDist(logits=F.log_softmax(x, dim=-1))
               ~~~~~~~~~~~~~~~ <--- HERE
        return dist
'CategoricalDist' is being compiled since it was called from 'CategoricalHead.forward'
  File "../rl2/utils/distributions.py", line 41
    def forward(self, x):
        x = self.linear(x)
        dist = CategoricalDist(logits=F.log_softmax(x, dim=-1))
        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
        return dist


In [49]:
sp = env.observation_space.shape

In [50]:
input_shape = (sp[-1], *sp[:-1])

In [51]:
sp

(84, 84, 4)

In [47]:
example_inputs = torch.rand((32,4,8,8)).to(device)

In [53]:
idx = 0
for module in model.modules():
    print(idx)
    print(module)
    idx+=1

0
ActorCriticModel(
  (encoder): DeepMindEnc(
    (feature): Sequential(
      (0): Conv2d(4, 32, kernel_size=(8, 8), stride=(4, 4))
      (1): ReLU()
      (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
      (3): ReLU()
      (4): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1))
      (5): ReLU()
    )
    (fc): Sequential(
      (0): Linear(in_features=1568, out_features=256, bias=True)
      (1): ReLU()
    )
  )
  (actor): CategoricalHead(
    (linear): Linear(in_features=256, out_features=4, bias=True)
  )
  (critic): ScalarHead(
    (linear): Linear(in_features=256, out_features=1, bias=True)
  )
)
1
DeepMindEnc(
  (feature): Sequential(
    (0): Conv2d(4, 32, kernel_size=(8, 8), stride=(4, 4))
    (1): ReLU()
    (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
    (3): ReLU()
    (4): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1))
    (5): ReLU()
  )
  (fc): Sequential(
    (0): Linear(in_features=1568, out_features=256, bias=True)
    (1): ReLU()
  )
)
2
Seq

In [48]:
script = torch.jit.trace(next(model.modules()), example_inputs)

RuntimeError: Calculated padded input size per channel: (1 x 1). Kernel size: (4 x 4). Kernel size can't be greater than actual input size

In [35]:
torch.jit.save(script, 'categoricalhead_script.pt')

RuntimeError: Expected object of device type cuda but got device type cpu for argument #1 'self' in call to _th_mm

In [None]:
torch.jit.save(torch.jit.script(model.nets[2]), 'scalarhead_script.pt')

In [108]:
loaded_net1 = torch.jit.load('deepmindenc_script.pt')

In [109]:
loaded_net2 = torch.jit.load('categoricalhead.pt')

RuntimeError: istream reader failed: reading file. (validate at /pytorch/caffe2/serialize/istream_adapter.cc:32)
frame #0: c10::Error::Error(c10::SourceLocation, std::string const&) + 0x33 (0x7fa611dee193 in /home/anthony/.local/lib/python3.6/site-packages/torch/lib/libc10.so)
frame #1: caffe2::serialize::IStreamAdapter::validate(char const*) const + 0x338 (0x7fa614f793c8 in /home/anthony/.local/lib/python3.6/site-packages/torch/lib/libtorch.so)
frame #2: caffe2::serialize::IStreamAdapter::read(unsigned long, void*, unsigned long, char const*) const + 0x2c (0x7fa614f796fc in /home/anthony/.local/lib/python3.6/site-packages/torch/lib/libtorch.so)
frame #3: <unknown function> + 0x2d26376 (0x7fa614f67376 in /home/anthony/.local/lib/python3.6/site-packages/torch/lib/libtorch.so)
frame #4: <unknown function> + 0x2d2b3a4 (0x7fa614f6c3a4 in /home/anthony/.local/lib/python3.6/site-packages/torch/lib/libtorch.so)
frame #5: caffe2::serialize::PyTorchStreamReader::init() + 0x8b (0x7fa614f74b1b in /home/anthony/.local/lib/python3.6/site-packages/torch/lib/libtorch.so)
frame #6: caffe2::serialize::PyTorchStreamReader::PyTorchStreamReader(std::string const&) + 0x64 (0x7fa614f77c04 in /home/anthony/.local/lib/python3.6/site-packages/torch/lib/libtorch.so)
frame #7: torch::jit::import_ir_module(std::shared_ptr<torch::jit::script::CompilationUnit>, std::string const&, c10::optional<c10::Device>, std::unordered_map<std::string, std::string, std::hash<std::string>, std::equal_to<std::string>, std::allocator<std::pair<std::string const, std::string> > >&) + 0x35 (0x7fa6162d7845 in /home/anthony/.local/lib/python3.6/site-packages/torch/lib/libtorch.so)
frame #8: <unknown function> + 0x776ffb (0x7fa65cf58ffb in /home/anthony/.local/lib/python3.6/site-packages/torch/lib/libtorch_python.so)
frame #9: <unknown function> + 0x2961c4 (0x7fa65ca781c4 in /home/anthony/.local/lib/python3.6/site-packages/torch/lib/libtorch_python.so)
frame #10: _PyCFunction_FastCallDict + 0x154 (0x559a110f8b94 in /home/anthony/miniconda3/envs/ml-stable/bin/python)
frame #11: <unknown function> + 0x19e67c (0x559a1118867c in /home/anthony/miniconda3/envs/ml-stable/bin/python)
frame #12: _PyEval_EvalFrameDefault + 0x2fa (0x559a111aacba in /home/anthony/miniconda3/envs/ml-stable/bin/python)
frame #13: <unknown function> + 0x197a94 (0x559a11181a94 in /home/anthony/miniconda3/envs/ml-stable/bin/python)
frame #14: <unknown function> + 0x198941 (0x559a11182941 in /home/anthony/miniconda3/envs/ml-stable/bin/python)
frame #15: <unknown function> + 0x19e755 (0x559a11188755 in /home/anthony/miniconda3/envs/ml-stable/bin/python)
frame #16: _PyEval_EvalFrameDefault + 0x2fa (0x559a111aacba in /home/anthony/miniconda3/envs/ml-stable/bin/python)
frame #17: PyEval_EvalCodeEx + 0x329 (0x559a11183459 in /home/anthony/miniconda3/envs/ml-stable/bin/python)
frame #18: PyEval_EvalCode + 0x1c (0x559a111841ec in /home/anthony/miniconda3/envs/ml-stable/bin/python)
frame #19: <unknown function> + 0x1be6cb (0x559a111a86cb in /home/anthony/miniconda3/envs/ml-stable/bin/python)
frame #20: _PyCFunction_FastCallDict + 0x91 (0x559a110f8ad1 in /home/anthony/miniconda3/envs/ml-stable/bin/python)
frame #21: <unknown function> + 0x19e67c (0x559a1118867c in /home/anthony/miniconda3/envs/ml-stable/bin/python)
frame #22: _PyEval_EvalFrameDefault + 0x2fa (0x559a111aacba in /home/anthony/miniconda3/envs/ml-stable/bin/python)
frame #23: _PyGen_Send + 0x256 (0x559a1118b866 in /home/anthony/miniconda3/envs/ml-stable/bin/python)
frame #24: _PyEval_EvalFrameDefault + 0x13ad (0x559a111abd6d in /home/anthony/miniconda3/envs/ml-stable/bin/python)
frame #25: _PyGen_Send + 0x256 (0x559a1118b866 in /home/anthony/miniconda3/envs/ml-stable/bin/python)
frame #26: _PyEval_EvalFrameDefault + 0x13ad (0x559a111abd6d in /home/anthony/miniconda3/envs/ml-stable/bin/python)
frame #27: _PyGen_Send + 0x256 (0x559a1118b866 in /home/anthony/miniconda3/envs/ml-stable/bin/python)
frame #28: _PyCFunction_FastCallDict + 0x115 (0x559a110f8b55 in /home/anthony/miniconda3/envs/ml-stable/bin/python)
frame #29: <unknown function> + 0x19e67c (0x559a1118867c in /home/anthony/miniconda3/envs/ml-stable/bin/python)
frame #30: _PyEval_EvalFrameDefault + 0x2fa (0x559a111aacba in /home/anthony/miniconda3/envs/ml-stable/bin/python)
frame #31: <unknown function> + 0x19870b (0x559a1118270b in /home/anthony/miniconda3/envs/ml-stable/bin/python)
frame #32: <unknown function> + 0x19e755 (0x559a11188755 in /home/anthony/miniconda3/envs/ml-stable/bin/python)
frame #33: _PyEval_EvalFrameDefault + 0x2fa (0x559a111aacba in /home/anthony/miniconda3/envs/ml-stable/bin/python)
frame #34: <unknown function> + 0x19870b (0x559a1118270b in /home/anthony/miniconda3/envs/ml-stable/bin/python)
frame #35: <unknown function> + 0x19e755 (0x559a11188755 in /home/anthony/miniconda3/envs/ml-stable/bin/python)
frame #36: _PyEval_EvalFrameDefault + 0x2fa (0x559a111aacba in /home/anthony/miniconda3/envs/ml-stable/bin/python)
frame #37: <unknown function> + 0x197a94 (0x559a11181a94 in /home/anthony/miniconda3/envs/ml-stable/bin/python)
frame #38: _PyFunction_FastCallDict + 0x3db (0x559a1118303b in /home/anthony/miniconda3/envs/ml-stable/bin/python)
frame #39: _PyObject_FastCallDict + 0x26f (0x559a110f8f5f in /home/anthony/miniconda3/envs/ml-stable/bin/python)
frame #40: _PyObject_Call_Prepend + 0x63 (0x559a110fda03 in /home/anthony/miniconda3/envs/ml-stable/bin/python)
frame #41: PyObject_Call + 0x3e (0x559a110f899e in /home/anthony/miniconda3/envs/ml-stable/bin/python)
frame #42: _PyEval_EvalFrameDefault + 0x1ab0 (0x559a111ac470 in /home/anthony/miniconda3/envs/ml-stable/bin/python)
frame #43: <unknown function> + 0x197c26 (0x559a11181c26 in /home/anthony/miniconda3/envs/ml-stable/bin/python)
frame #44: <unknown function> + 0x198941 (0x559a11182941 in /home/anthony/miniconda3/envs/ml-stable/bin/python)
frame #45: <unknown function> + 0x19e755 (0x559a11188755 in /home/anthony/miniconda3/envs/ml-stable/bin/python)
frame #46: _PyEval_EvalFrameDefault + 0x10ba (0x559a111aba7a in /home/anthony/miniconda3/envs/ml-stable/bin/python)
frame #47: <unknown function> + 0x1a13d0 (0x559a1118b3d0 in /home/anthony/miniconda3/envs/ml-stable/bin/python)
frame #48: _PyCFunction_FastCallDict + 0x91 (0x559a110f8ad1 in /home/anthony/miniconda3/envs/ml-stable/bin/python)
frame #49: <unknown function> + 0x19e67c (0x559a1118867c in /home/anthony/miniconda3/envs/ml-stable/bin/python)
frame #50: _PyEval_EvalFrameDefault + 0x2fa (0x559a111aacba in /home/anthony/miniconda3/envs/ml-stable/bin/python)
frame #51: <unknown function> + 0x197c26 (0x559a11181c26 in /home/anthony/miniconda3/envs/ml-stable/bin/python)
frame #52: <unknown function> + 0x198941 (0x559a11182941 in /home/anthony/miniconda3/envs/ml-stable/bin/python)
frame #53: <unknown function> + 0x19e755 (0x559a11188755 in /home/anthony/miniconda3/envs/ml-stable/bin/python)
frame #54: _PyEval_EvalFrameDefault + 0x2fa (0x559a111aacba in /home/anthony/miniconda3/envs/ml-stable/bin/python)
frame #55: <unknown function> + 0x1a13d0 (0x559a1118b3d0 in /home/anthony/miniconda3/envs/ml-stable/bin/python)
frame #56: _PyCFunction_FastCallDict + 0x91 (0x559a110f8ad1 in /home/anthony/miniconda3/envs/ml-stable/bin/python)
frame #57: <unknown function> + 0x19e67c (0x559a1118867c in /home/anthony/miniconda3/envs/ml-stable/bin/python)
frame #58: _PyEval_EvalFrameDefault + 0x2fa (0x559a111aacba in /home/anthony/miniconda3/envs/ml-stable/bin/python)
frame #59: <unknown function> + 0x197c26 (0x559a11181c26 in /home/anthony/miniconda3/envs/ml-stable/bin/python)
frame #60: <unknown function> + 0x198941 (0x559a11182941 in /home/anthony/miniconda3/envs/ml-stable/bin/python)
frame #61: <unknown function> + 0x19e755 (0x559a11188755 in /home/anthony/miniconda3/envs/ml-stable/bin/python)
frame #62: _PyEval_EvalFrameDefault + 0x2fa (0x559a111aacba in /home/anthony/miniconda3/envs/ml-stable/bin/python)
frame #63: <unknown function> + 0x1a13d0 (0x559a1118b3d0 in /home/anthony/miniconda3/envs/ml-stable/bin/python)


In [102]:
loaded_net.state_dict()

OrderedDict([('feature.0.weight',
              tensor([[[[-0.0456, -0.0061,  0.0086,  ..., -0.0013,  0.0013,  0.0611],
                        [ 0.0254, -0.0324, -0.0599,  ...,  0.0255, -0.0010,  0.0597],
                        [ 0.0516, -0.0215, -0.0080,  ...,  0.0341,  0.0477,  0.0229],
                        ...,
                        [-0.0095, -0.0022, -0.0173,  ...,  0.0236, -0.0310,  0.0007],
                        [-0.0565,  0.0590,  0.0299,  ..., -0.0021,  0.0109, -0.0398],
                        [ 0.0362, -0.0469,  0.0385,  ..., -0.0109,  0.0110,  0.0511]],
              
                       [[-0.0022, -0.0097, -0.0528,  ..., -0.0394, -0.0551,  0.0325],
                        [ 0.0042,  0.0489, -0.0604,  ...,  0.0624,  0.0612, -0.0432],
                        [-0.0139, -0.0344, -0.0452,  ..., -0.0390,  0.0050,  0.0188],
                        ...,
                        [ 0.0266, -0.0398,  0.0022,  ..., -0.0366, -0.0047,  0.0507],
                        [ 0.0342

In [98]:
torch.load('deepmindenc.pt')

OrderedDict([('feature.0.weight',
              tensor([[[[-0.0456, -0.0061,  0.0086,  ..., -0.0013,  0.0013,  0.0611],
                        [ 0.0254, -0.0324, -0.0599,  ...,  0.0255, -0.0010,  0.0597],
                        [ 0.0516, -0.0215, -0.0080,  ...,  0.0341,  0.0477,  0.0229],
                        ...,
                        [-0.0095, -0.0022, -0.0173,  ...,  0.0236, -0.0310,  0.0007],
                        [-0.0565,  0.0590,  0.0299,  ..., -0.0021,  0.0109, -0.0398],
                        [ 0.0362, -0.0469,  0.0385,  ..., -0.0109,  0.0110,  0.0511]],
              
                       [[-0.0022, -0.0097, -0.0528,  ..., -0.0394, -0.0551,  0.0325],
                        [ 0.0042,  0.0489, -0.0604,  ...,  0.0624,  0.0612, -0.0432],
                        [-0.0139, -0.0344, -0.0452,  ..., -0.0390,  0.0050,  0.0188],
                        ...,
                        [ 0.0266, -0.0398,  0.0022,  ..., -0.0366, -0.0047,  0.0507],
                        [ 0.0342

In [81]:
next(net1.modules())

DeepMindEnc(
  (feature): Sequential(
    (0): Conv2d(4, 32, kernel_size=(8, 8), stride=(4, 4))
    (1): ReLU()
    (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
    (3): ReLU()
    (4): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1))
    (5): ReLU()
  )
  (fc): Sequential(
    (0): Linear(in_features=1568, out_features=256, bias=True)
    (1): ReLU()
  )
)

In [48]:
gen = net1.parameters()

In [67]:
aa = next(gen)
aa

StopIteration: 

In [68]:
aa.shape

torch.Size([256])

In [84]:
print(torch.save.__doc__)

Saves an object to a disk file.

    See also: :ref:`recommend-saving-models`

    Args:
        obj: saved object
        f: a file-like object (has to implement write and flush) or a string
           containing a file name
        pickle_module: module used for pickling metadata and objects
        pickle_protocol: can be specified to override the default protocol

        If you are using Python 2, :func:`torch.save` does NOT support :class:`StringIO.StringIO`
        as a valid file-like object. This is because the write method should return
        the number of bytes written; :meth:`StringIO.write()` does not do this.

        Please use something like :class:`io.BytesIO` instead.

    Example:
        >>> # Save to file
        >>> x = torch.tensor([0, 1, 2, 3, 4])
        >>> torch.save(x, 'tensor.pt')
        >>> # Save to io.BytesIO buffer
        >>> buffer = io.BytesIO()
        >>> torch.save(x, buffer)
    


In [85]:
net1.state_dict()

OrderedDict([('feature.0.weight',
              tensor([[[[-1.3687e-02,  5.8401e-02,  2.8761e-02,  ...,  5.9029e-02,
                          5.5160e-02,  1.6049e-02],
                        [ 2.6644e-02, -5.2050e-02, -2.5416e-02,  ..., -1.9995e-02,
                         -1.8489e-02,  2.1497e-02],
                        [-4.0789e-02, -1.8163e-03,  3.1420e-02,  ..., -5.5475e-02,
                         -4.2569e-02,  3.1108e-02],
                        ...,
                        [ 3.8891e-02,  4.6222e-02,  1.5544e-03,  ...,  4.5083e-02,
                          2.3373e-02, -6.2243e-02],
                        [-5.1023e-02, -4.7213e-02,  1.7733e-02,  ...,  4.6972e-02,
                          1.7759e-02,  5.1614e-02],
                        [-3.4019e-02,  4.4249e-03,  4.8774e-02,  ..., -3.0882e-02,
                          5.6985e-02, -3.6005e-02]],
              
                       [[-2.3937e-02,  2.4887e-02,  3.2013e-02,  ...,  5.5352e-02,
                         -4.