In [13]:
import gymnasium as gym
import torch
import math
import random
import matplotlib.pyplot as plt
from IPython import display
from collections import namedtuple, deque
from itertools import count

In [14]:

def plot_durations(episode_durations, show_result=False):
    plt.figure(1)
    durations_t = torch.tensor(episode_durations, dtype=torch.float)
    if show_result:
        plt.title('Result')
    else:
        plt.clf()
        plt.title('Training...')
    plt.xlabel('Episode')
    plt.ylabel('Duration')
    plt.plot(durations_t.numpy())

    # Take 100 episode averages and plot them too
    if len(durations_t) >= 100:
        means = durations_t.unfold(0, 100, 1).mean(1).view(-1)
        means = torch.cat((torch.zeros(99), means))
        plt.plot(means.numpy())

    plt.pause(0.001)  # pause a bit so that plots are updated
    if not show_result:
        display.display(plt.gcf())
        display.clear_output(wait=True)
    else:
        display.display(plt.gcf())

In [15]:
class ReplayMemory:
    """Stores the transitions that the agent observes, allowing us to reuse this data later.

    By sampling from it randomly, the transitions that build up a batch are decorrelated.
    It has been shown that this greatly stabilizes and improves the DQN training procedure.
    """

    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)
        self.transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))

    def push(self, *args):
        self.memory.append(self.transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

In [16]:
class DQN(torch.nn.Module):
    def __init__(self, n_states, n_actions):
        super(DQN, self).__init__()
        self.layer1 = torch.nn.Linear(n_states, 128)
        self.layer2 = torch.nn.Linear(128, 128)
        self.layer3 = torch.nn.Linear(128, n_actions)

    def forward(self, x):
        x = torch.nn.functional.relu(self.layer1(x))
        x = torch.nn.functional.relu(self.layer2(x))

        return self.layer3(x)

In [17]:
class Agent:
    """The main idea behind Q-learning is that, if we had a function Q*(s, a) that could tell us what our return R would be
    
    if we were to take an action in a given state, then we could easily construct a policy π that maximizes our rewards.
    Since we don't know everything about the world, we don't have access to Q*.
    But, since neural networks are universal function approximators, we can simply create one and train it to resemble Q*.
    """

    def __init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.env = gym.make("CartPole-v1")
        self.state_size = self.env.observation_space.shape[0]
        self.action_size = self.env.action_space.n

        self.episodes = 600

        self.batch_size = 128  # number of transitions sampled randomly from the replay buffer
        self.gamma = 0.99  # discount factor
        self.epsilon_start = 0.9  # starting value of epsilon (probability of choosing random action)
        self.epsilon_end = 0.05  # final value of epsilon
        self.epsilon_decay = 1000  # rate of exponential decay of epsilon (higher -> slower decay)
        self.tau = 0.005  # update rate of the target network
        self.lr = 1e-4  # learning rate of the optimizer

        self.policy_net = DQN(self.state_size, self.action_size).to(self.device)
        self.target_net = DQN(self.state_size, self.action_size).to(self.device)
        self.target_net.load_state_dict(self.policy_net.state_dict())

        self.optimizer = torch.optim.AdamW(self.policy_net.parameters(), lr=self.lr, amsgrad=True)
        self.memory = ReplayMemory(10000)

        self.episode_durations = []
        self.steps_done = 0

    def act(self, state):
        sample = random.random()
        eps_threshold = self.epsilon_end  + (self.epsilon_start - self.epsilon_end) * \
            math.exp(-1. * self.steps_done / self.epsilon_decay)
        self.steps_done += 1
        if sample > eps_threshold:
            with torch.no_grad():
                return self.policy_net(state).max(1).indices.unsqueeze(0) # max(1) returns largest value of each row
        else:
            return torch.tensor([[self.env.action_space.sample()]], dtype=torch.int8, device=self.device)

    def optimise(self):
        # Convert batch-array of Transitions to Transition of batch-arrays (https://stackoverflow.com/a/19343/3343043)
        transitions = self.memory.sample(self.batch_size)
        batch = self.memory.transition(*zip(*transitions))

        # Compute a mask of non-final states and concatenate the batch elements
        # (a final state would've been the one after which simulation ended)
        non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                            batch.next_state)), device=self.device, dtype=torch.bool)
        non_final_next_states = torch.cat([s for s in batch.next_state
                                                    if s is not None])
        state_batch = torch.cat(batch.state)
        action_batch = torch.cat(batch.action)
        reward_batch = torch.cat(batch.reward)

        # Compute Q(s_t, a_t) across the batch for all actions and 
        # select Q(s_t) based on the actions taken during the act() step.
        Q_policy = self.policy_net(state_batch).gather(1, action_batch).squeeze()

        # Compute V(s_{t+1}) for all next states.
        # Target values of actions for non_final_next_states are computed based
        # on the "older" target_net; selecting their best reward with max(1).values
        # Returns the next state value or 0 in case the state was final.
        next_state_values = torch.zeros(self.batch_size, device=self.device)
        with torch.no_grad():
            next_state_values[non_final_mask] = self.target_net(non_final_next_states).max(1).values

        # Compute the target Q values according to the Bellman equation
        Q_target = reward_batch + (self.gamma * next_state_values)

        # Compute Huber loss for the temporal difference error
        # (i.e. the difference between the predicted and the target Q values)
        criterion = torch.nn.SmoothL1Loss()
        loss = criterion(Q_policy, Q_target)

        # Optimize the model
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_value_(self.policy_net.parameters(), 100)
        self.optimizer.step()

    def soft_update(self):
        for policy_param, target_param in zip(self.policy_net.parameters(), self.target_net.parameters()):
            target_param.data.copy_(self.tau*policy_param.data + (1.0 - self.tau)*target_param.data)

    def train(self):
        for e in range(self.episodes):
            # Initialize the environment and get its state
            state = torch.tensor(self.env.reset()[0], dtype=torch.float32, device=self.device).unsqueeze(0)

            for t in count():
                action = self.act(state)
                observation, reward, done, _, _ = self.env.step(action.item())
                reward = torch.tensor(reward, device=self.device).unsqueeze(0)
                if done:
                    next_state = None
                else:
                    next_state = torch.tensor(observation, dtype=torch.float32, device=self.device).unsqueeze(0)

                # Store the transition in memory
                self.memory.push(state, action, next_state, reward)               

                if len(self.memory) >= self.batch_size:
                    # Optimise the policy network
                    self.optimise()

                    # Soft update the target network
                    self.soft_update()
                
                if done:
                    # self.episode_durations.append(t + 1)
                    # plot_durations(self.episode_durations)

                    if t < 500:
                        print(f"Episode: {e}. Reached {t} steps")
                    else:
                        print(f"Episode: {e}. Reached 500 steps. Saving the model.")
                        self.save("model_500_steps.pth")
                    break
                else:
                    state = next_state
                    
    def test(self, episodes=10, t_max=None, display=None):
        if display == "gui":
            self.env = gym.make("CartPole-v1", render_mode='human')
        elif display == "video":
            tmp_env = gym.make("CartPole-v1", render_mode='rgb_array')
            trigger = lambda t: t % 10 == 0
            self.env = gym.wrappers.RecordVideo(env=tmp_env, video_folder=".", name_prefix="cartpole", episode_trigger=trigger)

        self.policy_net.eval()
        for e in range(episodes):
            state = torch.tensor(self.env.reset()[0], dtype=torch.float32, device=self.device).unsqueeze(0)
            total_reward = 0

            for t in count():
                with torch.no_grad():
                    action = self.policy_net(state).max(1).indices.unsqueeze(0)
                observation, reward, done, _, _ = self.env.step(action.item())
                total_reward += reward

                if t % 100 == 0:
                    print(f"Step: {t}, Reward: {total_reward}")

                if done or t > t_max:
                    print(f"Episode: {e+1}/{episodes}, Score: {total_reward}")
                    break
                else:
                    state = torch.tensor(observation, dtype=torch.float32, device=self.device).unsqueeze(0)

        self.env.close()
                
    def save(self, name):
        torch.save(self.policy_net.state_dict(), name)

    def load(self, name):
        try:
            state_dict = torch.load(name, map_location=self.device)
            self.policy_net.load_state_dict(state_dict)
            self.target_net.load_state_dict(state_dict)
            print(f"Model loaded successfully from {name}")
        except Exception as e:
            print(f"Error loading the model: {e}")

In [None]:
agent = Agent()
agent.train()
agent.save("model_final.pth")

In [None]:
print('Complete')
plot_durations(agent.episode_durations, show_result=True)
plt.ioff()
plt.show()

In [None]:
agent = Agent()
agent.load("../weights/model_500_steps.pth")
agent.test(episodes=1, t_max=1000)