In [1]:
from reflect.components.trainers.ppo.ppo_trainer import PPOTrainer
from reflect.components.trainers.value.critic import ValueCritic
import gymnasium as gym
from shimmy.registration import DM_CONTROL_SUITE_ENVS
import torch
import numpy as np
from torch.distributions import Normal
import torch.nn as nn
import random
from dataclasses import asdict


seed=1
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True


class Shim:

    def __init__(self, env):
        self._env = env

    def __getattr__(self, name):
        return getattr(self._env, name)

    def step(self, action):
        obs, reward, done, _, info = self._env.step(action)
        return obs, reward, done, done, info
    

seed=1
device='cpu'
NUM_STEPS = 1024
NUM_ENVS = 10

env_ids = [f"dm_control/{'-'.join(item)}-v0" for item in DM_CONTROL_SUITE_ENVS]

def make_env(env_id='dm_control/walker-walk-v0', seed=1):
    def thunk():
        env = gym.make(
            env_id,
            render_mode="rgb_array",
            render_kwargs={'camera_id': 0}
        )
        env = gym.wrappers.ClipAction(env)
        env = gym.wrappers.FlattenObservation(env)
        env.action_space.seed(seed)
        env.observation_space.seed(seed)
        env = Shim(env)
        return env
    return thunk

envs = gym.vector.SyncVectorEnv(
    [make_env('dm_control/walker-walk-v0', 42 + i) for i in range(NUM_ENVS)]
)


def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer


class Actor(torch.nn.Module):
    def __init__(
            self,
            input_dim,
            output_dim,
        ):
        super().__init__()
        self.actor_mean = nn.Sequential(
            layer_init(nn.Linear(np.array(input_dim).prod(), 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, np.prod(output_dim)), std=0.01),
        )
        self.actor_logstd = nn.Parameter(torch.zeros(1, np.prod(output_dim)))

    def forward(self, x):
        action_mean = self.actor_mean(x)
        action_logstd = self.actor_logstd.expand_as(action_mean)
        action_std = torch.exp(action_logstd)
        probs = Normal(action_mean, action_std)
        return probs
    

class Critic(nn.Module):
    def __init__(
            self,
            input_dim
        ):
        super().__init__()
        self.critic = nn.Sequential(
            layer_init(nn.Linear(np.array(input_dim).prod(), 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 1), std=1.0),
        )

    def forward(self, x):
        return self.critic(x)


action_space_dim = envs.single_action_space.shape
observation_space_dim = envs.single_observation_space.shape

actor = Actor(
    input_dim=observation_space_dim[0],
    output_dim=action_space_dim[0],
).to(device)

critic = ValueCritic(
    state_dim=observation_space_dim[0],
).to(device)

trainer = PPOTrainer(
    actor=actor,
    critic=critic,
    actor_lr=3e-4,
    critic_lr=3e-4,
    grad_clip=0.5,
    batch_size=1024,
    num_minibatch=32,
    eta=0
)
trainer.to(device)


import torch

def compute_model_weight_norm(model, norm_type=2):
    model.eval()
    total_norm = 0.0
    for param in model.parameters():
        param_norm = torch.norm(param, norm_type)
        total_norm += param_norm.item() ** norm_type
    total_norm = total_norm ** (1. / norm_type)
    model.train()
    return total_norm


from livelossplot import PlotLosses


plotlosses = PlotLosses(
  groups={
    'weight_norms': [
        'actor_wn',
        'critic_wn'
    ],
    'actor_loss': ['actor_loss'],
    'value_loss': ['value_loss'],
    'actor_grad_norms': ['actor_grad_norm'],
    'value_grad_norm': ['value_grad_norm'],
    'rewards': [
        'rewards',
    ],
    # 'action_mean': [
    #     f'action_mean_{i}' for i in range(6)
    # ],
    # 'action_std': [
    #     f'action_std_{i}' for i in range(6)
    # ],
    'clipfrac': ['clipfrac'],
    'approxkl': ['approxkl'],
  }
)


obs = torch.zeros((NUM_ENVS, NUM_STEPS) + envs.single_observation_space.shape).to(device)
actions = torch.zeros((NUM_ENVS, NUM_STEPS) + envs.single_action_space.shape).to(device)
logprobs = torch.zeros((NUM_ENVS, NUM_STEPS)).to(device)
rewards = torch.zeros((NUM_ENVS, NUM_STEPS)).to(device)
dones = torch.zeros((NUM_ENVS, NUM_STEPS)).to(device)


def sample_reward(envs, actor):
    rewards = torch.zeros((NUM_ENVS)).to(device)
    o, _ = envs.reset()
    next_obs = torch.Tensor(o).to(device)
    for step in range(0, 500):
        with torch.no_grad():
            action_dist = actor(next_obs)
            action = action_dist.sample()

        next_obs, reward, done, _, info = envs.step(action.cpu().numpy())
        next_obs = torch.Tensor(next_obs).to(device)
        rewards = rewards + torch.tensor(reward).to(device).view(-1)
    return rewards.cpu().detach().mean()


o, _ = envs.reset()
next_obs = torch.Tensor(o).to(device)
next_done = torch.zeros(NUM_ENVS).to(device)

for i in range(100000):
    for step in range(0, NUM_STEPS):
        obs[:, step] = next_obs
        dones[:, step] = next_done
        with torch.no_grad():
            action_dist = actor(next_obs)
            action = action_dist.sample()
            logprob = action_dist.log_prob(action).sum(-1)
        actions[:, step] = action
        logprobs[:, step] = logprob

        next_obs, reward, done, _, info = envs.step(action.cpu().numpy())
        rewards[:, step] = torch.tensor(reward).to(device).view(-1)
        next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(done).to(device)

    for update in range(10):
        trainer_history = trainer.update(
            state_samples=obs.detach(),
            reward_samples=rewards[:, :, None].detach(),
            done_samples=dones[:, :, None].detach(),
            action_samples=actions.detach()
        )

    print(f"Step {i}")
    if (i > 0) and (i % 1) == 0:
        plotlosses.update({
            **asdict(trainer_history),
            **{
                "actor_wn": compute_model_weight_norm(critic),
                "critic_wn": compute_model_weight_norm(actor),
                "rewards": sample_reward(envs, actor),
            }
        })
        plotlosses.send()


Step 0


KeyboardInterrupt: 

In [3]:
obs.shape

torch.Size([10, 1024, 24])