In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
import gymnasium as gym

In [2]:
class Actor(nn.Module):
    def __init__(self, n_obs, n_actions):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(n_obs, 128),
            nn.Tanh(),
            nn.Linear(128, n_actions),
            nn.Softmax(dim = -1)
        )

    def forward(self, x):
        return self.model(x)

class Critic(nn.Module):
    def __init__(self, n_obs):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(n_obs, 128),
            nn.Tanh(),
            nn.Linear(128, 1)
        )

    def forward(self, x):
        return self.model(x)

In [3]:
def compute_gae(rewards, values, dones, gamma, lam):
    returns = []
    gae = 0
    values = values + [torch.tensor(0.0)]
    for i in reversed(range(len(rewards))):
        delta = rewards[i] + gamma * values[i+1] * (1.0 - dones[i]) - values[i]
        gae = delta + gamma * lam * (1.0 - dones[i]) * gae
        returns.insert(0, gae + values[i])

    return returns        

def collect_samples(env, actor, critic, gamma, lam):
    states, actions, dones, log_probs, action_probs_collect = [], [], [], [], []

    # play episode
    state, _ = env.reset()
    rewards, values = [], []

    done = False
    while not done:
        
        state_tensor = torch.tensor(state, dtype = torch.float32)
        states.append(state_tensor)
        action_probs = actor(state_tensor.unsqueeze(0))
        action_probs_collect.append(action_probs.detach())
        dist = torch.distributions.Categorical(action_probs)
        
        action = dist.sample()
        actions.append(action.item())

        log_prob = dist.log_prob(action)
        log_probs.append(log_prob)

        value = critic(state_tensor.unsqueeze(0)).squeeze(0)
        values.append(value)

        state, reward, terminated, truncated, _ = env.step(action.item())
        done = terminated or truncated

        rewards.append(reward)
        dones.append(done)


    # aggregate collected data
    total_reward = sum(rewards)
    returns = compute_gae(rewards, values, dones, gamma, lam)
    
    advantages = torch.tensor(returns, dtype = torch.float32) - torch.tensor(values, dtype = torch.float32)
    advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)    

    # Convert to tensors
    states = torch.stack(states)
    actions = torch.tensor(actions)
    old_log_probs = torch.stack(log_probs).detach()
    old_probs = torch.stack(action_probs_collect).detach()
    returns = torch.tensor(returns, dtype = torch.float32)
    advantages = advantages.detach()

    return states, actions, old_log_probs, old_probs, returns, advantages, total_reward           

def train(env, actor, critic, gamma = 0.99, lam = 0.95, clip_eps = 0.2, lr = 0.001, entropy_coef = 0.001,  episodes = 1000, update_epochs = 5, batch_size = 64, print_reward_every = 1000.0):
    actor_optim = optim.Adam(actor.parameters(), lr = lr)
    critic_optim = optim.Adam(critic.parameters(), lr = lr)

    reward_collection = []

    for episode in range(episodes):
        states, actions, old_log_probs, old_probs, returns, advantages, total_reward = collect_samples(env, actor, critic, gamma, lam)
        reward_collection.append(total_reward)
        
        # cyclical policy update using collected data
        for _ in range(update_epochs):
            if states.shape[0] < batch_size:
                idx = torch.randperm(states.shape[0])
            else:
                idx = torch.randperm(batch_size)
                
            batch_states = states[idx]
            batch_actions = actions[idx]
            batch_old_log_probs = old_log_probs[idx]
            batch_old_probs = old_probs[idx]
            batch_returns = returns[idx]
            batch_advantages = advantages[idx]

            action_probs = actor(batch_states)
            dist = torch.distributions.Categorical(action_probs)
            new_log_probs = dist.log_prob(batch_actions)
            entropy = dist.entropy().mean()
                
            ratio = (new_log_probs - batch_old_log_probs).exp()
            
            surr1 = ratio * batch_advantages
            surr2 = torch.clamp(ratio, 1 - clip_eps, 1 + clip_eps) * batch_advantages
            actor_loss = - torch.min(surr1, surr2).mean() - entropy_coef * entropy

            critic_loss = nn.MSELoss()(critic(batch_states).squeeze(-1), batch_returns)
            
            actor_optim.zero_grad()
            actor_loss.backward()
            actor_optim.step()

            critic_optim.zero_grad()
            critic_loss.backward()
            critic_optim.step()

        if episode % print_reward_every == 0 and episode > 0:
            print(f"Episode {episode}, mean reward {sum(reward_collection[-10:]) / 10}")            

In [4]:
if __name__ == "__main__":
    env = gym.make("CartPole-v1")
    n_obs = env.observation_space.shape[0]
    n_actions = env.action_space.n

    actor = Actor(n_obs, n_actions)
    critic = Critic(n_obs)

    train(env, actor, critic, gamma = 0.99, lam = 0.95, clip_eps = 0.3, lr = 0.003, entropy_coef = 0.0001,  episodes = 2500, update_epochs = 5, batch_size = 128, print_reward_every = 100.0)

Episode 100, mean reward 28.9
Episode 200, mean reward 41.8
Episode 300, mean reward 57.9
Episode 400, mean reward 57.3
Episode 500, mean reward 80.8
Episode 600, mean reward 392.1
Episode 700, mean reward 369.6
Episode 800, mean reward 500.0
Episode 900, mean reward 454.7
Episode 1000, mean reward 498.8
Episode 1100, mean reward 444.1
Episode 1200, mean reward 500.0
Episode 1300, mean reward 476.7
Episode 1400, mean reward 492.1
Episode 1500, mean reward 500.0


KeyboardInterrupt: 