In [13]:
import torch
import torch.nn as nn
import torch.optim as optim
import gymnasium as gym

In [14]:
class Actor(nn.Module):
    def __init__(self, n_obs, n_actions):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(n_obs, 128),
            nn.ReLU(),
            nn.Linear(128, n_actions),
            nn.Softmax(dim = -1)
        )

    def forward(self, x):
        return self.model(x)

class Critic(nn.Module):
    def __init__(self, n_obs):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(n_obs, 128),
            nn.ReLU(),
            nn.Linear(128, 1)            
        )

    def forward(self, x):
        return self.model(x)

In [15]:
env = gym.make("CartPole-v1")
n_obs = env.observation_space.shape[0]
n_actions = env.action_space.n

print(n_obs)
print(n_actions)

4
2


In [16]:
LR = 0.001

actor = Actor(n_obs, n_actions)
critic = Critic(n_obs)
optimizer_actor = optim.Adam(actor.parameters(), lr = LR)
optimizer_critic = optim.Adam(critic.parameters(), lr = LR)

In [17]:
# A2C training after single episode
def get_returns(rewards, gamma):
    returns = []
    R = 0
    for r in reversed(rewards):
        R = r + gamma * R
        returns.insert(0, R)

    return returns

def train(env, actor, critic, optimizer_actor, optimizer_critic, episodes = 5000, gamma = 0.99):
    total_rewards = []

    for episode in range(episodes):
        state, _ = env.reset()
        log_probs, states, rewards = [], [], []

        done = False
        while not done:
            # actor works
            state_tensor = torch.tensor(state, dtype = torch.float32)
            action_probs = actor(state_tensor.unsqueeze(0))
            dist = torch.distributions.Categorical(action_probs)
            action = dist.sample()
            log_prob = dist.log_prob(action)            
            
            log_probs.append(log_prob)  
            states.append(state_tensor) # collect states to estimate Q-values by the Critic after
            
            # make new step
            state, reward, terminated, truncated, _ = env.step(action.item())
            done = terminated or truncated

            rewards.append(reward)      # collect rewards for calculation of the discounted returns by gamma

        # train after episode ends
        total_rewards.append(sum(rewards))

        # calculate cumulative reward for every step
        returns = get_returns(rewards, gamma)
        returns_tensor = torch.tensor(returns, dtype = torch.float32)

        # get Q-values of these states by Critic
        values = critic(torch.stack(states)).squeeze(-1)    

        # train Critic
        loss_critic = nn.MSELoss()(values, returns_tensor)
        optimizer_critic.zero_grad()
        loss_critic.backward()
        optimizer_critic.step()

        # calc advantages considering feedback from Critic
        advantages = returns_tensor - values.detach()
        # advantages = (advantages - advantages.mean()) / (advantages.std() + 5e-20)

        # train Actor
        actor_loss = - torch.sum(torch.stack(log_probs) * advantages)
        optimizer_actor.zero_grad()
        actor_loss.backward()
        optimizer_actor.step()

        if episode > 0 and episode % 100 == 0:
            print(f"Episode {episode}, mean reward {sum(total_rewards[-100:]) / 100.0}")       

In [18]:
train(env, actor, critic, optimizer_actor, optimizer_critic, episodes = 5000, gamma = 0.99)

Episode 100, mean reward 29.18
Episode 200, mean reward 39.17
Episode 300, mean reward 58.06
Episode 400, mean reward 70.43
Episode 500, mean reward 83.29
Episode 600, mean reward 181.84
Episode 700, mean reward 255.83
Episode 800, mean reward 319.47
Episode 900, mean reward 424.95
Episode 1000, mean reward 414.86
Episode 1100, mean reward 440.95
Episode 1200, mean reward 408.31
Episode 1300, mean reward 453.25
Episode 1400, mean reward 414.94
Episode 1500, mean reward 433.86
Episode 1600, mean reward 422.0
Episode 1700, mean reward 445.14
Episode 1800, mean reward 457.16
Episode 1900, mean reward 441.15
Episode 2000, mean reward 468.06
Episode 2100, mean reward 458.43
Episode 2200, mean reward 484.11
Episode 2300, mean reward 467.41
Episode 2400, mean reward 441.16
Episode 2500, mean reward 470.37
Episode 2600, mean reward 380.93
Episode 2700, mean reward 456.14
Episode 2800, mean reward 454.27
Episode 2900, mean reward 443.71
Episode 3000, mean reward 477.98
Episode 3100, mean reward

In [72]:
# A2C training with batches
def get_returns(rewards, gamma):
    returns = []
    R = 0
    for r in reversed(rewards):
        R = r + gamma * R
        returns.insert(0, R)

    return returns

def train(env, actor, critic, optimizer_actor, optimizer_critic, episodes = 5000, gamma = 0.99, batch_size = 32):    
    total_rewards, batch_log_probs, batch_states, batch_returns = [], [], [], []

    for episode in range(episodes):
        state, _ = env.reset()
        log_probs, states, rewards = [], [], []      

        done = False
        while not done:
            # actor works
            state_tensor = torch.tensor(state, dtype = torch.float32)
            action_probs = actor(state_tensor.unsqueeze(0))
            dist = torch.distributions.Categorical(action_probs)
            action = dist.sample()
            log_prob = dist.log_prob(action)          
            
            log_probs.append(log_prob)  
            states.append(state_tensor) # collect states to estimate Q-values by the Critic after
            
            # make new step
            state, reward, terminated, truncated, _ = env.step(action.item())
            done = terminated or truncated

            rewards.append(reward)      # collect rewards for calculation of the discounted returns by gamma

        # EPISODE ENDS
        total_rewards.append(sum(rewards))

        # calculate cumulative reward for every step
        returns = get_returns(rewards, gamma)
        returns_tensor = torch.tensor(returns, dtype = torch.float32)

        # COLLECT DATA AFTER EPISODE
        batch_log_probs.extend(log_probs)
        batch_states.extend(states)
        batch_returns.extend(returns_tensor)

        # TRAIN
        if len(batch_log_probs) >= batch_size:
            # print("training...")
            # shuffling
            idxs = torch.randperm(batch_size)
            states_batch = torch.stack(batch_states)[idxs]
            log_probs_batch = torch.stack(batch_log_probs)[idxs]
            returns_batch = torch.tensor(batch_returns)[idxs]            
            
            # get Q-values of these states by Critic
            values = critic(states_batch).squeeze(-1) 
    
            # train Critic
            loss_critic = nn.MSELoss()(values, returns_batch)
            optimizer_critic.zero_grad()
            loss_critic.backward()
            optimizer_critic.step()
    
            # calc advantages considering feedback from Critic
            advantages = returns_batch - values.detach()
            # advantages = (advantages - advantages.mean()) / (advantages.std() + 5e-20)
    
            # train Actor
            actor_loss = - torch.sum(log_probs_batch * advantages)
            optimizer_actor.zero_grad()
            actor_loss.backward()
            optimizer_actor.step()

            # RESET MEMORY
            batch_log_probs = []
            batch_states = []    
            batch_returns = []

        if episode > 0 and episode % 100 == 0:
            print(f"Episode {episode}, mean reward {sum(total_rewards[-100:]) / 100.0}")     

In [73]:
LR = 0.001
BATCH_SIZE = 64

actor = Actor(n_obs, n_actions)
critic = Critic(n_obs)
optimizer_actor = optim.Adam(actor.parameters(), lr = LR)
optimizer_critic = optim.Adam(critic.parameters(), lr = LR)

train(env, actor, critic, optimizer_actor, optimizer_critic, episodes = 5000, gamma = 0.99, batch_size = BATCH_SIZE)

Episode 100, mean reward 22.69
Episode 200, mean reward 28.15
Episode 300, mean reward 31.76
Episode 400, mean reward 34.4
Episode 500, mean reward 38.61
Episode 600, mean reward 40.12
Episode 700, mean reward 45.88
Episode 800, mean reward 70.23
Episode 900, mean reward 106.04
Episode 1000, mean reward 166.62
Episode 1100, mean reward 229.99
Episode 1200, mean reward 283.56
Episode 1300, mean reward 368.27
Episode 1400, mean reward 415.42
Episode 1500, mean reward 389.66
Episode 1600, mean reward 417.35
Episode 1700, mean reward 441.54
Episode 1800, mean reward 445.58
Episode 1900, mean reward 461.41
Episode 2000, mean reward 388.51
Episode 2100, mean reward 432.87
Episode 2200, mean reward 437.05
Episode 2300, mean reward 459.05
Episode 2400, mean reward 454.8
Episode 2500, mean reward 470.0
Episode 2600, mean reward 475.3
Episode 2700, mean reward 474.15
Episode 2800, mean reward 483.9
Episode 2900, mean reward 483.02
Episode 3000, mean reward 486.85
Episode 3100, mean reward 489.27