# RLify Use a custom NN model example
In this file we will see examples of running differnet agent-algorithms, getting the train metric, and watching the agents in actions.


In [1]:
import torch
from rlify.models.base_model import BaseModel
import torch.nn as nn
import numpy as np
from torchvision import transforms

class my_resnet(BaseModel):
    def __init__(self, input_shape, out_shape):
        super(my_resnet, self).__init__(input_shape, out_shape)
        self.preprocess = transforms.Compose(
            [
                transforms.Resize(224, antialias=True),
            ]
        )
        self.resnet = torch.hub.load(
            "pytorch/vision:v0.10.0", "resnet18", weights=None
        )
        self.out_layer = nn.Linear(1000, np.prod(self.out_shape))

    def forward(self, x):
        # since we know the input is an observation_space=Box(0, 255, (210, 160, 3), np.uint8) and not dict we can just pass x with default key 'data'
        # (for  more infor check ObsWrapper class)
        x = x["data"]
        # we need to permute the input to be in the shape of (batch_size, channels, height, width)
        x = self.preprocess(x.permute(0, 3, 1, 2))
        x = self.resnet(x)
        return self.out_layer(x)

    def reset(self):
        # we can pass since its not a rnn model
        pass

## Example 1 - Train using discrete PPO
lets train a LunaLander gym env

In [2]:
import numpy as np
import gymnasium as gym
from rlify.agents.ppo_agent import PPO_Agent

In [3]:
def norm_obs(x):
    return (x / 255).astype(np.float32)

env_name = "Pong-v4"
env = gym.make(env_name, render_mode=None)
from gym.wrappers import TransformObservation

env = TransformObservation(env, norm_obs)
models_shapes = PPO_Agent.get_models_input_output_shape(
env.observation_space, env.action_space
)

A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]


In [4]:
from rlify.utils import init_torch
device = init_torch()
policy_input_shape = models_shapes["policy_nn"]["input_shape"]
policy_out_shape = models_shapes["policy_nn"]["out_shape"]
critic_input_shape = models_shapes["critic_nn"]["input_shape"]
critic_out_shape = models_shapes["critic_nn"]["out_shape"]
policy_nn = my_resnet(input_shape=policy_input_shape, out_shape=policy_out_shape)
critic_nn = my_resnet(input_shape=critic_input_shape, out_shape=critic_out_shape)

agent = PPO_Agent(
obs_space=env.observation_space,
action_space=env.action_space,
device=device,
batch_size=64,
max_mem_size=10**4,
num_parallel_envs=1,
lr=3e-4,
entropy_coeff=0.05,
policy_nn=policy_nn,
critic_nn=critic_nn,
discount_factor=0.99,
kl_div_thresh=0.05,
clip_param=0.2,
tensorboard_dir="./tensorboard/",
)
train_stats = agent.train_n_steps(env=env, n_steps=1000)

Using cache found in /home/nitsan57/.cache/torch/hub/pytorch_vision_v0.10.0
Using cache found in /home/nitsan57/.cache/torch/hub/pytorch_vision_v0.10.0
episode 0, curr_mean_R:-00019.0, best_mean_R:-19.0, total_steps:1497: : 1497it [01:46, 14.10it/s]             
