# Policy-based vision agent using PyTorch

In [1]:
import copy
import time
import torch
import imageio
import gymnasium as gym
import matplotlib.pyplot as plt

from agents.torch_policy_vision import VisionPolicyGradient

## Tetris environment

|        | TYPE                   | VALUES          | DESCRIPTION                                                                                                |
|--------|------------------------|-----------------|------------------------------------------------------------------------------------------------------------|
| Action Space | ndarray<br/>(1,) | {0, 1, 2, 3, 4} | Action to manipulate the current tile.<br/>0: No action<br/>1: Rotate<br/>2: Right<br/>3: Left<br/>4: Down |
| Observation Space | ndarray<br/>(210,160,) | <0, 255> | The game screen.                                                                                           |
| Reward |  | float | Reward given when a row is filled.<br/>Single: 50<br/>Double: 150<br/>Triple: 400<br/>Quadruple: 900       |
| Termination |  | boolean | The game ends when the pieces stack up to the top of the playing field.                                    |

In [2]:
environment = gym.make('ALE/Tetris-v5', render_mode="rgb_array", obs_type="grayscale")

A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]


### Training

In [3]:
GAMES = 50

NETWORK = {"input_channels": 1, "outputs": 5, 
           "channels": [4, 8],
           "kernels": [(3, 3), (3, 3), (5, 5)]
           }
OPTIMIZER = {"optim": torch.optim.RMSprop, "lr": 0.00025}

In [4]:
policy_agent = VisionPolicyGradient(
    network=NETWORK, optimizer=OPTIMIZER
)
_policy_agent = copy.deepcopy(policy_agent)

In [5]:
checkpoint = GAMES // 10
metrics = {metric: torch.zeros(GAMES) for metric in ["steps", "gradients"]}

In [6]:
start = time.time()
for game in range(1, GAMES + 1):
    
    state = torch.tensor(environment.reset()[0], dtype=torch.float32).unsqueeze(0)  # noqa
    terminated = truncated = False
    
    # LEARNING FROM GAME
    # --------------------------------------------------
    
    steps = 0
    while not (terminated or truncated):
        steps += 1
        action, logarithm = policy_agent.action(state)
        state, reward, terminated, truncated, _ = environment.step(action)
        state = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
        policy_agent.memorize(logarithm, reward)
    gradient = policy_agent.learn()
    
    # METRICS
    # --------------------------------------------------
    
    metrics["steps"][game-1] = steps
    metrics["gradients"][game-1] = gradient
    
    if game % checkpoint == 0 or game == GAMES:

        _mean_steps = metrics["steps"][max(0, game-checkpoint-1):game-1].mean()
        _mean_gradient = metrics["gradients"][max(0, game-checkpoint-1):game-1].mean()
        
        print(f"Game {game:>6} {int(game/GAMES * 100):>16} % \n"
              f"{'-'*30} \n"
              f" > Average steps: {int(_mean_steps):>12} \n"
              f" > Average gradients: {_mean_gradient:>8.4f} \n ")
        
print(f"Total training time: {time.time()-start:.2f} seconds")

Game      5               10 % 
------------------------------ 
 > Average steps:          265 
 > Average gradients:   0.0000 
 
Game     10               20 % 
------------------------------ 
 > Average steps:          301 
 > Average gradients:   0.0000 
 
Game     15               30 % 
------------------------------ 
 > Average steps:          277 
 > Average gradients:   0.0000 
 



KeyboardInterrupt



#### Visualisation

In [None]:
def moving_average(data, window_size=50):
    """Compute moving average with given window size of the data."""
    half_window = window_size // 2
    return [(data[max(0, i-half_window):min(GAMES, i+half_window)]).mean() 
            for i in range(GAMES)]

steps = moving_average(metrics["steps"])
gradients = moving_average(metrics["gradients"])

fig, ax = plt.subplots(2, 1, figsize=(12, 8))
fig.suptitle("Policy-based vision gradient agent")

ax[0].axhline(y=500, color="red", linestyle="dotted", linewidth=1)
ax[0].plot(steps, color="black", linewidth=1)
ax[0].set_xticks([])
ax[0].set_title("Average steps per game")

ax[1].axhline(y=0, color="red", linestyle="dotted", linewidth=1)
ax[1].plot(gradients, color="black", linewidth=1)
ax[1].set_xlabel("Game nr.")
ax[1].set_title("Average gradients")

for i in range(0, GAMES, GAMES // 10):
    ax[0].axvline(x=i, color='gray', linewidth=0.5)
    ax[1].axvline(x=i, color='gray', linewidth=0.5)

plt.savefig("./static/images/torch-pbg-tetris.png")
plt.show()

##### In action

In [None]:
state = torch.tensor(environment.reset()[0], dtype=torch.float32).unsqueeze(0)

images = []
terminated = truncated = False
while not (terminated or truncated):
    actions = torch.softmax(policy_agent(state), dim=-1)
    action = torch.argmax(actions).item()
    
    state, reward, terminated, truncated, _ = environment.step(action)
    state = torch.tensor(state, dtype=torch.float32).unsqueeze(0)

    images.append(environment.render())
_ = imageio.mimsave('./static/images/torch-pbg-tetris.gif', images, duration=25)

<img src="./static/images/torch-pbg-tetris.gif" width="1000" height="1000" />

In [None]:
environment.close()