In [40]:
import numpy as np
import torch
import gym
from torch import nn
from torch.nn import functional as F
import matplotlib.pyplot as plt

In [41]:
def mish(input):
    return input * torch.tanh(F.softplus(input))

class Mish(nn.Module):
    def __init__(self): super().__init__()
    def forward(self, input): return mish(input)

In [42]:
# helper function to convert numpy arrays to tensors
def t(x): return torch.from_numpy(x).float()

In [43]:
# Actor module, categorical actions only
class Actor(nn.Module):
    def __init__(self, state_dim, n_actions, activation=nn.Tanh):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(state_dim, 64),
            activation(),
            nn.Linear(64, 64),
            activation(),
            nn.Linear(64, n_actions),
            nn.Softmax()
        )
    
    def forward(self, X):
        return self.model(X)

In [44]:
# Critic module
class Critic(nn.Module):
    def __init__(self, state_dim, activation=nn.Tanh):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(state_dim, 64),
            activation(),
            nn.Linear(64, 64),
            activation(),
            nn.Linear(64, 1)
        )
    
    def forward(self, X):
        return self.model(X)

In [45]:
env = gym.make("CartPole-v1")

In [46]:
# config
state_dim = env.observation_space.shape[0]
n_actions = env.action_space.n
actor = Actor(state_dim, n_actions, activation=Mish)
critic = Critic(state_dim, activation=Mish)
adam_actor = torch.optim.Adam(actor.parameters(), lr=3e-4)
adam_critic = torch.optim.Adam(critic.parameters(), lr=1e-3)

torch.manual_seed(1)

<torch._C.Generator at 0x1080d0190>

In [47]:
def clip_grad_norm_(module, max_grad_norm):
    nn.utils.clip_grad_norm_([p for g in module.param_groups for p in g["params"]], max_grad_norm)

def policy_loss(old_log_prob, log_prob, advantage, eps):
    ratio = (log_prob - old_log_prob).exp()
    clipped = torch.clamp(ratio, 1-eps, 1+eps)*advantage
    
    m = torch.min(ratio*advantage, clipped)
    return -m

In [48]:
episode_rewards = []
gamma = 0.99
eps = 0.2
s = 0
max_grad_norm = 0.5

for i in range(800):
    prev_prob_act = None
    done = False
    total_reward = 0
    state, _ = env.reset()

    while not done:
        s += 1
        probs = actor(t(state))
        dist = torch.distributions.Categorical(probs=probs)
        action = dist.sample()
        prob_act = dist.log_prob(action)
        
        next_state, reward, done, _, info = env.step(action.detach().data.numpy())
        advantage = reward + (1-done)*gamma*critic(t(next_state)) - critic(t(state))
        
        total_reward += reward
        state = next_state
        
        if prev_prob_act:
            actor_loss = policy_loss(prev_prob_act.detach(), prob_act, advantage.detach(), eps)
            adam_actor.zero_grad()
            actor_loss.backward()
            # clip_grad_norm_(adam_actor, max_grad_norm)
            adam_actor.step()

            critic_loss = advantage.pow(2).mean()
            adam_critic.zero_grad()
            critic_loss.backward()
            # clip_grad_norm_(adam_critic, max_grad_norm)
            adam_critic.step()
        
        prev_prob_act = prob_act
    
    print(f"Episode {i+1} finished after {s} steps with total reward {total_reward}")
    episode_rewards.append(total_reward)

  return self._call_impl(*args, **kwargs)


Episode 1 finished after 26 steps with total reward 26.0
Episode 2 finished after 50 steps with total reward 24.0
Episode 3 finished after 71 steps with total reward 21.0
Episode 4 finished after 87 steps with total reward 16.0
Episode 5 finished after 107 steps with total reward 20.0
Episode 6 finished after 132 steps with total reward 25.0
Episode 7 finished after 153 steps with total reward 21.0
Episode 8 finished after 201 steps with total reward 48.0
Episode 9 finished after 223 steps with total reward 22.0
Episode 10 finished after 254 steps with total reward 31.0
Episode 11 finished after 275 steps with total reward 21.0
Episode 12 finished after 296 steps with total reward 21.0
Episode 13 finished after 308 steps with total reward 12.0
Episode 14 finished after 320 steps with total reward 12.0
Episode 15 finished after 336 steps with total reward 16.0
Episode 16 finished after 366 steps with total reward 30.0
Episode 17 finished after 381 steps with total reward 15.0
Episode 18