In [1]:
import torch, numpy as np
import gymnasium as gym

In [2]:
# Create the environment
env = gym.make('HalfCheetah-v4')

print(f'Observation space: {env.observation_space}')
print(f'Action space: {env.action_space}')

Observation space: Box(-inf, inf, (17,), float64)
Action space: Box(-1.0, 1.0, (6,), float32)


In [3]:
gamma = 0.99
lr = 0.0005
episodes = 2000
hid_layer = 512
hid_layer2 = 512
randomness_begin = 1.0
randomness_end = 0.02

In [4]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [5]:
device

device(type='cuda', index=0)

In [9]:
from typing import Any


class A2C(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, hidden_dim2, action_dim):
        super(A2C, self).__init__()
        self.common = torch.nn.Sequential(
            torch.nn.Linear(input_dim, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim, hidden_dim2),
            torch.nn.ReLU(),
        )
        self.actor_mu = torch.nn.Sequential(
            torch.nn.Linear(hidden_dim2, action_dim),
            torch.nn.Tanh()
        )
        self.actor_var = torch.nn.Sequential(
            torch.nn.Linear(hidden_dim2, action_dim),
            torch.nn.Softplus()
        )
        self.critic = torch.nn.Linear(hidden_dim2, 1)

    def forward(self, x):
        common_out = self.common(x)
        return self.actor_mu(common_out), self.actor_var(common_out), self.critic(common_out)
    
    def __call__(self, x):
        return self.forward(x)

In [66]:
import random
import math

class AgentA2C():
    def __init__(self, model, optim, device, eps_start, eps_end, eps_decay_time, loss, entropy_beta):
        self.model = model
        self.optim = optim
        self.device = device
        self.eps_start = eps_start
        self.eps_end = eps_end
        self.eps_decay_time = eps_decay_time
        self.steps_counter = 0
        self.episodes_counter = 0
        self.loss = loss
        self.entropy_beta = entropy_beta


    def calc_logprob(self, mu, var, action):
        p1 = - ((mu - action) ** 2) / (2*var.clamp(min=1e-3))
        p2 = - torch.log(torch.sqrt(2* math.pi* var))
        return p1 + p2

    def run_episode(self):
        state = env.reset()[0]
        done = False
        state = torch.tensor(env.reset()[0], dtype=torch.float).to(self.device)

        total_loss = 0.
        total_actor_loss = 0.
        total_critic_loss = 0.
        steps = 0
        risks = 0

        while not done:
            mu, var, value = self.model(state)

            std = torch.sqrt(var).data.cpu()
            log_std = torch.log(std).to(self.device)
            std = std.numpy()

            eps = self.eps_end + (self.eps_start - self.eps_end) * (1 - min(self.episodes_counter, self.eps_decay_time)/self.eps_decay_time)

            risk = random.random() <= eps

            if risk:
                risks += 1

            action = np.random.normal(mu, std) if not risk else np.random.uniform(-1, 1, 6)

            state, r, done, _, _ = env.step(action)
            state = torch.tensor(state, dtype=torch.float32).to(self.device)

            _, _, new_value = self.model(state)

            TD_err = r + gamma*new_value*(1 - int(done)) - value

            critic_loss = (TD_err**2).mean()

            action = torch.tensor(action, dtype=torch.float32).to(self.device)

            log_prob = self.calc_logprob(mu, var, action)

            actor_loss = (-log_prob * TD_err).mean()

            entropy_loss = (self.entropy_beta * (-(torch.log(2*math.pi*var) + 1)/2)).mean()

            loss = critic_loss+actor_loss+entropy_loss

            total_loss += loss
            total_actor_loss += actor_loss
            total_critic_loss += critic_loss

            print(f'Actor loss: {actor_loss}')
            print(f'Critic loss: {critic_loss}')
            print(f'Entropy loss: {entropy_loss}')
            print(f'Loss: {loss}')

            self.optim.zero_grad()
            loss.backward()
            # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            self.optim.step()
            self.steps_counter += 1

        self.episodes_counter += 1

        return steps

In [67]:
model = A2C(env.observation_space.shape[0], hid_layer, hid_layer2, env.action_space.shape[0]).to(device)

In [68]:
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
epsilon_decay_time = 1000
entropy_beta = 0.01

In [69]:
agent = AgentA2C(model, optimizer, device, randomness_begin, randomness_end, epsilon_decay_time, loss_fn, entropy_beta)

In [70]:
for episode in range(episodes):
    steps = agent.run_episode()
    print(f'Episode: {episode}, steps: {steps}')

Actor loss: -0.17245137691497803
Critic loss: 0.026317566633224487
Entropy loss: -0.01236751675605774
Loss: -0.15850132703781128
Actor loss: 1.1307529211044312
Critic loss: 1.088650107383728
Entropy loss: -0.012605908326804638
Loss: 2.2067971229553223
Actor loss: 0.47797632217407227
Critic loss: 0.2523984909057617
Entropy loss: -0.012242275290191174
Loss: 0.7181325554847717
Actor loss: -0.10669003427028656
Critic loss: 0.01062803715467453
Entropy loss: -0.012842963449656963
Loss: -0.10890495777130127
Actor loss: 0.20158110558986664
Critic loss: 0.04163181781768799
Entropy loss: -0.012707000598311424
Loss: 0.23050592839717865
Actor loss: 0.811187744140625
Critic loss: 0.8126053214073181
Entropy loss: -0.012248784303665161
Loss: 1.6115443706512451
Actor loss: 0.17117080092430115
Critic loss: 0.030517254024744034
Entropy loss: -0.011908305808901787
Loss: 0.18977974355220795
Actor loss: -0.031254447996616364
Critic loss: 0.0009941385360434651
Entropy loss: -0.012503686361014843
Loss: -0.04

KeyboardInterrupt: 