# Records CartPole Performance

**Please expand the cells to view the code!**

### Description
This notebook records an episode of the CartPole when using a saved .pt model and saved it as an .mp4 file. 

### How to Run
Please make sure to change the model directory to the model you want to run, and change the video directory to where you would like the video to be saved.

### Citations
Kang, C., 2021. REINFORCE on CartPole-v0 [Online]. Chan`s Jupyter. Available from: https://goodboychan.github.io/python/reinforcement_learning/pytorch/udacity/2021/05/12/REINFORCE-CartPole.html [Accessed 8 May 2024].


In [None]:
import gym
import numpy as np
from collections import deque
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = (16, 10)

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical
torch.manual_seed(0)

import base64, io

# For visualization
from gym.wrappers.monitoring import video_recorder
from IPython.display import HTML
from IPython import display 
import glob


class DDDQNet(nn.Module):
    def __init__(self):
        super().__init__()

        self.fc1 = nn.Linear(4, 64)
        self.relu = nn.ReLU()

        self.fc_value = nn.Linear(64, 256)
        self.value = nn.Linear(256, 1)

        self.fc_adv = nn.Linear(64, 256)
        self.adv = nn.Linear(256, 2)

    def forward(self, x):
        y = self.relu(self.fc1(x))
        value = self.relu(self.fc_value(y))
        adv = self.relu(self.fc_adv(y))

        value = self.value(value)
        adv = self.adv(adv)

        output = value + adv - torch.mean(adv, dim=1, keepdim=True)

        return output

    def select_action(self, x):
        with torch.no_grad():
            Q = self.forward(x)
            action_index = torch.argmax(Q, dim=1)
        return action_index.item()
    
class Net(nn.Module):

    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(4, 32)
        self.layer2 = nn.Linear(32, 2)

    # Called with either one element to determine next action, or a batch
    # during optimization. Returns tensor([[left0exp,right0exp]...]).
    def forward(self, state):
        x = F.relu(self.layer1(state))
        x = self.layer2(x)
        return x

In [None]:
def record_video_of_model(model_path,
                        model_name, episodes, env_name="CartPole-v1", agent="CartPole"):
    if model_name == "reinforce":
        policy = Net()
    else:
        policy = DDDQNet()

    policy = torch.load(model_path)
    env = gym.make(env_name, render_mode="rgb_array")
    vid = video_recorder.VideoRecorder(env, path=f"./video/{agent}_{model_name}_{episodes}.mp4")
    state, _ = env.reset()
    done = False
    reward = 0
    for t in range(2000):
        vid.capture_frame()
        tensor_state = torch.from_numpy(state).float().unsqueeze(0)

        if model_name == "reinforce":
            action_logits = policy(tensor_state)
            action = torch.argmax(action_logits, dim=1).item()
            next_state, reward, done, _, _ = env.step(action)
            reward += reward
        else: 
            action = policy.select_action(tensor_state)
            next_state, reward, done, _, _ = env.step(action)
            reward += reward

        state = next_state
        if done:
            print("Reward earned: ", reward)
            print("t: ", t)
            break

    vid.close()
    env.close()

    
record_video_of_model(model_path="./models/random_replay/CN8e-05_LR1e-05_B128/E1500.pt",
                        model_name="dddqn_random_rep",
                        episodes=1500)




In [None]:

record_video_of_model(model_path="./models/prioritised_replay/CN8e-05_LR1e-05_B128/E_2600.pt",
                    model_name="dddqn_prioritised_rep",
                    episodes=2600)




In [None]:
record_video_of_model(model_path="./models/reinforce/LR0.001/E_1700.pt",
                    model_name="reinforce",
                    episodes=1700)

In [None]:
record_video_of_model(model_path="./models/reinforce/LR0.001/E_1400.pt",
                    model_name="reinforce",
                    episodes=1400)