# REINFORCE in PyTorch

In [1]:
# Imports

from torch.distributions import Categorical
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

In [2]:
# Discount parameter
gamma = 0.99

In [8]:
# Policy Class

class Pi(nn.Module):
    def __init__(self, in_dim, out_dim):
        super(Pi, self).__init__()
        layers = [
            nn.Linear(in_dim, 64),
            nn.ReLU(),
            nn.Linear(64, out_dim),
        ]
        self.model = nn.Sequential(*layers)
        self.onpolicy_reset()
        self.train()
        
    def onpolicy_reset(self):
        self.log_probs = []
        self.rewards = []
        
    def forward(self, x):
        pdparam = self.model(x)
        return pdparam
    
    def act(self, state):
        x = torch.from_numpy(state.astype(np.float32))
        pdparam = self.forward(x)
        pd = Categorical(logits=pdparam)
        action = pd.sample()
        log_prob = pd.log_prob(action)
        self.log_probs.append(log_prob)
        return action.item()

In [9]:
# Train Function

def train(pi, optimizer):
    T = len(pi.rewards)
    rets = np.empty(T, dtype=np.float32)
    future_ret = 0.0
    for t in reversed(range(T)):
        future_ret = pi.rewards[t] + gamma * future_ret
        rets[t] = future_ret
    rets = torch.tensor(rets)
    log_probs = torch.stack(pi.log_probs)
    loss = - log_probs * rets
    loss = torch.sum(loss)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss

In [10]:
# Main Function

def main():
    env = gym.make('CartPole-v0')
    in_dim = env.observation_space.shape[0]
    out_dim = env.action_space.n
    pi = Pi(in_dim, out_dim)
    optimizer = optim.Adam(pi.parameters(), lr=0.01)
    for epi in range(300):
        state = env.reset()
        for t in range(200): # cartpole max timestep is 200
            action = pi.act(state)
            state, reward, done, _ = env.step(action)
            pi.rewards.append(reward)
            env.render()
            if done:
                break
        loss = train(pi, optimizer)
        total_reward = sum(pi.rewards)
        solved = total_reward > 195.0
        pi.onpolicy_reset()
        print("Episode: {}\tLoss: {}\tTotal Reward: {}\tSolved: {}".format(epi,loss, total_reward, solved))

In [11]:
# Run
main()

Episode: 0	Loss: 23.724102020263672	Total Reward: 9.0	Solved: False
Episode: 1	Loss: 30.376394271850586	Total Reward: 11.0	Solved: False
Episode: 2	Loss: 86.03150177001953	Total Reward: 14.0	Solved: False
Episode: 3	Loss: 12.686543464660645	Total Reward: 9.0	Solved: False
Episode: 4	Loss: 29.635496139526367	Total Reward: 10.0	Solved: False
Episode: 5	Loss: 39.673465728759766	Total Reward: 11.0	Solved: False
Episode: 6	Loss: 84.25299072265625	Total Reward: 16.0	Solved: False
Episode: 7	Loss: 41.09261703491211	Total Reward: 12.0	Solved: False
Episode: 8	Loss: 39.05577087402344	Total Reward: 12.0	Solved: False
Episode: 9	Loss: 163.9733123779297	Total Reward: 22.0	Solved: False
Episode: 10	Loss: 51.59680938720703	Total Reward: 14.0	Solved: False
Episode: 11	Loss: 1007.4610595703125	Total Reward: 58.0	Solved: False
Episode: 12	Loss: 71.6487808227539	Total Reward: 15.0	Solved: False
Episode: 13	Loss: 94.16250610351562	Total Reward: 17.0	Solved: False
Episode: 14	Loss: 59.10795593261719	Total

Episode: 121	Loss: 6061.29296875	Total Reward: 200.0	Solved: True
Episode: 122	Loss: 5921.93505859375	Total Reward: 200.0	Solved: True
Episode: 123	Loss: 5780.65087890625	Total Reward: 197.0	Solved: True
Episode: 124	Loss: 5215.279296875	Total Reward: 173.0	Solved: False
Episode: 125	Loss: 6103.06689453125	Total Reward: 200.0	Solved: True
Episode: 126	Loss: 6157.63818359375	Total Reward: 200.0	Solved: True
Episode: 127	Loss: 6141.435546875	Total Reward: 200.0	Solved: True
Episode: 128	Loss: 5836.3857421875	Total Reward: 200.0	Solved: True
Episode: 129	Loss: 6362.25927734375	Total Reward: 200.0	Solved: True
Episode: 130	Loss: 5989.697265625	Total Reward: 200.0	Solved: True
Episode: 131	Loss: 5893.7666015625	Total Reward: 200.0	Solved: True
Episode: 132	Loss: 6246.2109375	Total Reward: 200.0	Solved: True
Episode: 133	Loss: 851.380615234375	Total Reward: 59.0	Solved: False
Episode: 134	Loss: 6442.1953125	Total Reward: 200.0	Solved: True
Episode: 135	Loss: 6024.91064453125	Total Reward: 20

Episode: 240	Loss: 1943.4520263671875	Total Reward: 117.0	Solved: False
Episode: 241	Loss: 2115.682861328125	Total Reward: 131.0	Solved: False
Episode: 242	Loss: 1704.8128662109375	Total Reward: 120.0	Solved: False
Episode: 243	Loss: 2035.15625	Total Reward: 133.0	Solved: False
Episode: 244	Loss: 1464.816650390625	Total Reward: 112.0	Solved: False
Episode: 245	Loss: 1749.2225341796875	Total Reward: 111.0	Solved: False
Episode: 246	Loss: 2426.34326171875	Total Reward: 146.0	Solved: False
Episode: 247	Loss: 1741.4454345703125	Total Reward: 113.0	Solved: False
Episode: 248	Loss: 1494.2987060546875	Total Reward: 108.0	Solved: False
Episode: 249	Loss: 1385.2177734375	Total Reward: 113.0	Solved: False
Episode: 250	Loss: 1577.0018310546875	Total Reward: 112.0	Solved: False
Episode: 251	Loss: 1627.33154296875	Total Reward: 103.0	Solved: False
Episode: 252	Loss: 1827.466552734375	Total Reward: 100.0	Solved: False
Episode: 253	Loss: 1798.61669921875	Total Reward: 129.0	Solved: False
Episode: 254