In [1]:
%matplotlib inline

In [5]:
import gymnasium as gym
import math
import random
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import count

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

env = gym.make("CartPole-v1")

# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display

plt.ion()

# if GPU is to be used
device = torch.device(
    "cuda" if torch.cuda.is_available() else
    "mps" if torch.backends.mps.is_available() else
    "cpu"
)

In [3]:
class DQN(nn.Module):

    def __init__(self, n_observations, n_actions):
        super(DQN, self).__init__()
        self.layer1 = nn.Linear(n_observations, 128)
        self.layer2 = nn.Linear(128, 128)
        self.layer3 = nn.Linear(128, n_actions)

    # Called with either one element to determine next action, or a batch
    # during optimization. Returns tensor([[left0exp,right0exp]...]).
    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        return self.layer3(x)

In [7]:
n_actions = env.action_space.n
# Get the number of state observations
state, info = env.reset()
n_observations = len(state)

policy_net = DQN(n_observations, n_actions).to(device)

state_dict = torch.load("policy.pt", map_location="cpu", weights_only=True)
policy_net.load_state_dict(state_dict)

<All keys matched successfully>

Visualize the trained policy

In [8]:
env_render = gym.make("CartPole-v1", render_mode="human")
state, info = env_render.reset()
state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
for t in count():
    with torch.no_grad():
        action = policy_net(state).max(1).indices.view(1, 1)
    observation, reward, terminated, truncated, _ = env_render.step(action.item())
    done = terminated or truncated
    if terminated:
        next_state = None
    else:
        next_state = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0)
    state = next_state
    if done:
        print(f"Episode finished after {t+1} timesteps.")
        break
env_render.close()
print("Visualization complete.")

Episode finished after 500 timesteps.
Visualization complete.


Record episode as video

In [None]:
import os
from gymnasium.wrappers import RecordVideo

env_record = gym.make("CartPole-v1", render_mode="rgb_array")

video_dir = "./videos"
os.makedirs(video_dir, exist_ok=True)
video_path = os.path.join(video_dir, "cartpole-episode-0")
env_record = RecordVideo(env_record, video_folder=video_dir, name_prefix="cartpole-episode")

state, info = env_record.reset(seed=42)
state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)

for t in count():
    with torch.no_grad():
        action = policy_net(state).max(1).indices.view(1, 1)

    observation, reward, terminated, truncated, _ = env_record.step(action.item())
    done = terminated or truncated

    if terminated:
        next_state = None
    else:
        next_state = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0)

    state = next_state

    if done:
        print(f"Recorded episode finished after {t+1} timesteps.")
        break

env_record.close()
print("Video recording complete. Check the 'videos' directory.")

  logger.warn(


Recorded episode finished after 500 timesteps.
Video recording complete. Check the 'videos' directory.
