In [None]:
#import argparse
import gymnasium as gym
import numpy as np
from itertools import count
from collections import deque

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical

import matplotlib
import matplotlib.pyplot as plt
from IPython.display import Video
%matplotlib inline




In [None]:
plt.ion()
# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display
# parser = argparse.ArgumentParser(description='PyTorch REINFORCE example')
# parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
#                     help='discount factor (default: 0.99)')
# parser.add_argument('--seed', type=int, default=543, metavar='N',
#                     help='random seed (default: 543)')
# parser.add_argument('--render', action='store_true',
#                     help='render the environment')
# parser.add_argument('--log-interval', type=int, default=10, metavar='N',
#                     help='interval between training status logs (default: 10)')
# args = parser.parse_args()
gamma = 0.99

In [None]:
env = gym.make('CartPole-v1', render_mode="rgb_array")
# env.reset(seed=seed)
# torch.manual_seed(seed)
stat, info = env.reset()

In [None]:
class Policy(nn.Module):
    def __init__(self):
        super(Policy, self).__init__()
        self.affine1 = nn.Linear(4, 128)
        self.dropout = nn.Dropout(p=0.6)
        self.affine2 = nn.Linear(128, 2)

        self.saved_log_probs = []
        self.rewards = []

    def forward(self, x):
        x = self.affine1(x)
        x = self.dropout(x)
        x = F.relu(x)
        action_scores = self.affine2(x)
        return F.softmax(action_scores, dim=1)

policy = Policy()
optimizer = optim.Adam(policy.parameters(), lr=1e-2)
eps = np.finfo(np.float32).eps.item()

In [None]:
def select_action(state):
    state = torch.from_numpy(state).float().unsqueeze(0)
    probs = policy(state)
    m = Categorical(probs)
    action = m.sample()
    policy.saved_log_probs.append(m.log_prob(action))
    return action.item()

In [None]:
def finish_episode():
    R = 0
    policy_loss = []
    returns = deque()
    for r in policy.rewards[::-1]:
        R = r + gamma * R
        returns.appendleft(R)
    returns = torch.tensor(returns)
    returns = (returns - returns.mean()) / (returns.std() + eps)
    for log_prob, R in zip(policy.saved_log_probs, returns):
        policy_loss.append(-log_prob * R)
    optimizer.zero_grad()
    policy_loss = torch.cat(policy_loss).sum()
    policy_loss.backward()
    optimizer.step()
    del policy.rewards[:]
    del policy.saved_log_probs[:]

In [None]:
episode_durations = []

def plot_durations(show_result=False):
    plt.figure(1)
    durations_t = torch.tensor(episode_durations, dtype=torch.float)
    if show_result:
        plt.title('Result')
    else:
        plt.clf()
        plt.title('Training...')
    plt.xlabel('Episode')
    plt.ylabel('Duration')
    plt.plot(durations_t.numpy())
    # Take 100 episode averages and plot them too
    if len(durations_t) >= 100:
        means = durations_t.unfold(0, 100, 1).mean(1).view(-1)
        means = torch.cat((torch.zeros(99), means))
        plt.plot(means.numpy())

    plt.pause(0.001)  # pause a bit so that plots are updated
    if is_ipython:
        if not show_result:
            display.display(plt.gcf())
            display.clear_output(wait=True)
        else:
            display.display(plt.gcf())

In [None]:
if torch.cuda.is_available() or torch.backends.mps.is_available():
    num_episodes = 600
else:
    num_episodes = 50
    
first_episode_images = []
last_episode_images = []

def main():
    running_reward = 10
    for i_episode in range(num_episodes):
        state, _ = env.reset()
        ep_reward = 0
        global last_episode_images
        last_episode_images = []
        for t in range(1, 10000):  # Don't infinite loop while learning
            action = select_action(state)
            state, reward, done, _, _ = env.step(action)
#             if args.render:
#                 env.render()
            policy.rewards.append(reward)
            ep_reward += reward
            if i_episode == 0:
                first_episode_images.append(env.render())
            elif i_episode != 0:
                last_episode_images.append(env.render())

            if done:
                episode_durations.append(t+1)
                break

        running_reward = 0.05 * ep_reward + (1 - 0.05) * running_reward
        finish_episode()
        plot_durations()
        
        if running_reward > env.spec.reward_threshold:
#             print("Solved! Running reward is now {} and "
#                   "the last episode runs to {} time steps!".format(running_reward, t))
            break


if __name__ == '__main__':
    main()
    print('Complete')
    plot_durations(show_result=True)
    plt.ioff()
    plt.show()

In [None]:
def write_first_and_last_episode_gifs(first_episode_images, last_episode_images):
    import imageio
    imageio.mimwrite("./images/reinforce_cartpole_first_episode.gif", first_episode_images, format="gif",loop=0,duration=50)
    imageio.mimwrite("./images/reinforce_cartpole_last_episode.gif", last_episode_images, format="gif", loop=0, duration=50)

In [None]:
def display_episode(episode_filename):
    from IPython.display import display, Image
    with open(episode_filename, 'rb') as f:
        display(Image(f.read(), format='gif'))

In [None]:
print("Length of first episode (in frames): ", len(first_episode_images))
print("Length of last episode (in frames): ", len(last_episode_images))

In [None]:
write_first_and_last_episode_gifs(first_episode_images, last_episode_images)

In [None]:
display_episode('./images/reinforce_cartpole_first_episode.gif')

In [None]:
display_episode('./images/reinforce_cartpole_last_episode.gif')