# Policy-based agent using PyTorch

In [1]:
import time
import torch
import imageio
import gymnasium as gym
import matplotlib.pyplot as plt

from agents.torch_policy import PolicyGradient

KeyboardInterrupt: 

## Cart-pole environment

|        | TYPE             | VALUES | DESCRIPTION                                                                     |
|--------|------------------|--------|---------------------------------------------------------------------------------|
| Action | ndarray<br/>(1,) | {0, 1} | Direction to push cart.<br/>0: left<br/>1: right                                |
| Observation | ndarray<br/>(4,) | float  | 1. Cart position<br/>2. Cart velocity<br/>3. Pole angle<br/>4. Angular velocity |
| Reward |  | float | Reward is given for every step taken, including the termination step.<br/>Each step provides a reward of +1 |
| Termination &<br/>Truncation |  | boolean | Pole Angle exceeds ± 0.2095 rad<br/>Cart Position exceeds ± 2.4 (i.e., edge of the display)<br/>Episode length is greater than 500 |

In [None]:
environment = gym.make('CartPole-v1', render_mode="rgb_array")

### Training

In [None]:
GAMES = 10000

NETWORK = {"inputs": 4, "outputs": 2, "nodes": [15, 30]}
OPTIMIZER = {"optim": torch.optim.RMSprop, "lr": 0.00025}

In [None]:
policy_agent = PolicyGradient(
    network=NETWORK, optimizer=OPTIMIZER
)

checkpoint = GAMES // 10
metrics = {metric: torch.zeros(GAMES) for metric in ["steps", "gradients"]}

In [None]:
start = time.time()
for game in range(1, GAMES + 1):
    
    state = torch.tensor(environment.reset()[0], dtype=torch.float32)  # noqa
    terminated = truncated = False
    
    # LEARNING FROM GAME
    # --------------------------------------------------
    
    steps = 0
    while not (terminated or truncated):
        steps += 1
        action, logarithm = policy_agent.action(state)
        state, reward, terminated, truncated, _ = environment.step(action)
        state = torch.tensor(state, dtype=torch.float32)
        policy_agent.memorize(logarithm, reward)
    gradient = policy_agent.learn()
    
    # METRICS
    # --------------------------------------------------
    
    metrics["steps"][game-1] = steps
    metrics["gradients"][game-1] = gradient
    
    if game % checkpoint == 0 or game == GAMES:

        _mean_steps = metrics["steps"][max(0, game-checkpoint-1):game-1].mean()
        _mean_gradient = metrics["gradients"][max(0, game-checkpoint-1):game-1].mean()
        
        print(f"Game {game:>6} {int(game/GAMES * 100):>16} % \n"
              f"{'-'*30} \n"
              f" > Average steps: {int(_mean_steps):>12} \n"
              f" > Average gradients: {_mean_gradient:>8.4f} \n ")
        
print(f"Total training time: {time.time()-start:.2f} seconds")

#### Visualisation

In [None]:
def moving_average(data, window_size=50):
    """Compute moving average with given window size of the data."""
    half_window = window_size // 2
    return [(data[max(0, i-half_window):min(GAMES, i+half_window)]).mean() 
            for i in range(GAMES)]

steps = moving_average(metrics["steps"])
gradients = moving_average(metrics["gradients"])

fig, ax = plt.subplots(2, 1, figsize=(12, 8))
fig.suptitle("Policy-based gradient agent")

ax[0].axhline(y=500, color="red", linestyle="dotted", linewidth=1)
ax[0].plot(steps, color="black", linewidth=1)
ax[0].set_xticks([])
ax[0].set_title("Average steps per game")

ax[1].axhline(y=0, color="red", linestyle="dotted", linewidth=1)
ax[1].plot(gradients, color="black", linewidth=1)
ax[1].set_xlabel("Game nr.")
ax[1].set_title("Average gradients")

for i in range(0, GAMES, GAMES // 10):
    ax[0].axvline(x=i, color='gray', linewidth=0.5)
    ax[1].axvline(x=i, color='gray', linewidth=0.5)

plt.savefig("./static/images/torch-pbg.png")
plt.show()

##### In action

In [None]:
state = torch.tensor(environment.reset()[0], dtype=torch.float32)

images = []
terminated = truncated = False
while not (terminated or truncated):
    actions = torch.softmax(policy_agent(state), dim=-1)
    action = torch.argmax(actions).item()
    
    state, reward, terminated, truncated, _ = environment.step(action)
    state = torch.tensor(state, dtype=torch.float32)

    images.append(environment.render())
_ = imageio.mimsave('./static/images/torch-pbg.gif', images, duration=25)

<img src="./static/images/torch-pbg.gif" width="1000" height="1000" />

In [None]:
environment.close()