Реализуйте алгоритм SAC для среды lunar lander

In [1]:
!pip install swig
!pip install "gymnasium[box2d]"

Collecting swig
  Downloading swig-4.3.1-py3-none-manylinux_2_12_x86_64.manylinux2010_x86_64.whl.metadata (3.5 kB)
Downloading swig-4.3.1-py3-none-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.9/1.9 MB[0m [31m12.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: swig
Successfully installed swig-4.3.1
Collecting box2d-py==2.3.5 (from gymnasium[box2d])
  Downloading box2d-py-2.3.5.tar.gz (374 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m374.4/374.4 kB[0m [31m4.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: box2d-py
  Building wheel for box2d-py (setup.py) ... [?25l[?25hdone
  Created wheel for box2d-py: filename=box2d_py-2.3.5-cp311-cp311-linux_x86_64.whl size=2379369 sha256=ad175533295272e3ecf80cdd6d9a363c5de9f3c08d1f54da803a76171db1a0c2
  Stored in directory: /root/.cache/pip/wheels/ab

In [2]:
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from collections import deque
import random
from torch.distributions import Normal

In [3]:
GAMMA = 0.99
TAU = 0.005
ALPHA = 0.2
ACTOR_LR = 3e-4
CRITIC_LR = 3e-4
REPLAY_SIZE = 100000
BATCH_SIZE = 256
START_STEPS = 10000
TOTAL_STEPS = 200000
UPDATE_AFTER = 1000
UPDATE_EVERY = 50

EPISODE_WINDOW = 100
TARGET_AVG_RETURN = 200

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
class Actor(nn.Module):
    def __init__(self, obs_dim, act_dim, action_low, action_high):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_dim, 256), nn.ReLU(),
            nn.Linear(256, 256), nn.ReLU()
        )
        self.mu_layer = nn.Linear(256, act_dim)
        self.log_std_layer = nn.Linear(256, act_dim)
        self.action_low = torch.tensor(action_low, device=device)
        self.action_high = torch.tensor(action_high, device=device)

    def forward(self, obs):
        x = self.net(obs)
        mean = self.mu_layer(x)
        log_std = self.log_std_layer(x)
        log_std = torch.clamp(log_std, -20, 2)
        std = log_std.exp()

        normal = Normal(mean, std)
        x_t = normal.rsample()
        y_t = torch.tanh(x_t)
        action = y_t * (self.action_high - self.action_low)/2 + (self.action_high + self.action_low)/2

        log_prob = normal.log_prob(x_t)
        log_prob -= torch.log((1 - y_t.pow(2)) + 1e-6)
        log_prob = log_prob.sum(1, keepdim=True)

        return action, log_prob

    def get_action(self, obs, deterministic=False):
        with torch.no_grad():
            obs_tensor = torch.FloatTensor(obs).to(device).unsqueeze(0)
            x = self.net(obs_tensor)
            mean = self.mu_layer(x)
            if deterministic:
                action = torch.tanh(mean)
            else:
                log_std = self.log_std_layer(x)
                log_std = torch.clamp(log_std, -20, 2)
                std = log_std.exp()
                normal = Normal(mean, std)
                x_t = normal.rsample()
                action = torch.tanh(x_t)
            action = action * (self.action_high - self.action_low)/2 + (self.action_high + self.action_low)/2
            return action.squeeze().cpu().numpy()

In [6]:
class Critic(nn.Module):
    def __init__(self, obs_dim, act_dim):
        super().__init__()
        self.q1 = nn.Sequential(
            nn.Linear(obs_dim + act_dim, 256), nn.ReLU(),
            nn.Linear(256, 256), nn.ReLU(),
            nn.Linear(256, 1)
        )
        self.q2 = nn.Sequential(
            nn.Linear(obs_dim + act_dim, 256), nn.ReLU(),
            nn.Linear(256, 256), nn.ReLU(),
            nn.Linear(256, 1)
        )

    def forward(self, obs, act):
        x = torch.cat([obs, act], dim=-1)
        return self.q1(x), self.q2(x)

In [7]:
class ReplayBuffer:
    def __init__(self, size):
        self.buffer = deque(maxlen=size)

    def add(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        return (
            torch.FloatTensor(np.array(states)).to(device),
            torch.FloatTensor(np.array(actions)).to(device),
            torch.FloatTensor(np.array(rewards)).unsqueeze(1).to(device),
            torch.FloatTensor(np.array(next_states)).to(device),
            torch.FloatTensor(np.array(dones)).unsqueeze(1).to(device)
        )

In [8]:
env = gym.make("LunarLanderContinuous-v3")
obs_dim = env.observation_space.shape[0]
act_dim = env.action_space.shape[0]
action_low = env.action_space.low[0]
action_high = env.action_space.high[0]

actor = Actor(obs_dim, act_dim, action_low, action_high).to(device)
critic = Critic(obs_dim, act_dim).to(device)
critic_target = Critic(obs_dim, act_dim).to(device)
critic_target.load_state_dict(critic.state_dict())

actor_optim = optim.Adam(actor.parameters(), lr=ACTOR_LR)
critic_optim = optim.Adam(critic.parameters(), lr=CRITIC_LR)

replay_buffer = ReplayBuffer(REPLAY_SIZE)

In [9]:
def update():
    states, actions, rewards, next_states, dones = replay_buffer.sample(BATCH_SIZE)

    with torch.no_grad():
        next_actions, log_probs = actor(next_states)
        target_q1, target_q2 = critic_target(next_states, next_actions)
        target_q = torch.min(target_q1, target_q2) - ALPHA * log_probs
        target_q = rewards + GAMMA * (1 - dones) * target_q

    current_q1, current_q2 = critic(states, actions)
    critic_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q)

    critic_optim.zero_grad()
    critic_loss.backward()
    critic_optim.step()

    for param in critic.parameters():
        param.requires_grad = False

    actions_pred, log_probs_pred = actor(states)
    q1_pred, q2_pred = critic(states, actions_pred)
    actor_loss = (ALPHA * log_probs_pred - torch.min(q1_pred, q2_pred)).mean()

    actor_optim.zero_grad()
    actor_loss.backward()
    actor_optim.step()

    for param in critic.parameters():
        param.requires_grad = True

    with torch.no_grad():
        for param, target_param in zip(critic.parameters(), critic_target.parameters()):
            target_param.data.copy_(TAU * param.data + (1 - TAU) * target_param.data)

obs, _ = env.reset()
episode_return = 0
episode_length = 0

In [10]:
episode_history = deque(maxlen=EPISODE_WINDOW)
episode_counter = 0

for step in range(1, TOTAL_STEPS + 1):
    if step <= START_STEPS:
        action = env.action_space.sample()
    else:
        action = actor.get_action(obs)

    next_obs, reward, terminated, truncated, _ = env.step(action)
    done = terminated or truncated
    replay_buffer.add(obs, action, reward, next_obs, done)

    obs = next_obs
    episode_return += reward
    episode_length += 1

    if done:
        # Добавляем результат эпизода в историю
        episode_history.append(episode_return)
        episode_counter += 1

        print(f"Episode: {episode_counter}")

        # Выводим статистику
        print(f"Step: {step}, Return: {episode_return:.2f}, Length: {episode_length}")

        # Рассчитываем и выводим среднее каждые 100 эпизодов
        if episode_counter % EPISODE_WINDOW == 0:
            avg_return = np.mean(episode_history)
            print(f"\n--- Средний возврат за последние {EPISODE_WINDOW} эпизодов: {avg_return:.2f} ---\n")

            # Проверяем условие ранней остановки
            if avg_return >= TARGET_AVG_RETURN:
                print(f"Обучение остановлено! Достигнут средний возврат {avg_return:.2f} за {EPISODE_WINDOW} эпизодов")
                break

        # Сброс для нового эпизода
        obs, _ = env.reset()
        episode_return = 0
        episode_length = 0

    if step >= UPDATE_AFTER and step % UPDATE_EVERY == 0:
        for _ in range(UPDATE_EVERY):
            update()

env.close()

Episode: 1
Step: 135, Return: -310.41, Length: 135
Episode: 2
Step: 222, Return: -375.83, Length: 87
Episode: 3
Step: 300, Return: -43.98, Length: 78
Episode: 4
Step: 464, Return: -442.25, Length: 164
Episode: 5
Step: 580, Return: -326.35, Length: 116
Episode: 6
Step: 693, Return: -300.06, Length: 113
Episode: 7
Step: 790, Return: -70.06, Length: 97
Episode: 8
Step: 908, Return: -72.75, Length: 118
Episode: 9
Step: 1016, Return: -183.87, Length: 108
Episode: 10
Step: 1097, Return: -145.50, Length: 81
Episode: 11
Step: 1174, Return: -438.46, Length: 77
Episode: 12
Step: 1278, Return: -332.20, Length: 104
Episode: 13
Step: 1370, Return: -306.59, Length: 92
Episode: 14
Step: 1435, Return: -240.44, Length: 65
Episode: 15
Step: 1526, Return: -166.52, Length: 91
Episode: 16
Step: 1661, Return: -251.95, Length: 135
Episode: 17
Step: 1786, Return: -71.45, Length: 125
Episode: 18
Step: 1886, Return: -52.88, Length: 100
Episode: 19
Step: 1996, Return: -90.99, Length: 110
Episode: 20
Step: 2119, 