In [1]:
# Actor Critic example Frozenlake
# using Torch

In [9]:
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
import session_info

In [3]:
class ActorCritic(nn.Module):
    def __init__(self, input_dim, n_actions):
        super(ActorCritic, self).__init__()
        self.shared = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU()
        )
        self.actor = nn.Sequential(
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, n_actions),
            nn.Softmax(dim=-1)
        )
        self.critic = nn.Sequential(
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 1)
        )
    
    def forward(self, state):
        shared_features = self.shared(state)
        action_probs = self.actor(shared_features)
        state_value = self.critic(shared_features)
        return action_probs, state_value

def normalize_reward(rewards):
    return (rewards - rewards.mean()) / (rewards.std() + 1e-5)

def train(env, model, optimizer, n_episodes, gamma):
    for episode in range(n_episodes):
        state, _ = env.reset()
        done = False
        episode_rewards = []
        episode_log_probs = []
        episode_values = []
        
        while not done:
            state = torch.FloatTensor(state)
            action_probs, state_value = model(state)
            
            dist = Categorical(action_probs)
            action = dist.sample()
            
            next_state, reward, done, _ , _ = env.step(action.item())
            
            episode_rewards.append(reward)
            episode_log_probs.append(dist.log_prob(action))
            episode_values.append(state_value)
            
            state = next_state
        
        # Compute returns and advantages
        returns = []
        advantages = []
        R = 0
        for r, v in zip(reversed(episode_rewards), reversed(episode_values)):
            R = r + gamma * R
            advantage = R - v.item()
            returns.insert(0, R)
            advantages.insert(0, advantage)
        
        returns = torch.tensor(returns)
        advantages = torch.tensor(advantages)
        advantages = normalize_reward(advantages)
        
        # Compute losses
        actor_loss = -(torch.stack(episode_log_probs) * advantages).mean()
        critic_loss = nn.MSELoss()(torch.cat(episode_values), returns)
        
        loss = actor_loss + 0.5 * critic_loss
        
        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()
        
        print(f"Episode {episode + 1}, Reward: {sum(episode_rewards)}")

def test(env, model, n_episodes=15):
    total_rewards = []
    for episode in range(n_episodes):
        state, _ = env.reset()
        done = False
        episode_reward = 0
        
        k=0
        while not done:
            state = torch.FloatTensor(state)
            action_probs, _ = model(state)
            
            action = torch.argmax(action_probs).item()
            
            next_state, reward, done, _ , _= env.step(action)
            , _
            episode_reward += reward
            state = next_state
            k = k + 1
        total_rewards.append(episode_reward)
        print(f"Test Episode {episode + 1}, Reward: {episode_reward} steps: {k}")
    avg_reward = sum(total_rewards) / len(total_rewards)
    print(f"Average Reward over {n_episodes} episodes: {avg_reward}")

In [4]:
# Hyperparameters
input_dim = 4
n_actions = 2
lr = 0.001
n_episodes = 2000
gamma = 0.99

In [5]:
# Create environment and model
env = gym.make('CartPole-v1')
model = ActorCritic(input_dim, n_actions)
optimizer = optim.Adam(model.parameters(), lr=lr)

# Train the model
train(env, model, optimizer, n_episodes, gamma)

Episode 1, Reward: 21.0
Episode 2, Reward: 18.0
Episode 3, Reward: 19.0
Episode 4, Reward: 25.0
Episode 5, Reward: 17.0
Episode 6, Reward: 21.0
Episode 7, Reward: 29.0
Episode 8, Reward: 18.0
Episode 9, Reward: 12.0
Episode 10, Reward: 19.0
Episode 11, Reward: 12.0
Episode 12, Reward: 12.0
Episode 13, Reward: 20.0
Episode 14, Reward: 12.0
Episode 15, Reward: 25.0
Episode 16, Reward: 40.0
Episode 17, Reward: 24.0
Episode 18, Reward: 53.0
Episode 19, Reward: 21.0
Episode 20, Reward: 12.0
Episode 21, Reward: 19.0
Episode 22, Reward: 45.0
Episode 23, Reward: 28.0
Episode 24, Reward: 35.0
Episode 25, Reward: 11.0
Episode 26, Reward: 41.0
Episode 27, Reward: 29.0
Episode 28, Reward: 14.0
Episode 29, Reward: 21.0
Episode 30, Reward: 30.0
Episode 31, Reward: 15.0
Episode 32, Reward: 9.0
Episode 33, Reward: 13.0
Episode 34, Reward: 37.0
Episode 35, Reward: 13.0
Episode 36, Reward: 21.0
Episode 37, Reward: 23.0
Episode 38, Reward: 10.0
Episode 39, Reward: 13.0
Episode 40, Reward: 14.0
Episode 41

Episode 338, Reward: 14.0
Episode 339, Reward: 29.0
Episode 340, Reward: 32.0
Episode 341, Reward: 31.0
Episode 342, Reward: 38.0
Episode 343, Reward: 28.0
Episode 344, Reward: 31.0
Episode 345, Reward: 23.0
Episode 346, Reward: 18.0
Episode 347, Reward: 53.0
Episode 348, Reward: 23.0
Episode 349, Reward: 99.0
Episode 350, Reward: 12.0
Episode 351, Reward: 23.0
Episode 352, Reward: 20.0
Episode 353, Reward: 22.0
Episode 354, Reward: 17.0
Episode 355, Reward: 35.0
Episode 356, Reward: 12.0
Episode 357, Reward: 23.0
Episode 358, Reward: 27.0
Episode 359, Reward: 19.0
Episode 360, Reward: 12.0
Episode 361, Reward: 22.0
Episode 362, Reward: 18.0
Episode 363, Reward: 24.0
Episode 364, Reward: 38.0
Episode 365, Reward: 32.0
Episode 366, Reward: 34.0
Episode 367, Reward: 9.0
Episode 368, Reward: 15.0
Episode 369, Reward: 20.0
Episode 370, Reward: 10.0
Episode 371, Reward: 12.0
Episode 372, Reward: 38.0
Episode 373, Reward: 13.0
Episode 374, Reward: 49.0
Episode 375, Reward: 15.0
Episode 376, 

Episode 658, Reward: 96.0
Episode 659, Reward: 106.0
Episode 660, Reward: 57.0
Episode 661, Reward: 74.0
Episode 662, Reward: 16.0
Episode 663, Reward: 91.0
Episode 664, Reward: 27.0
Episode 665, Reward: 20.0
Episode 666, Reward: 32.0
Episode 667, Reward: 104.0
Episode 668, Reward: 15.0
Episode 669, Reward: 19.0
Episode 670, Reward: 82.0
Episode 671, Reward: 21.0
Episode 672, Reward: 67.0
Episode 673, Reward: 25.0
Episode 674, Reward: 12.0
Episode 675, Reward: 41.0
Episode 676, Reward: 67.0
Episode 677, Reward: 44.0
Episode 678, Reward: 68.0
Episode 679, Reward: 38.0
Episode 680, Reward: 69.0
Episode 681, Reward: 17.0
Episode 682, Reward: 75.0
Episode 683, Reward: 15.0
Episode 684, Reward: 39.0
Episode 685, Reward: 28.0
Episode 686, Reward: 52.0
Episode 687, Reward: 70.0
Episode 688, Reward: 66.0
Episode 689, Reward: 101.0
Episode 690, Reward: 39.0
Episode 691, Reward: 23.0
Episode 692, Reward: 130.0
Episode 693, Reward: 27.0
Episode 694, Reward: 78.0
Episode 695, Reward: 25.0
Episode 

Episode 974, Reward: 96.0
Episode 975, Reward: 63.0
Episode 976, Reward: 123.0
Episode 977, Reward: 40.0
Episode 978, Reward: 102.0
Episode 979, Reward: 125.0
Episode 980, Reward: 75.0
Episode 981, Reward: 102.0
Episode 982, Reward: 139.0
Episode 983, Reward: 21.0
Episode 984, Reward: 138.0
Episode 985, Reward: 160.0
Episode 986, Reward: 138.0
Episode 987, Reward: 103.0
Episode 988, Reward: 92.0
Episode 989, Reward: 117.0
Episode 990, Reward: 117.0
Episode 991, Reward: 70.0
Episode 992, Reward: 28.0
Episode 993, Reward: 78.0
Episode 994, Reward: 28.0
Episode 995, Reward: 51.0
Episode 996, Reward: 105.0
Episode 997, Reward: 24.0
Episode 998, Reward: 76.0
Episode 999, Reward: 53.0
Episode 1000, Reward: 123.0
Episode 1001, Reward: 81.0
Episode 1002, Reward: 108.0
Episode 1003, Reward: 30.0
Episode 1004, Reward: 66.0
Episode 1005, Reward: 26.0
Episode 1006, Reward: 49.0
Episode 1007, Reward: 147.0
Episode 1008, Reward: 27.0
Episode 1009, Reward: 40.0
Episode 1010, Reward: 47.0
Episode 1011

Episode 1274, Reward: 118.0
Episode 1275, Reward: 40.0
Episode 1276, Reward: 97.0
Episode 1277, Reward: 114.0
Episode 1278, Reward: 75.0
Episode 1279, Reward: 27.0
Episode 1280, Reward: 198.0
Episode 1281, Reward: 203.0
Episode 1282, Reward: 164.0
Episode 1283, Reward: 58.0
Episode 1284, Reward: 63.0
Episode 1285, Reward: 19.0
Episode 1286, Reward: 134.0
Episode 1287, Reward: 170.0
Episode 1288, Reward: 90.0
Episode 1289, Reward: 67.0
Episode 1290, Reward: 124.0
Episode 1291, Reward: 174.0
Episode 1292, Reward: 27.0
Episode 1293, Reward: 33.0
Episode 1294, Reward: 103.0
Episode 1295, Reward: 80.0
Episode 1296, Reward: 27.0
Episode 1297, Reward: 181.0
Episode 1298, Reward: 91.0
Episode 1299, Reward: 60.0
Episode 1300, Reward: 25.0
Episode 1301, Reward: 116.0
Episode 1302, Reward: 294.0
Episode 1303, Reward: 76.0
Episode 1304, Reward: 229.0
Episode 1305, Reward: 20.0
Episode 1306, Reward: 124.0
Episode 1307, Reward: 127.0
Episode 1308, Reward: 160.0
Episode 1309, Reward: 185.0
Episode 13

Episode 1576, Reward: 152.0
Episode 1577, Reward: 107.0
Episode 1578, Reward: 130.0
Episode 1579, Reward: 157.0
Episode 1580, Reward: 44.0
Episode 1581, Reward: 103.0
Episode 1582, Reward: 170.0
Episode 1583, Reward: 147.0
Episode 1584, Reward: 89.0
Episode 1585, Reward: 114.0
Episode 1586, Reward: 62.0
Episode 1587, Reward: 154.0
Episode 1588, Reward: 89.0
Episode 1589, Reward: 38.0
Episode 1590, Reward: 132.0
Episode 1591, Reward: 163.0
Episode 1592, Reward: 112.0
Episode 1593, Reward: 110.0
Episode 1594, Reward: 57.0
Episode 1595, Reward: 31.0
Episode 1596, Reward: 22.0
Episode 1597, Reward: 79.0
Episode 1598, Reward: 135.0
Episode 1599, Reward: 128.0
Episode 1600, Reward: 44.0
Episode 1601, Reward: 140.0
Episode 1602, Reward: 154.0
Episode 1603, Reward: 69.0
Episode 1604, Reward: 140.0
Episode 1605, Reward: 113.0
Episode 1606, Reward: 58.0
Episode 1607, Reward: 18.0
Episode 1608, Reward: 117.0
Episode 1609, Reward: 41.0
Episode 1610, Reward: 123.0
Episode 1611, Reward: 37.0
Episode

Episode 1875, Reward: 131.0
Episode 1876, Reward: 160.0
Episode 1877, Reward: 195.0
Episode 1878, Reward: 159.0
Episode 1879, Reward: 17.0
Episode 1880, Reward: 132.0
Episode 1881, Reward: 16.0
Episode 1882, Reward: 134.0
Episode 1883, Reward: 193.0
Episode 1884, Reward: 72.0
Episode 1885, Reward: 205.0
Episode 1886, Reward: 32.0
Episode 1887, Reward: 157.0
Episode 1888, Reward: 43.0
Episode 1889, Reward: 66.0
Episode 1890, Reward: 105.0
Episode 1891, Reward: 161.0
Episode 1892, Reward: 165.0
Episode 1893, Reward: 137.0
Episode 1894, Reward: 158.0
Episode 1895, Reward: 162.0
Episode 1896, Reward: 119.0
Episode 1897, Reward: 116.0
Episode 1898, Reward: 26.0
Episode 1899, Reward: 63.0
Episode 1900, Reward: 112.0
Episode 1901, Reward: 128.0
Episode 1902, Reward: 151.0
Episode 1903, Reward: 86.0
Episode 1904, Reward: 100.0
Episode 1905, Reward: 123.0
Episode 1906, Reward: 165.0
Episode 1907, Reward: 168.0
Episode 1908, Reward: 99.0
Episode 1909, Reward: 165.0
Episode 1910, Reward: 179.0
Ep

In [6]:
env = gym.make('CartPole-v1', render_mode='human')
test(env, model)

Test Episode 1, Reward: 274.0 steps: 274
Test Episode 2, Reward: 272.0 steps: 272
Test Episode 3, Reward: 209.0 steps: 209
Test Episode 4, Reward: 328.0 steps: 328
Test Episode 5, Reward: 211.0 steps: 211
Test Episode 6, Reward: 247.0 steps: 247
Test Episode 7, Reward: 204.0 steps: 204
Test Episode 8, Reward: 233.0 steps: 233
Test Episode 9, Reward: 265.0 steps: 265
Test Episode 10, Reward: 221.0 steps: 221
Test Episode 11, Reward: 252.0 steps: 252
Test Episode 12, Reward: 238.0 steps: 238
Test Episode 13, Reward: 230.0 steps: 230
Test Episode 14, Reward: 250.0 steps: 250
Test Episode 15, Reward: 228.0 steps: 228
Average Reward over 15 episodes: 244.13333333333333


In [12]:
exit() # closes pygame window
session_info.show(html=False)

-----
gymnasium           0.29.1
numpy               1.26.4
session_info        1.0.0
torch               2.4.0+cu121
-----
IPython             8.26.0
jupyter_client      8.6.2
jupyter_core        5.7.2
-----
Python 3.12.3 (main, Sep 11 2024, 14:17:37) [GCC 13.2.0]
Linux-5.15.153.1-microsoft-standard-WSL2-x86_64-with-glibc2.39
-----
Session information updated at 2024-09-20 08:43
