# 1. Importing libraries

In [1]:
from gymnasium import make
import torch as T
from torch import nn
from torch.distributions import Normal

# 2. Creating network

In [2]:
class ActorCriticNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = nn.Sequential(
            nn.Linear(2, 32),
            nn.Tanh(),
            nn.Linear(32, 32),
            nn.Tanh()
        )
        
        self.actor = nn.Sequential(
            nn.Linear(32, 32),
            nn.Tanh(),
            nn.Linear(32, 1)
        )
        
        self.critic = nn.Sequential(
            nn.Linear(32, 32),
            nn.Tanh(),
            nn.Linear(32, 1)
        )
        self.log_std = nn.Parameter(T.zeros(1, requires_grad=True))
        # self.log_std = nn.Parameter(-1 * T.ones(1, requires_grad=True))

    def forward(self, x):
        x = self.backbone(x)
        actor_mean = self.actor(x)
        critic_value = self.critic(x)
        std = T.exp(self.log_std)
        dist = Normal(loc=actor_mean, scale=std)
        return actor_mean, critic_value, dist

# 3. Train loop

## 3.1. Creating objects

In [3]:
env = make("MountainCarContinuous-v0")

network = ActorCriticNetwork()
optimizer = T.optim.Adam(network.parameters(), lr=2e-4)
loss_fn = nn.HuberLoss()

train_step = 2048

# parameters
gamma_ = 0.99
lambda_ = 0.95

## 3.2 Training

<li>states: torch.Size([32, 2]) <br>
<li>actons: torch.Size([32, 1]) <br>
<li>rewards: torch.Size([32]) <br>
<li>next_states: torch.Size([32, 2]) <br>
<li>dones: torch.Size([32]) <br>
<li>critic_values: torch.Size([32, 1])

In [4]:
buffer = []
state, _ = env.reset()
network.train()
total_reward = 0
total_reward_list = []

for i in range(int(1e7)):
    tensor_state = T.tensor(state).float()
    with T.no_grad():
        network_output = network(tensor_state)
    actor_mean, critic_value, dist = network_output
    action = dist.sample()
    next_state, reward, terminated, truncated, _ = env.step(action.numpy())
    total_reward += float(reward)
    done = T.tensor(terminated or truncated).float()
    
    velocity = next_state[-1]
    buffer.append(
        (
            tensor_state,
            action,
            T.tensor(reward + 100 * (velocity ** 2) + 5 * abs(velocity)),
            T.tensor(next_state),
            done,
            critic_value.detach()
        )
    )
    if terminated or truncated:
        state, _ = env.reset()
        total_reward_list.append(total_reward)
        total_reward = 0
    else:
        state = next_state
    if len(buffer) == train_step:
        states, actions, rewards, next_states, dones, critic_values = (
            [T.stack(column, dim=0) for column in zip(*buffer)]
        )
        buffer.clear()
        
        # calculate returns and advantages
        returns, advantages = [], []
        g = T.tensor(0)
        for k in reversed(range(train_step-1)):
            td_error = rewards[k] + gamma_ * critic_values[k + 1, 0] * (1 - dones[k]) - critic_values[k, 0]
            g = td_error + gamma_ * lambda_ * g * (1 - dones[k])
            
            returns.insert(0, g + critic_values[k, 0])
            advantages.insert(0, g)
        
        returns = T.stack(returns, dim=0)
        advantages = T.stack(advantages, dim=0)
        
        for batch_idx in range(0, train_step, 256):
            batch_range = slice(batch_idx,min(batch_idx + 256, train_step-1))

            batch_advantages = advantages[batch_range]
            batch_states = states[batch_range]
            batch_actions = actions[batch_range]
            batch_returns = returns[batch_range]
            
            batch_advantages = (batch_advantages - batch_advantages.mean()) / (batch_advantages.std() + 1e-8)
            
            # calculate losses
            optimizer.zero_grad()
            _, batch_critic_value, batch_dist = network(batch_states)
            
            log_prob = batch_dist.log_prob(batch_actions)
            policy_loss = -(log_prob.sum(-1) * batch_advantages.detach()).mean()
                
            critic_loss = loss_fn(batch_critic_value.squeeze(-1), batch_returns)
            
            entropy = batch_dist.entropy().mean()
            
            loss = policy_loss + 0.5 * critic_loss - 0.01 * entropy
            loss.backward()
            T.nn.utils.clip_grad_norm_(network.parameters(), 0.5)
            optimizer.step()
        
        if (i + 1) % (2048 * 10) == 0:
            print(f"Step: {i}, last rewards: {sum(total_reward_list[-10:])/len(total_reward_list[-10:]):.2f}, max reward: {max(total_reward_list[-10:]):.2f}")
            print(f"Loss: {loss.item():.4f}, policy loss: {policy_loss.item():.4f}, critic_loss: {critic_loss.item():.4f}, entropy: {entropy.item():.4f}")
            print(f"Returns: {returns.mean().item():.4f}, advantages: {advantages.mean().item():.4f}, actions: {actions.mean().item():.4f}, actions std: {actions.std().item():.4f}")

Step: 20479, last rewards: -102.37, max reward: -96.03
Loss: 0.1704, policy loss: -0.1153, critic_loss: 0.5996, entropy: 1.4039
Returns: -0.5629, advantages: -0.0981, actions: 0.0526, actions std: 0.9947
Step: 40959, last rewards: -93.98, max reward: -85.43
Loss: 0.0688, policy loss: -0.1295, critic_loss: 0.4242, entropy: 1.3903
Returns: -0.8652, advantages: -0.2951, actions: -0.0021, actions std: 0.9866
Step: 61439, last rewards: -92.83, max reward: -88.47
Loss: 0.1013, policy loss: -0.1336, critic_loss: 0.4973, entropy: 1.3752
Returns: -0.4021, advantages: 0.3312, actions: -0.0006, actions std: 0.9648
Step: 81919, last rewards: -90.00, max reward: -83.00
Loss: 0.1869, policy loss: 0.0240, critic_loss: 0.3531, entropy: 1.3615
Returns: -0.6653, advantages: 0.0491, actions: -0.0285, actions std: 0.9488
Step: 102399, last rewards: -87.60, max reward: -82.89
Loss: -0.0036, policy loss: -0.1029, critic_loss: 0.2255, entropy: 1.3488
Returns: -0.6760, advantages: -0.0623, actions: 0.0090, ac

KeyboardInterrupt: 