In [2]:
import gymnasium as gym
import importlib
import torch
import wandb
import numpy as np
import matplotlib.pyplot as plt
import random
from IPython.display import clear_output

def get_class(name):
    module, name = name.rsplit(".", 1)
    return getattr(importlib.import_module(module), name)

%env WANDB_NOTEBOOK_NAME reinforce
# API key found at https://wandb.ai/quickstart
wandb.login(key=input())

pygame 2.4.0 (SDL 2.26.4, Python 3.10.10)
Hello from the pygame community. https://www.pygame.org/contribute.html
env: WANDB_NOTEBOOK_NAME=reinforce


[34m[1mwandb[0m: Currently logged in as: [33mcaretcaret[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /Users/amber/.netrc


True

In [11]:
config = {
    "env": "CartPole-v1",
    "device": "cpu",
    "seed": 42,
    "lr": 0.01,
    "betas": (0.9, 0.999),
    "eps": 1e-5,
    "hidden": (64,),
    "activation": "torch.nn.Tanh",
    "episodes": 1000,
    "episode_length": 500,
}
wandb.init(
    project="reinforce",
    config=config
)
env: gym.Env = gym.make(config["env"], render_mode="rgb_array")
seed = config["seed"]
device = torch.device(config["device"])
env.action_space.seed(seed)
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
wandb.run

0,1
entropy,█▇▆▄▄▂▁▃▁▁
episode,▁▂▃▃▄▅▆▆▇█
return,█▃▁▁▂▁▁▁▁▁
steps,█▃▁▁▂▁▁▁▁▁
terminated,▁▁▁▁▁▁▁▁▁▁
truncated,▁▁▁▁▁▁▁▁▁▁

0,1
entropy,0.10475
episode,9
return,10.0
steps,10
terminated,True
truncated,False


In [12]:
from typing import Type

print(env.observation_space, env.action_space)

class Agent(torch.nn.Module):
    def __init__(self, env: gym.Env, dims: tuple[int, ...], activation: Type[torch.nn.Module] = torch.nn.Tanh):
        super().__init__()
        # separate actor/critic networks are better
        output_size: int = env.action_space.n if type(env.action_space) == gym.spaces.Discrete else np.prod(env.action_space.shape)
        self.actor = torch.nn.Sequential(*self._make_layers(env, dims, activation, torch.nn.Linear(dims[-1], output_size)))
        if type(env.action_space) == gym.spaces.Box:
            self.actor_logstd = torch.nn.Parameter(torch.zeros(output_size))
        self.critic = torch.nn.Sequential(*self._make_layers(env, dims, activation, torch.nn.Linear(dims[-1], 1)))
    
    def _make_layers(self, env, dims, activation, output_layer) -> list[torch.nn.Module]:
        layers = [torch.nn.Linear(env.observation_space.shape[0], dims[0]), activation()]
        for i in range(len(dims) - 1):
            layers.append(torch.nn.Linear(dims[i], dims[i + 1]))
            layers.append(activation())
        layers.append(output_layer)
        for layer in layers:
            if type(layer) == torch.nn.Linear:
                torch.nn.init.orthogonal_(layer.weight)
                torch.nn.init.zeros_(layer.bias)
        return layers

    def get_action(self, observation):
        if type(env.action_space) == gym.spaces.Discrete:
            return torch.distributions.categorical.Categorical(logits=self.actor.forward(torch.Tensor(observation).to(device)))
        return torch.distributions.normal.Normal(self.actor.forward(torch.Tensor(observation).to(device)), torch.exp(self.actor_logstd).to(device))

    def get_value(self, observation):
        return self.critic.forward(torch.Tensor(observation).to(device))

agent = Agent(env, dims=config["hidden"], activation=get_class(config["activation"])).to(device)
agent_optimizer = torch.optim.Adam(agent.actor.parameters(), lr=config["lr"], betas=config["betas"], eps=config["eps"])

for episode in range(config["episodes"]):
    observation, info = env.reset()
    # Unrolling the environment
    terminated, truncated = False, False
    observations, actions, logprobs, rewards, entropies = [], [], [], [], []

    steps = 0
    for _ in range(config["episode_length"]):
        observations.append(observation)
        action_dist = agent.get_action(observation)
        action = action_dist.sample()
        entropies.append(action_dist.entropy())
        actions.append(action)
        logprobs.append(action_dist.log_prob(action))
        print(action)
        observation, reward, terminated, truncated, info = env.step(action.item())
        rewards.append(reward)
        steps += 1
        clear_output(wait=True)
        plt.imshow(env.render())
        plt.show()
        print(steps, info)
        if terminated or truncated:
            break

    # policy gradient
    agent_loss = torch.zeros_like(logprobs[0]).to(device)
    for t in range(len(rewards)):
        G = sum(rewards[t:]) # TODO + bootstrap value if not done
        agent_loss -= logprobs[t] * G  # gradient ascent inverts the sign
    agent_optimizer.zero_grad()
    agent_loss.backward()
    agent_optimizer.step()
    # TODO: update critic

    wandb.log({
        "episode": episode,
        "steps": steps,
        "return": sum(rewards),
        "entropy": torch.stack(entropies).mean(),
    })


KeyboardInterrupt: 

In [13]:
env.close()
wandb.finish()

0,1
entropy,██▃▇▆▄▅▆▅▅▆▅▆▄▄▄▄▃▅▄▄▅▁▂▄▄▂▃▃▃▃▄▄▄▄▃▂▃▄▃
episode,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
return,▁▁▁▂▁▁▁▂▂▁▂▁▂▂▂▁▁▁▁▁▁▂▁▁▂▂▂▂▅▂█▃▃▄▅▂▂▂█▂
steps,▁▁▁▂▁▁▁▂▂▁▂▁▂▂▂▁▁▁▁▁▁▂▁▁▂▂▂▂▅▂█▃▃▄▅▂▂▂█▂

0,1
entropy,0.55375
episode,63.0
return,90.0
steps,90.0
