In [17]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import gymnasium as gym
import matplotlib.pyplot as plt
import networkx as nx
from IPython.display import Video


### Compute Returns

Computing the returns from our trajectory is identical to before!

### Train our Model

This is mostly identical to before but we have a few changes:

1) We have two models that both have to be optimized, so we can also define two optimizers.
2) Instead of using $G_t$ directly, we will compute our advantage $A(s,a) = G_t - V(s)$ 

In [37]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import gymnasium as gym

class ActorNetwork(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(ActorNetwork, self).__init__()
        self.fc1 = nn.Linear(input_dim, 64)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(64, output_dim)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, state):
        x = self.fc1(state)
        x = self.relu(x)
        x = self.fc2(x)
        action_probs = self.softmax(x)
        return action_probs

class CriticNetwork(nn.Module):
    def __init__(self, input_dim):
        super(CriticNetwork, self).__init__()
        self.fc1 = nn.Linear(input_dim, 64)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(64, 1)

    def forward(self, state):
        x = self.fc1(state)
        x = self.relu(x)
        state_value = self.fc2(x)
        return state_value

class ActorCriticAgent:
    def __init__(self, env, actor_lr=0.001, critic_lr=0.005, gamma=0.99):
        self.env = env
        self.actor = ActorNetwork(env.observation_space.shape[0], env.action_space.n)
        self.critic = CriticNetwork(env.observation_space.shape[0])
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=actor_lr)
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=critic_lr)
        self.gamma = gamma

    def select_action(self, state):
         state = torch.FloatTensor(state).unsqueeze(0)
         probs = self.actor(state)
         action = np.random.choice(self.env.action_space.n, p=probs.detach().numpy()[0])
         return action

    def train_step(self, state, action, reward, next_state, done):
        state = torch.FloatTensor(state).unsqueeze(0)
        next_state = torch.FloatTensor(next_state).unsqueeze(0)

        # Critic update
        state_value = self.critic(state)
        next_state_value = self.critic(next_state)
        td_target = reward + self.gamma * next_state_value * (1 - done)
        critic_loss = nn.MSELoss()(state_value, td_target)

        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        # Actor update
        probs = self.actor(state)
        log_prob = torch.log(probs[0][action])
        advantage = td_target - state_value
        actor_loss = -log_prob * advantage.detach()

        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

In [39]:
env = gym.make('LunarLander-v3')
agent = ActorCriticAgent(env)

episodes = 200
for episode in range(episodes):
    state = env.reset()[0]
    done = False
    total_reward = 0

    while not done:
        action = agent.select_action(state)
        next_state, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        agent.train_step(state, action, reward, next_state, done)

        total_reward += reward
        state = next_state

    print(f"Episode: {episode + 1}, Total Reward: {total_reward}")

env.close()



Episode: 1, Total Reward: -49.657901693494395
Episode: 2, Total Reward: -398.14078481707793
Episode: 3, Total Reward: -395.348345029353
Episode: 4, Total Reward: -300.6674100394823
Episode: 5, Total Reward: -372.524935825173
Episode: 6, Total Reward: -213.11099822602583
Episode: 7, Total Reward: -82.12964535693312
Episode: 8, Total Reward: -213.85316140456888
Episode: 9, Total Reward: -166.88457185717914
Episode: 10, Total Reward: -240.32921176830365
Episode: 11, Total Reward: -464.78259080242725
Episode: 12, Total Reward: -340.58690768639246
Episode: 13, Total Reward: -78.00880697958297
Episode: 14, Total Reward: -147.35743586501297
Episode: 15, Total Reward: -438.6649297618385
Episode: 16, Total Reward: -60.51617845660549
Episode: 17, Total Reward: -315.0721493902753
Episode: 18, Total Reward: -29.207307740710007
Episode: 19, Total Reward: -252.18492559861417
Episode: 20, Total Reward: -172.65279078818725
Episode: 21, Total Reward: -252.13702887494765
Episode: 22, Total Reward: -247.

KeyboardInterrupt: 