# Policy gradients

## But first, Gymnasium

From here on, we will do everything in Gymnasium format

In [None]:
import gymnasium as gym
from gymnasium import spaces

class CustomEnv(gym.Env):
    """
    A custom Gymnasium environment skeleton.
    """
    metadata = {"render.modes": ["human"]}  # Optional, define render modes
    
    def __init__(self):
        super().__init__()
        
        # Define action and observation spaces
        # Example: Discrete action space with 2 actions (0, 1)
        self.action_space = spaces.Discrete(2)
        
        # Example: Continuous observation space (1D array of size 3)
        self.observation_space = spaces.Box(low=-1.0, high=1.0, shape=(3,), dtype=np.float32)
        
        # Initialize environment state
        self.state = None

    def reset(self, seed=None, options=None):
        """
        Reset the environment to an initial state.
        """
        super().reset(seed=seed)
        
        # Example: Initialize the state
        self.state = np.zeros(self.observation_space.shape, dtype=np.float32)
        
        # Return the initial observation, and a possible dictionary with information
        return self.state, {'message':'successfully reset!', 'hello':'world'}

    def step(self, action):
        """
        Perform one step in the environment.
        """
        
        # Update the environment state (example logic)
        self.state = self.state + action - 0.5  # Dummy dynamics
        
        # Compute the reward (example logic)
        reward = -np.sum(np.square(self.state))  
        
        # Check if the episode is terminated (example logic) -> reach terminal state of MDP
        terminated = np.linalg.norm(self.state) > 10.0  

        # Flag if the episode was truncated -> terminated not because reaching MDP endpoint, but e.g. because too many steps taken
        truncated = False 
        
        # Additional info (can be empty)
        info = {}
        
        return self.state, reward, terminated, truncated, info

    def render(self, mode="human"):
        """
        Render the environment (optional).
        """
        print(f"State: {self.state}")

    def close(self):
        """
        Clean up resources (optional).
        """
        pass

#Example usage:
env = CustomEnv()
obs, info = env.reset()
print("Initial Observation:", obs)
action = env.action_space.sample()
obs, reward, terminated, truncated, info = env.step(action)
done = terminated or truncated
print("Step:", obs, reward, done, info)

while not done:
    obs, reward, terminated, truncated, info = env.step(action)
    done = terminated or truncated
    print("Step:", obs, reward, done, info)

## Now on to REINFORCE

### Another standard benchmark: the cartpole problem

In [None]:
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

# Set up the CartPole environment
env = gym.make("CartPole-v1")#, render_mode="human")#rgb_array")

# Check state and action space
state_dim = ...
action_dim = ...

print(f"State dimension: {state_dim}, Action dimension: {action_dim}")

In [None]:
class PolicyNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(state_dim, 128),
            nn.ReLU(),
            nn.Linear(128, action_dim),
            nn.Softmax(dim=-1)  # Output probabilities for actions
        )
    
    def forward(self, state):
        return self.fc(state)

In [None]:
def select_action(policy, state):
    state = torch.tensor(state, dtype=torch.float32)  # Convert state to tensor

    # Predict action probabilities using policy network
    action_probs = ...

    # Here is a trick: use the Categorial distribution so we can use '.sample()'
    # Without it, i.e. if we just use the softmax, we would use np.random.choice perhaps
    # Similarly, Categorical provides the .log_prob(), which we would have to do separately otherwise
    # (And finally, Categorical works nicely if we wanted to use batches)
    
    action_dist = torch.distributions.Categorical(action_probs)  # Categorical distribution
    action = action_dist.sample()                    # Sample an action
    return action.item(), action_dist.log_prob(action)

In [None]:
# Initialize policy network
policy = PolicyNetwork(state_dim, action_dim)

# Test action sampling
state, info = env.reset()
action, log_prob = select_action(policy, state)
print(f"Sampled action: {action}, Log-probability: {log_prob}")

In [None]:
def collect_trajectory(policy, env, max_steps=200):
    states, actions, log_probs, rewards = [], [], [], []
    state, info = env.reset()

    # The trajectory can be max_steps long
    for _ in range(max_steps):

        # For EACH step, we want to store the action and log( pi(a|s) )
        action, log_prob = select_action(policy, state)
        next_state, reward, terminated, truncated, info = env.step(action)
        
        # Store data
        states.append(state)
        actions.append(action)
        log_probs.append(log_prob)
        rewards.append(reward)

        done = terminated or truncated
        if done:
            break

        # Prepare state for next iteration of for loop
        state = next_state
    
    return states, actions, log_probs, rewards

In [None]:
def compute_returns(rewards, gamma=0.99):
    returns = []

    # Should return G_0, G_1, G_2, ... G_t
    G = 0
    for r in reversed(rewards):
        G = r + gamma * G
        returns.insert(0, G)  # Insert at the beginning
    return returns

## The REINFORCE update

In [None]:
optimizer = optim.Adam(policy.parameters(), lr=0.01)
def reinforce_update(policy, optimizer, log_probs, returns):
    returns = torch.tensor(returns, dtype=torch.float32)

    # It is a good idea to normalize the returns, but we don't HAVE to
    returns = (returns - returns.mean()) / (returns.std() + 1e-8) 
    
    loss = 0
    # The gradient is summed over all the steps in a trajectory (hence for loop), and each term has G_t * log_prop(a_t|s_t)
    for log_prob, G in zip(log_probs, returns):
        loss -= log_prob * G  # Policy gradient loss (negative because we minimize)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

## Let's train

In [None]:
num_episodes = 1000
gamma = 0.99

reward_progression = []
for episode in range(num_episodes):
    
    # Collect a single trajectory
    states, actions, log_probs, rewards = collect_trajectory(policy, env)
    
    # Compute the returns for the rewards collected
    returns = compute_returns(rewards, gamma)
    
    # Update policy
    reinforce_update(policy, optimizer, log_probs, returns)
    
    # Logging
    total_reward = sum(rewards)
    reward_progression.append(total_reward)
    if episode % 100 == 0:
        print(f"Episode {episode}, Total Reward: {total_reward}")
