# Visualise the pre-trained agent in action

Modify the path to the weights and run the notebook.

In [1]:
WEIGHTS = './weights/weights-40000'  # NB: without '.pth'
METRICS = None # './src/output/metrics.csv'

In [2]:
import sys
import torch
import gymnasium as gym
import matplotlib.pyplot as plt

sys.path.append('../../')

from src.agent import VisionDeepQ
from _helpers.plotting import visualise_csv  # noqa
from _helpers.gif import gif_stack  # noqa

## Parameters

In [3]:
network = {
    "input_channels": 4, "outputs": 5,
    "channels": [64, 32],
    "kernels": [3, 5],
    "padding": ["same", "same"],
    "strides": [],
    "nodes": [],
}
optimizer = {
    "optimizer": torch.optim.Adam,
    "lr": 0.001,
    "hyperparameters": {}
}
shape = {
    "original": (1, 1, 210, 160),

    "height": slice(27, 203),
    "width": slice(22, 64),
    "max_pooling": 2,
}
skip = 4

## Setup

In [68]:
value_agent = VisionDeepQ(
    network=network, optimizer=optimizer, shape=shape,
    exploration_rate=1.0,
)

weights = torch.load(f'{WEIGHTS}.pth', map_location=torch.device('cpu'))
value_agent.load_state_dict(weights)

environment = gym.make('ALE/Tetris-v5', render_mode="rgb_array",
                       obs_type="grayscale", frameskip=1, repeat_action_probability=0.0)
environment.metadata["render_fps"] = 30

## Visualise

Plotting the metrics from the csv-file created during training.

In [5]:
visualise_csv(METRICS, title="Training Metrics") if METRICS else None
plt.show() if METRICS else None

Creating and saving a gif of the agent in action. The gif will be saved to the given path.

In [6]:
gif_stack(environment, value_agent, f'./{WEIGHTS}.gif', skip)