In [None]:
import torch
import torch.utils.data
import ray
import gym
from IPython import display
import matplotlib
import matplotlib.pyplot as plt
from ray.tune import JupyterNotebookReporter
%matplotlib inline

In [None]:
print(torch.cuda.is_available())

In [None]:
from ray.rllib.agents import ppo
from ray import tune

config = ppo.DEFAULT_CONFIG.copy()
#Edit default config to do hyperparameter search
config['framework'] = 'torch'
config['lr'] = 0.001
config["num_gpus"] = 2
config["env"] = "BreakoutNoFrameskip-v4"
config["preprocessor_pref"] = "deepmind"
config["num_workers"]=4

In [None]:
import logging

from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch

torch, nn = try_import_torch()

logger = logging.getLogger(__name__)


class ConvNet(TorchModelV2, nn.Module):
    """Generic fully connected network."""

    def __init__(self, obs_space, action_space, num_outputs, model_config,
                 name):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
                              model_config, name)
        nn.Module.__init__(self)

        in_channels = obs_space.shape[-1]
        self._conv_layers = nn.Sequential(
            torch.nn.Conv2d(in_channels, 8, kernel_size=[7,7], padding=3),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2, stride=2, padding=1),
            torch.nn.Conv2d(8, 16, kernel_size=[5,5], padding=2),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2, stride=2, padding=1),
            torch.nn.Conv2d(16, 32, kernel_size=[3,3], padding=1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2, stride=2, padding=1),
            torch.nn.Conv2d(32, num_outputs, kernel_size=[12,12])
        )
        self._features = None
        self._num_outputs = num_outputs

        
    @override(TorchModelV2)
    def forward(self, input_dict, state, seq_lens):
        obs = input_dict["obs"].float().permute(0,3,1,2) #reshape input
        self._features = self._conv_layers(obs).view(-1, self._num_outputs)
        return self._features, state
    
    def value_function(self):
        return 

In [None]:
class PPOWrapper:
    def __init__(self, env, config, local_dir):
        self.agent = None
        self.env = env
        self.config = config
    def train(self, num_steps):
        self.agent = ppo.PPOTrainer(config=self.config, env=self.env)
        for i in range(num_steps):
            result = self.agent.train()
            print("Iteration: {}".format(i))
            print("Reward: {}", result['episode_reward_mean'])
        if i == num_steps - 1:
            checkpoint = self.agent.save()
            print('checkpoint saved at', checkpoint)
        return checkpoint
    def load(self, path):
        self.agent = ppo.PPOTrainer(config=self.config, env=self.env)
        self.agent.restore(path)
    def test(self, num_episodes):
        env = self.agent.workers.local_worker().env
        for episode in range(num_episodes):
            episode_reward = 0
            done = False
            obs = env.reset()
            while not done:
                action = self.agent.compute_action(obs)
                obs, reward, done, info = env.step(action)
                plt.imshow(env.render(mode='rgb_array'))
                display.display(plt.gcf())
                display.clear_output(wait=True)
                episode_reward += reward
            print(episode_reward)

In [None]:
ray.shutdown()
ray.init()
ppo_agent = PPOWrapper("BreakoutNoFrameskip-v4", config)
trainingSteps = 500
checkpoint_path = ppo_agent.train(trainingSteps)

In [None]:
ppo_agent.load(checkpoint_path)
ppo_agent.test(1)