In [None]:
import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical

# Model

In [None]:
class REINFORCE(nn.Module):
    def __init__(self):
        super(REINFORCE, self).__init__()
        self.linear_1 = nn.Linear(4, 128)
        self.linear_2 = nn.Linear(128, 2)
        self.optimizer = optim.Adam(self.parameters(), lr=learning_rate)
        self.data = []
        
    def policy(self, state):
        x = F.relu(self.linear_1(state))
        prob = F.softmax(self.linear_2(x), dim=0)
        return prob
    
    def get_action(self, state):
        prob = self.policy(torch.from_numpy(state).float())
        dist = Categorical(prob)
        action = dist.sample()
        return action.item(), prob
      
    def save(self, item):
        self.data.append(item)
        
    def update(self):
        G = 0
        self.optimizer.zero_grad()
        for reward, prob in self.data[::-1]:
            G = reward + gamma * G
            loss = - torch.log(prob) * G
            loss.backward()
        self.optimizer.step()
        self.data = []

# Train

In [None]:
# Hyperparameters
learning_rate = 0.0002
gamma         = 0.98

# Run configurations
print_every = 100
num_episodes = 10000

In [None]:
def train(agent, env):
    G = 0.0
    for n_epi in range(num_episodes):
        state = env.reset()
        done = False
        
        while not done:
            action, prob = agent.get_action(state)
            next_state, reward, done, _ = env.step(action)
            agent.save((reward, prob[action]))
            state = next_state
            G += reward
            
        agent.update()
        
        if n_epi % print_every==0 and n_epi!=0:
            avg_return = G / print_every
            print("# of episode :{}, avg score : {}".format(n_epi, avg_return))
            G = 0.0
            if avg_return >= 200:
                break
    env.close()

In [None]:
cartpole = gym.make('CartPole-v0')
reinforce = REINFORCE()

train(reinforce, cartpole)

# of episode :20, avg score : 22.35
# of episode :40, avg score : 21.4
# of episode :60, avg score : 24.5
# of episode :80, avg score : 23.5
# of episode :100, avg score : 25.9
# of episode :120, avg score : 23.6
# of episode :140, avg score : 18.95
# of episode :160, avg score : 22.75
# of episode :180, avg score : 27.1
# of episode :200, avg score : 24.9
# of episode :220, avg score : 29.45
# of episode :240, avg score : 35.05
# of episode :260, avg score : 35.4
# of episode :280, avg score : 33.6
# of episode :300, avg score : 28.65
# of episode :320, avg score : 34.1
# of episode :340, avg score : 34.3
# of episode :360, avg score : 46.15
# of episode :380, avg score : 44.3
# of episode :400, avg score : 54.85
# of episode :420, avg score : 38.9
# of episode :440, avg score : 40.4
# of episode :460, avg score : 44.55
# of episode :480, avg score : 49.75
# of episode :500, avg score : 39.4
# of episode :520, avg score : 44.25
# of episode :540, avg score : 50.6
# of episode :560, av