In [None]:
#@title -- Installation of Packages -- { display-mode: "form" }
import sys
import shutil
USE_NBCAP = False

if not shutil.which('apt') is None:
    !apt update
    !apt install -y xvfb x11-utils
    !{sys.executable} -m pip install pyscreenshot pyvirtualdisplay
    !{sys.executable} -m pip install --upgrade pyglet
    !{sys.executable} -m pip install git+https://github.com/michalgregor/nbcap.git

    USE_NBCAP = True

!{sys.executable} -m pip install gym[classic_control]
!{sys.executable} -m pip install class_utils[tensorboard]@git+https://github.com/michalgregor/class_utils.git
!{sys.executable} -m pip install tianshou
!{sys.executable} -m pip install git+https://github.com/michalgregor/tianshou_agents.git

In [1]:
#@title -- Import of Necessary Packages -- { display-mode: "form" }
%load_ext tensorboard

import shutil
if shutil.which('apt') is None:
    USE_NBCAP = False
else:
    USE_NBCAP = True

    from nbcap import ShowVideoCallback, ScreenRecorder, OutputManager, DisplayProcess

import os
import torch
from tianshou_agents.utils import VectorEnvRenderWrapper
from tianshou_agents.sac import sac_simple
from tianshou.data import Collector
from tianshou.env import BaseVectorEnv
from tianshou_agents.preset import AgentPresetWrapper

In [2]:
#@title -- Auxiliary Functions -- { display-mode: "form" }

if USE_NBCAP:
    display_size=(600, 400)
    show_video = ShowVideoCallback(dimensions=display_size)

    # make sure that only one instance
    # of the display is ever created
    try:
        DISP_PROC
    except NameError:
        DISP_PROC = DisplayProcess(display_size=display_size)

    def make_screen_recorder(max_gui_outputs=1):
        video_path="output"
        segment_time=10

        output_manager = OutputManager(max_gui_outputs=max_gui_outputs)
        video_callback=output_manager(show_video)
        display = DISP_PROC.id

        screen_recorder = ScreenRecorder(
            display, display_size, video_path,
            segment_time=segment_time, video_callback=video_callback
        )
        
        return screen_recorder

    SCREEN_RECORDER = make_screen_recorder()
else:
    from contextlib import suppress
    SCREEN_RECORDER = suppress()

class RenderCollector(Collector):
    def __init__(self, collector, render=0.01):
        self.collector = collector

        if isinstance(self.collector.env, BaseVectorEnv):
            self.collector.env = VectorEnvRenderWrapper(
                self.collector.env)

        self._render = render

    @property
    def collect_time(self):
        return max(self.collector.collect_time, 1e-20)

    @collect_time.setter
    def collect_time(self, val):
        self.collector.collect_time = val

    def collect(
        self, n_step = None, n_episode = None, random = False,
        render = None, no_grad = True,
    ):
        with SCREEN_RECORDER:
            render = render or self._render
            return self.collector.collect(n_step, n_episode, random, render, no_grad)

    def __getattr__(self, name):
        if name.startswith('_'):
            raise AttributeError("attempted to get missing private attribute '{}'".format(name))
        return getattr(self.collector, name)

    def __str__(self):
        return '<{}{}>'.format(type(self).__name__, self.collector)

    def __repr__(self):
        return str(self)

class AgentPresetPatch(AgentPresetWrapper):
    def __init__(self, preset, render=0.01):
        super().__init__(preset)
        self._prev_test_envs = None
        self.render = render

    def __call__(self, *args, **kwargs):
        agent = self._preset(*args, **kwargs)
        
        # we close the previous pyglet window before
        # opening a new one to work around a bug on Windows
        if not self._prev_test_envs is None:
            self._prev_test_envs.close()

        agent.test_collector = RenderCollector(
            agent.test_collector, render=self.render
        )
        
        self._prev_test_envs = agent.test_envs
        
        return agent
        
sac_simple = AgentPresetPatch(sac_simple)

# Checkpointing and Saving Agents

This example illustrates how agents are checkpointed automatically and how the saving and loading of agent's state works. This is not only useful when trying to load and use a previously trained agent – it will also allow you to resume training after it has been interrupted, etc.

Let's start by creating a simple SAC agent from a preset and training it.

In [3]:
agent = sac_simple(
    'Pendulum-v0', stop_criterion=-250, seed=0
)

In [4]:
train_results = agent.train(max_epoch=10, step_per_epoch=1000)

Epoch #1: 1001it [00:17, 57.03it/s, alpha=0.753, env_step=1000, len=200, loss/actor=25.850, loss/alpha=-0.651, loss/critic1=0.384, loss/critic2=0.406, n/ep=0, n/st=1, rew=-1829.07]                          
Epoch #2:   1%|1         | 11/1000 [00:00<00:14, 70.18it/s, alpha=0.751, env_step=1011, len=200, loss/actor=26.207, loss/alpha=-0.657, loss/critic1=0.384, loss/critic2=0.406, n/ep=0, n/st=1, rew=-1829.07]

Epoch #1: test_reward: -1810.144515 ± 148.657324, best_reward: -1186.824446 ± 304.551268 in #0


Epoch #2: 1001it [00:16, 60.63it/s, alpha=0.565, env_step=2000, len=200, loss/actor=55.363, loss/alpha=-1.139, loss/critic1=0.530, loss/critic2=0.707, n/ep=0, n/st=1, rew=-1547.19]                          
Epoch #3:   1%|1         | 10/1000 [00:00<00:14, 67.96it/s, alpha=0.564, env_step=2010, len=200, loss/actor=55.721, loss/alpha=-1.141, loss/critic1=0.531, loss/critic2=0.710, n/ep=0, n/st=1, rew=-1547.19]

Epoch #2: test_reward: -1135.350921 ± 102.566455, best_reward: -1135.350921 ± 102.566455 in #2


Epoch #3: 1001it [00:16, 59.15it/s, alpha=0.442, env_step=3000, len=200, loss/actor=74.259, loss/alpha=-1.162, loss/critic1=0.816, loss/critic2=1.078, n/ep=0, n/st=1, rew=-387.59]                          
Epoch #4:   1%|1         | 10/1000 [00:00<00:14, 66.67it/s, alpha=0.441, env_step=3010, len=200, loss/actor=74.476, loss/alpha=-1.161, loss/critic1=0.869, loss/critic2=1.133, n/ep=0, n/st=1, rew=-387.59]

Epoch #3: test_reward: -1003.528745 ± 601.529164, best_reward: -1003.528745 ± 601.529164 in #3


Epoch #4: 1001it [00:42, 23.72it/s, alpha=0.352, env_step=4000, len=200, loss/actor=82.729, loss/alpha=-1.147, loss/critic1=1.822, loss/critic2=1.651, n/ep=0, n/st=1, rew=-1.53]                          
Epoch #5:   1%|1         | 11/1000 [00:00<00:14, 69.57it/s, alpha=0.351, env_step=4011, len=200, loss/actor=82.332, loss/alpha=-1.164, loss/critic1=1.802, loss/critic2=1.626, n/ep=0, n/st=1, rew=-1.53]

Epoch #4: test_reward: -373.723060 ± 335.072812, best_reward: -373.723060 ± 335.072812 in #4


Epoch #5: 1001it [00:28, 35.11it/s, alpha=0.279, env_step=5000, len=200, loss/actor=88.286, loss/alpha=-1.286, loss/critic1=2.950, loss/critic2=2.810, n/ep=0, n/st=1, rew=-1355.56]                          
Epoch #6:   1%|1         | 13/1000 [00:00<00:12, 79.21it/s, alpha=0.278, env_step=5013, len=200, loss/actor=88.141, loss/alpha=-1.292, loss/critic1=2.938, loss/critic2=2.859, n/ep=0, n/st=1, rew=-1355.56]

Epoch #5: test_reward: -427.622719 ± 461.578917, best_reward: -373.723060 ± 335.072812 in #4


Epoch #6:  47%|####7     | 472/1000 [00:46<00:51, 10.19it/s, env_step=5472, len=200, n/ep=1, n/st=1, rew=-123.81]


Now let's have a look at what the training progress looks like in terms of epoch, environment steps and learning steps:

In [5]:
agent.epoch, agent.env_step, agent.gradient_step

(6, 5471, 5471)

## Loading an Agent from a Checkpoint

By default, the agent is checkpointed every time there is an improvement. The checkpointed state goes under the logging path and we can load it back using ``torch.load``. It has the form of a dictionary.

In [6]:
state_dict = torch.load("log/Pendulum-v0/sac/best_agent.pth")

Once we have loaded the checkpointed state, we can write it back into a new agent of the same kind using ``agent.load_state_dict``.

In [7]:
agent = sac_simple(
    'Pendulum-v0', stop_criterion=-250, seed=0
)

agent.load_state_dict(state_dict)

If you observe the epoch, environment step and the gradient step of this new agent now, you will see that it displays the training that we did in our previous agent.

In [8]:
agent.epoch, agent.env_step, agent.gradient_step

(6, 5471, 5471)

You can also run ``agent.test`` and observe the performance of the agent to verify that the trained policy was restored as well.

In [9]:
agent.test()

{'n/ep': 10,
 'n/st': 2000,
 'rews': array([-126.8196232 , -118.639652  , -126.15630426, -123.37939039,
        -127.0365974 , -246.65841828, -129.50245252, -240.11327871,
        -127.65953754, -126.30450594]),
 'lens': array([200, 200, 200, 200, 200, 200, 200, 200, 200, 200]),
 'idxs': array([0, 1, 2, 3, 4, 0, 1, 2, 3, 4])}

## Saving the Agent Manually

You can also save the state of the agent manually – the interface is similar to that used in PyTorch. You retrieve the state of the agent using ``agent.state_dict()`` and then save it using ``torch.save``.

In [5]:
state_dict = agent.state_dict()
torch.save(agent, "output/agent.pth")

To load the state of the agent back, you can then construct a new agent, load the state dict using ``torch.load`` and fill it into the agent using ``agent.load_state_dict``.

In [7]:
agent2 = sac_simple(
    'Pendulum-v0', stop_criterion=-250, seed=0
)

state_dict = torch.load("output/agent.pth")
agent2.load_state_dict(state_dict)

The stats regarding training progress will again be retained as well as the saved policy.

In [None]:
agent.test()