###  Reference

https://www.youtube.com/watch?v=e20EY4tFC_Q

https://huggingface.co/learn/deep-rl-course/en/unit4/hands-on

In [3]:
import gymnasium as gym
from gymnasium.wrappers import RecordVideo
import os
import shutil
import torch
import time
from collections import deque
import random
import numpy as np
from tqdm import tqdm
from matplotlib import pyplot as plt
import seaborn as sns
import torch.nn as nn

In [8]:
# Check if CUDA is available
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(DEVICE)


cuda


In [42]:
def build_env(name = 'LunarLander-v3', record_name = 'lunar', max_record_steps = int(1e3)):
    # Delete all contents in lunar-agent folder

    # Initialise the environment
    env = gym.make(name, render_mode="rgb_array")

    if record_name != None and record_name != "":
        path = os.path.join('output', record_name)
        if os.path.exists(path):
            shutil.rmtree(path)

        env = RecordVideo(
            env,
            video_folder=path,
            episode_trigger=lambda x: True,  # Record every episode
            name_prefix="training",
            video_length=max_record_steps,  # Maximum number of steps to record per episode
        )

    return env

env = build_env()

# Reset the environment to generate the first observation
observation, info = env.reset(seed=42)
print('state =', observation.shape ,' type =', type(observation))
print("action shape = ", env.action_space.sample().shape, ' type =', type(env.action_space.sample()))

for _ in range(1000):
    # this is where you would insert your policy
    action = env.action_space.sample()

    # step (transition) through the environment with the action
    # receiving the next observation, reward and if the episode has terminated or truncated
    observation, reward, terminated, truncated, info = env.step(action)

    # If the episode has ended then we can reset to start a new episode
    if terminated or truncated:
        observation, info = env.reset()

env.close()


state = (8,)  type = <class 'numpy.ndarray'>
action shape =  ()  type = <class 'numpy.int64'>


### REINFORCE

in value based, we use value function (Qnet, Q-table) to estimate the policy

policy-based directly optimize the policy function without using intermediate value function

In [72]:
class Policy(nn.Module):
    def __init__(self, state_size = 8, hidden_size = 16, action_size = 4):
        super(Policy, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(state_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, action_size),
            nn.Softmax(dim=1) #turn output into probability distribution
        )

    def forward(self, x):
        return self.model(x)

    def act(self, state):
        tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
        probs = self.forward(tensor)
        m = torch.distributions.Categorical(probs)
        action = m.sample()
        return action.item(), m.log_prob(action)
    
    def save(self, name = None):
        torch.save(self.state_dict(), 'policy' if name == None else name + '.pth')

# Test
model = Policy()

state = np.random.rand(8)
print(state.shape)

action, log_prob = model.act(state)
print(action)
print(log_prob.item())

(8,)
1
-1.3966193199157715


In [75]:
#train
gamma = 0.99
num_episodes = 1000
reward_threshold = 200
print_every = 10
training_eps = 5000
max_steps = 2000

env_id = 'LunarLander-v3'

env = build_env(name = env_id  ,record_name=None)
state, info = env.reset()
action = env.action_space.sample()

#TODO:  add DEVICE

# step 2 - prep policy and optimizer
policy = Policy(state_size=state.shape[0], action_size=4) 
optimizer = torch.optim.Adam(policy.parameters(), lr=1e-2)

loop = tqdm(range(training_eps))
latest_rewards = deque(maxlen=100)
best_100_rewards = 0
for ep in loop:
    saved_log_probs = []
    rewards = []
    state, info = env.reset()
    # step 3 - generate episode S0, A0, R0, ..., ST-1, AT-1, RT-1, using policy
    for t in range(max_steps):
        action, log_prob = policy.act(state) 
        saved_log_probs.append(log_prob)
        state, reward, terminated, truncated, info = env.step(action)
        rewards.append(reward)

        if terminated or truncated:
            break
    
    # step 4 - compute discounted return Gt
    returns = deque(maxlen=(max_steps))
    n_steps = len(rewards)

    # G_t = r_(t+1) + gamma*G_(t+1)
    # use DP, from last to first
    # -> G(t - 1) = r_t + gamma*(G_t)

    for t in range(n_steps)[::-1]:
        discount_return_t = returns[0] if len(returns) > 0 else 0
        returns.appendleft(rewards[t] + gamma * discount_return_t)

    
    e = np.finfo(np.float32).eps.item()
    returns = torch.tensor(returns)
    # do the following to avoid numerical instability
    returns = (returns - returns.mean()) / (returns.std() + e)

    # step 5 - calculate objective function 
    policy_loss = []
    for log_prob, discounted_return in zip(saved_log_probs, returns):
        policy_loss.append(-log_prob * discounted_return) # add minus to log prop to maximize reward
    policy_loss = torch.cat(policy_loss).sum()

    # step 6 - policy gradient update

    optimizer.zero_grad()
    policy_loss.backward()
    optimizer.step()

    # training stats report, 
    latest_rewards.append(sum(rewards))
    avg_rewards = sum(latest_rewards)/len(latest_rewards) if len(latest_rewards) > 0 else 0
    if avg_rewards > best_100_rewards :
        best_100_rewards = avg_rewards
        policy.save('best_policy')
    loop.set_description(f"Episode: {ep}\tAvg reward:\t{avg_rewards:.2f}")

env.close()



Episode: 4999	Total reward:	76.20: 100%|██████████| 5000/5000 [22:18<00:00,  3.74it/s] 


In [79]:
# Inference
env = build_env(record_name='lunar')
policy = Policy()
policy.load_state_dict(torch.load('best_policy.pth'))
policy.eval()

state, info = env.reset()
for _ in range(2000):
    action, _ = policy.act(state)
    state, reward, terminated, truncated, info = env.step(action)
    if terminated or truncated:
        break

env.close()

  policy.load_state_dict(torch.load('best_policy.pth'))
