In [1]:
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from collections import namedtuple, deque



# if GPU is to be used
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class DQN(nn.Module):

    def __init__(self, n_observations, n_actions):
        super(DQN, self).__init__()
        self.layer1 = nn.Linear(n_observations, 128)
        self.layer2 = nn.Linear(128, 128)
        self.layer3 = nn.Linear(128, n_actions)

    # Called with either one element to determine next action, or a batch
    # during optimization. Returns tensor([[left0exp,right0exp]...]).
    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        return self.layer3(x)

In [9]:
import os
from itertools import count

env = gym.make("CartPole-v1", render_mode='human')

# Get number of actions from gym action space
n_actions = env.action_space.n
# Get the number of state observations
state, info = env.reset()
n_observations = len(state)

policy_net = DQN(n_observations, n_actions).to(device)
model_folder_path = './model_saved'
file_name = 'model_599' + '.pth'
file_name = os.path.join(model_folder_path, file_name)
policy_net.load_state_dict(torch.load(file_name))

env.render()
# state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
for i_episode in range(5):
    state, info = env.reset()
    state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
    for t in count():
        action = policy_net(state).max(1)[1].view(1, 1)
        observation, reward, terminated, truncated, _ = env.step(action.item())
        done = terminated or truncated
        state = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0)
        if done:
            print('episode ', i_episode, ' get ', t + 1, ' score')
            break
print('done')


episode  0  get  500  score
episode  1  get  500  score
episode  2  get  500  score
episode  3  get  500  score
episode  4  get  500  score
done
