<a href="https://colab.research.google.com/github/keeeehun/RL/blob/main/REINFORCE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import gym
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical

from IPython.display import clear_output
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
class CartpolePolicy(nn.Module):
    def __init__(self, obs_dim, n_actions, device):
        super(CartpolePolicy, self).__init__()

        self.obs_dim = obs_dim
        self.n_actions = n_actions
        self.device = device

        self.l1 = nn.Linear(obs_dim, 64)
        self.l2 = nn.Linear(64, 128)
        self.l3 = nn.Linear(128, n_actions)

    def forward(self, x):
        x = F.relu(self.l1(x))
        x = F.relu(self.l2(x))
        x = self.l3(x)
        x = F.softmax(x, dim=-1)

        return x

    def get_action_logprob(self, obs):
        if isinstance(obs, np.ndarray):
            obs = torch.tensor(obs, dtype=torch.float32)
        obs = obs.to(self.device)
        output = self.forward(obs)
        categorical = Categorical(output)
        action = categorical.sample()
        logprob = categorical.log_prob(action)

        return action.item(), logprob

In [None]:
def train(logprobs, returns, optim):
    optim.zero_grad()
    # Cumulate gradients
    for ret, logprob in zip(returns, logprobs):
        j = -1 * logprob * ret
        j.backward()
    optim.step()

In [None]:
def plot(rewards):
    clear_output(True)
    plt.figure(figsize=(20,5))
    # plt.title('')
    plt.ylabel('reward')
    plt.xlabel('episode')
    plt.plot(rewards)
    plt.show()

In [None]:
# Save device (cpu or cuda)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# Make env
env = gym.make("CartPole-v0")

# shape
obs_shape = env.observation_space.shape
act_shape = tuple([1]) # int to Tuple
obs_dim = obs_shape[0]
n_actions = env.action_space.n
buffer_size = 1000

# Define NN
policy_model = CartpolePolicy(obs_dim=obs_dim, n_actions=n_actions, device=device)
policy_model.to(device)

# Optimizer
optimizer = optim.Adam(policy_model.parameters())

# training parameters
total_step = 60000
batch_size = 32
gamma = 0.99

# list for saving results
rewards = []

# list for saving return & logprobs
logprobs = []
returns = []

In [None]:
episode_reward = 0
best_test_reward = -999999
obs = env.reset()
for t in range(total_step):
    action, logprob = policy_model.get_action_logprob(obs)
    next_obs, rew, done, _ = env.step(action)
    logprobs.append(logprob)
    returns.append(rew)
    episode_reward += rew
    obs = next_obs

    if done:
        # Save the best model so far
        if best_test_reward <= episode_reward:
            best_test_reward = episode_reward
            torch.save(policy_model.state_dict(), "policy_model.pt")

        # complete returns
        for i in range(len(returns)-2, -1, -1):
            returns[i] += returns[i+1]*gamma
        rewards.append(episode_reward)
        train(logprobs, returns, optimizer)
        logprobs = []
        returns = []
        obs = env.reset()
        episode_reward = 0
    
    if t % 1000 == 0:
        plot(rewards)

In [None]:
!apt update
!apt install xvfb
!pip install pyvirtualdisplay
!pip install gym-notebook-wrapper

In [None]:
import gnwrapper
import gym

env = gnwrapper.LoopAnimation(gym.make('CartPole-v0'))

obs = env.reset()
for _ in range(500):
    obs, rew, done, _ = env.step(env.action_space.sample()) # Take random action
    env.render()
    if done:
        obs = env.reset()

env.display()

In [None]:
policy_model.load_state_dict(torch.load("policy_model.pt"))

env = gnwrapper.LoopAnimation(gym.make('CartPole-v0'))

obs = env.reset()

for _ in range(500):
    action, _ = policy_model.get_action_logprob(obs)  # Take action from trained model
    env.render()
    obs, rew, done, _ = env.step(action)
    if done:
        obs = env.reset()

env.display()