In [1]:
# from https://gymnasium.farama.org/tutorials/training_agents/reinforce_invpend_gym_v26/
import gymnasium as gym
import torch 
import torch.nn as nn
import numpy as np

from torch.utils.tensorboard import SummaryWriter

In [2]:
# env = gym.make("CartPole-v1", render_mode='human')
env = gym.make("CartPole-v1")
#%load_ext tensorboard
#wandb.init(project="cartpole-v1", entity="bpanthi977")

In [12]:
def evaluate_agent(agent, steps=100):
    observation, info = env.reset()
    total_reward = 0
    total_episodes = 0
    for _ in range(steps):
        action = agent.action(observation)
        observation, reward, terminated, truncated, info = env.step(action)
        
        total_reward += reward
        # env.render()
        if terminated or truncated:
            observation, info = env.reset()
            total_episodes += 1
            
    return total_reward/total_episodes

In [14]:
class RandomAgent():
    def action(self, state):
        return env.action_space.sample()

In [15]:
evaluate_agent(RandomAgent(), 100)

33.333333333333336

In [3]:
class CartPoleAgent(nn.Module):
    def __init__(self):
        super(CartPoleAgent, self).__init__()
        input_dim = 4
        out_dim = 2
        
        self.net = nn.Sequential(
            nn.Linear(in_features=input_dim, out_features=512),
            nn.ReLU(),
            nn.Linear(in_features=512, out_features=256),
            nn.ReLU(),
            nn.Linear(in_features=256, out_features=128),
            nn.ReLU(),
            nn.Linear(in_features=128, out_features=out_dim),
            nn.Softmax(dim=0)
        )
        
    def forward(self, x):
        return self.net(x)
    
    def action(self, state):
        probs = self.forward(torch.tensor(state).to(device))
        action = np.random.choice([0,1], p=probs.detach().cpu().numpy())
        self.action_probs = probs
        
        return action

device = torch.device('mps')   
def reset_network():
    global network, optim, total_episodes, writer
    network = CartPoleAgent().to(device)
    optim = torch.optim.Adam(network.parameters())
    total_episodes = 0
    writer = SummaryWriter()

reset_network()

In [74]:
evaluate_agent(network, 1_000)

16.666666666666668

In [1]:
GAMMA=0.99
ENTROPY_BETA=0.1
def train_reinforce(steps):
    global total_episodes
    observation, info = env.reset(seed=42)
    
    total_reward = 0
    episode_steps = 0

    par = []
    
    def train(par):
        # REINFORCE Update
        # compute returns in backward order and compute loss
        g_t = 0
        loss = 0
        for t in range(len(par)-1, -1, -1):
            prob, action, reward = par[t]
            g_t = reward + GAMMA * g_t

            log_prob = torch.log(prob)
            # L = - G ln \pi(a)
            loss += - g_t * log_prob[action]
            ## L = entropy penalty
            entropy = - (prob * log_prob).sum()         # H = - \sum p_i log p_i
            loss += - ENTROPY_BETA * entropy          # increase entropy 
        
        if g_t == 0:
            return 

        optim.zero_grad()
        loss.backward()
        optim.step()
            
    for step in range(steps):
        action = network.action(observation)
        observation, reward, terminated, truncated, info = env.step(action)

        par.append([network.action_probs, action, reward])
        
        total_reward += reward
        episode_steps += 1
        if terminated or truncated:
            train(par)
            par = []
            observation, info = env.reset()
            total_episodes += 1
            writer.add_scalar("Reward/Episode", episode_steps, total_episodes)
            episode_steps = 0
            
    train(par)
            
    avg_reward = total_reward/total_episodes
    writer.add_hparams({'entropy': ENTROPY_BETA, 'gamma': GAMMA, 'baseline': False}, {'avg_reward': avg_reward})
    return avg_reward

In [20]:
reset_network()
train_reinforce(30_000)

42.918454935622314

In [28]:
evaluate_agent(network, 1_000)

500.0