# REINFORCE


In [None]:
!apt-get install -y xvfb

!pip install \
  pygame \
  gym==0.23.1 \
  pytorch-lightning==1.6 \
  pyvirtualdisplay

#### Setup virtual display

In [None]:
from pyvirtualdisplay import Display
Display(visible=False, size=(1400, 900)).start()

#### Import the necessary code libraries

In [None]:
import copy
import torch
import random
import gym
import matplotlib

import numpy as np
import matplotlib.pyplot as plt

import torch.nn.functional as F

from collections import deque, namedtuple
from IPython.display import HTML
from base64 import b64encode

from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data.dataset import IterableDataset
from torch.optim import AdamW

from pytorch_lightning import LightningModule, Trainer

from gym.wrappers import RecordVideo, RecordEpisodeStatistics, \
  NormalizeObservation, NormalizeReward


device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
num_gpus = torch.cuda.device_count()

In [None]:
def plot_policy(policy):
  pos = np.linspace(-4.8, 4.8, 100)
  vel = np.random.random(size=(10000, 1)) * 0.1
  ang = np.linspace(-0.418, 0.418, 100)
  ang_vel = np.random.random(size=(10000, 1)) * 0.1

  g1, g2 = np.meshgrid(pos, ang)
  grid = np.stack((g1,g2), axis=-1)
  grid = grid.reshape(-1, 2)
  grid = np.hstack((grid, vel, ang_vel))

  probs = policy(grid).detach().numpy()
  probs_left = probs[:, 0]

  probs_left = probs_left.reshape(100, 100)
  probs_left = np.flip(probs_left, axis=1)

  plt.figure(figsize=(8, 8))
  plt.imshow(probs_left, cmap='coolwarm')
  plt.colorbar()
  plt.clim(0, 1)
  plt.title("P(left | s)", size=20)
  plt.xlabel("Cart Position", size=14)
  plt.ylabel("Pole angle", size=14)
  plt.xticks(ticks=[0, 50, 100], labels=['-4.8', '0', '4.8'])
  plt.yticks(ticks=[100, 50, 0], labels=['-0.418', '0', '0.418'])


def test_env(env_name, policy, obs_rms):
  env = gym.make(env_name)
  env = RecordVideo(env, 'videos', episode_trigger=lambda e: True)
  env = NormalizeObservation(env)
  env.obs_rms = obs_rms

  for episode in range(10):
    done = False
    obs = env.reset()
    while not done:
      action = policy(obs).multinomial(1).cpu().item()
      obs, _, done, _ = env.step(action)
  del env


def display_video(episode=0):
  video_file = open(f'/content/videos/rl-video-episode-{episode}.mp4', "r+b").read()
  video_url = f"data:video/mp4;base64,{b64encode(video_file).decode()}"
  return HTML(f"<video width=600 controls><source src='{video_url}'></video>")

#### Create the policy

#### Plot the untrained policy

#### Create the environment

#### Create the dataset

#### Create the REINFORCE algorithm

#### Purge logs and run the visualization tool (Tensorboard)

In [None]:
!rm -r /content/lightning_logs/
!rm -r /content/videos/
%load_ext tensorboard
%tensorboard --logdir /content/lightning_logs/

#### Train the policy

#### Check the resulting policy

#### Plot the trained policy