Реализуйте алгоритм SAC для среды lunar lander

In [1]:
!pip install swig
!pip install "gymnasium[box2d]"

Collecting swig
  Downloading swig-4.3.1-py3-none-win_amd64.whl.metadata (3.5 kB)
Downloading swig-4.3.1-py3-none-win_amd64.whl (2.6 MB)
   ---------------------------------------- 0.0/2.6 MB ? eta -:--:--
   ---------------------------------------- 0.0/2.6 MB ? eta -:--:--
   -------- ------------------------------- 0.5/2.6 MB 3.7 MB/s eta 0:00:01
   ------------------------ --------------- 1.6/2.6 MB 4.6 MB/s eta 0:00:01
   ---------------------------- ----------- 1.8/2.6 MB 4.4 MB/s eta 0:00:01
   ---------------------------------------- 2.6/2.6 MB 3.6 MB/s eta 0:00:00
Installing collected packages: swig
Successfully installed swig-4.3.1



[notice] A new release of pip is available: 25.0.1 -> 25.1.1
[notice] To update, run: python.exe -m pip install --upgrade pip


Collecting box2d-py==2.3.5 (from gymnasium[box2d])
  Downloading box2d-py-2.3.5.tar.gz (374 kB)
  Installing build dependencies: started
  Installing build dependencies: finished with status 'done'
  Getting requirements to build wheel: started
  Getting requirements to build wheel: finished with status 'done'
  Preparing metadata (pyproject.toml): started
  Preparing metadata (pyproject.toml): finished with status 'done'
Building wheels for collected packages: box2d-py
  Building wheel for box2d-py (pyproject.toml): started
  Building wheel for box2d-py (pyproject.toml): finished with status 'error'
Failed to build box2d-py


  error: subprocess-exited-with-error
  
  Building wheel for box2d-py (pyproject.toml) did not run successfully.
  exit code: 1
  
  [35 lines of output]
  !!
  
          ********************************************************************************
          Please consider removing the following classifiers in favor of a SPDX license expression:
  
          License :: OSI Approved :: zlib/libpng License
  
          See https://packaging.python.org/en/latest/guides/writing-pyproject-toml/#license for details.
          ********************************************************************************
  
  !!
    self._finalize_license_expression()
  Using setuptools (version 80.3.1).
  running bdist_wheel
  running build
  running build_py
  creating build\lib.win-amd64-cpython-313\Box2D
  copying library\Box2D\Box2D.py -> build\lib.win-amd64-cpython-313\Box2D
  copying library\Box2D\__init__.py -> build\lib.win-amd64-cpython-313\Box2D
  creating build\lib.win-amd64-cpython-313\Bo

In [3]:
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from collections import deque
import random
from torch.distributions import Normal

In [4]:
GAMMA = 0.99
TAU = 0.005
ALPHA = 0.2
ACTOR_LR = 3e-4
CRITIC_LR = 3e-4
REPLAY_SIZE = 100000
BATCH_SIZE = 256
START_STEPS = 10000
TOTAL_STEPS = 200000
UPDATE_AFTER = 1000
UPDATE_EVERY = 50

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
class Actor(nn.Module):
    def __init__(self, obs_dim, act_dim, act_limit):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_dim, 256), nn.ReLU(),
            nn.Linear(256, 256), nn.ReLU(),
        )
        self.mu_layer = nn.Linear(256, act_dim)
        self.log_std_layer = nn.Linear(256, act_dim)
        self.act_limit = act_limit

    def forward(self, obs):
        x = F.relu(self.net(obs))
        mean, std = self.mu_layer(x),  torch.clamp(self.log_std_layer(x), -20, 2).exp()
        normal = torch.distributions.Normal(mean, std)

        x_t = normal.rsample()
        y_t = torch.tanh(x_t)
        action = y_t * (self.act_limit[1] - self.act_limit[0]) / 2.0 + (self.act_limit[0] + self.act_limit[1]) / 2.0

        log_prob = normal.log_prob(x_t)
        log_prob -= torch.log((1 - y_t.pow(2)) + 1e-6)
        log_prob = log_prob.sum(1, keepdim=True)

        return action, log_prob

    def get_action(self, obs, deterministic=False):
        if not isinstance(obs, torch.Tensor):
            obs = torch.tensor(obs, dtype=torch.float32, device=device)

        x = self.net(obs)
        mean = self.mu_layer(x)
        log_std = torch.clamp(self.log_std_layer(x), -20, 2)
        std = log_std.exp()

        normal = torch.distributions.Normal(mean, std)

        if deterministic:
            x_t = mean
        else:
            x_t = normal.rsample()

        y_t = torch.tanh(x_t)

        action = y_t * (self.act_limit[1] - self.act_limit[0]) / 2.0 + (self.act_limit[0] + self.act_limit[1]) / 2.0

        return action.squeeze(0).detach().cpu().numpy()

In [7]:
class Critic(nn.Module):
    def __init__(self, obs_dim, act_dim):
        super().__init__()
        self.q1 = nn.Sequential(
            nn.Linear(obs_dim + act_dim, 256), nn.ReLU(),
            nn.Linear(256, 256), nn.ReLU(),
            nn.Linear(256, 1)
        )
        self.q2 = nn.Sequential(
            nn.Linear(obs_dim + act_dim, 256), nn.ReLU(),
            nn.Linear(256, 256), nn.ReLU(),
            nn.Linear(256, 1)
        )

    def forward(self, obs, act):
        if isinstance(act, tuple):
            act = act[0]
        x = torch.cat([obs, act], dim=-1)
        return self.q1(x), self.q2(x)

In [8]:
class ReplayBuffer:
    def __init__(self, size):
        self.buffer = deque(maxlen=size)

    def add(self, *args):
        self.buffer.append(tuple(args))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = map(np.array, zip(*batch))
        return (
            torch.tensor(states, dtype=torch.float32).to(device),
            torch.tensor(actions, dtype=torch.float32).to(device),
            torch.tensor(rewards, dtype=torch.float32).unsqueeze(1).to(device),
            torch.tensor(next_states, dtype=torch.float32).to(device),
            torch.tensor(dones, dtype=torch.float32).unsqueeze(1).to(device)
        )

    def __len__(self):
        return len(self.buffer)

In [10]:
env = gym.make("LunarLanderContinuous-v3")
obs_dim = env.observation_space.shape[0]
act_dim = env.action_space.shape[0]
action_low, action_high = float(env.action_space.low[0]), float(env.action_space.high[0])
act_limit = [action_low, action_high]

actor = Actor(obs_dim, act_dim, act_limit)
critic = Critic(obs_dim, act_dim)

actor_target = Actor(obs_dim, act_dim, act_limit) 
critic_target = Critic(obs_dim, act_dim)

actor_target.load_state_dict(actor.state_dict())
critic_target.load_state_dict(critic.state_dict())

actor_opt = optim.Adam(actor.parameters(), lr=ACTOR_LR)
critic_opt = optim.Adam(critic.parameters(), lr=CRITIC_LR)

replay = ReplayBuffer(REPLAY_SIZE)

obs, _ = env.reset()
episode_return, episode_len = 0, 0

In [11]:
def update(actor, critic, actor_target, critic_target, replay_buffer, actor_opt, critic_opt, batch_size, gamma, tau):
    if len(replay_buffer) < batch_size:
        return

    states, actions, rewards, next_states, dones = replay_buffer.sample(batch_size)

    # ------------------------ Обновление Critic ------------------------
    with torch.no_grad():
        next_actions = actor_target(next_states)
        if isinstance(next_actions, tuple):
            next_actions = next_actions[0]

        target_q1, target_q2 = critic_target(next_states, next_actions)
        target_q = torch.min(target_q1, target_q2)

        target_q = rewards + gamma * (1 - dones) * target_q.squeeze()

    current_q1, current_q2 = critic(states, actions)

    critic_loss = nn.MSELoss()(current_q1.squeeze(), target_q) + nn.MSELoss()(current_q2.squeeze(), target_q)

    critic_opt.zero_grad()
    critic_loss.backward()
    critic_opt.step()

    # ------------------------ Обновление Actor ------------------------
    actor_loss = -critic(states, actor(states))[0].mean()

    actor_opt.zero_grad()
    actor_loss.backward()
    actor_opt.step()

    # ------------------------ Мягкое обновление целевых сетей ------------------------
    for param, target_param in zip(critic.parameters(), critic_target.parameters()):
        target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)

    for param, target_param in zip(actor.parameters(), actor_target.parameters()):
        target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)

In [12]:
actor.to(device)
critic.to(device)
actor_target.to(device)
critic_target.to(device)

Critic(
  (q1): Sequential(
    (0): Linear(in_features=10, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=256, bias=True)
    (3): ReLU()
    (4): Linear(in_features=256, out_features=1, bias=True)
  )
  (q2): Sequential(
    (0): Linear(in_features=10, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=256, bias=True)
    (3): ReLU()
    (4): Linear(in_features=256, out_features=1, bias=True)
  )
)

In [13]:
for step in range(TOTAL_STEPS):
    if step < START_STEPS:
        act = env.action_space.sample()
    else:
        with torch.no_grad():
            obs_t = torch.tensor(obs, dtype=torch.float32, device=device).unsqueeze(0)
            act = actor.get_action(obs_t)

    next_obs, rew, terminated, truncated, _ = env.step(act)
    done = terminated or truncated

    replay.add(obs, act, rew, next_obs, done)

    obs = next_obs
    episode_return += rew
    episode_len += 1

    if done:
        obs, _ = env.reset()
        print(f"Step: {step}, Return: {episode_return:.2f}, Len: {episode_len}")
        episode_return, episode_len = 0, 0

    if step >= UPDATE_AFTER and step % UPDATE_EVERY == 0:
        for _ in range(UPDATE_EVERY):
            update(
                actor=actor,
                critic=critic,
                actor_target=actor_target,
                critic_target=critic_target,
                replay_buffer=replay,
                actor_opt=actor_opt,
                critic_opt=critic_opt,
                batch_size=BATCH_SIZE,
                gamma=GAMMA,
                tau=TAU
            )

Step: 123, Return: -212.72, Len: 124
Step: 221, Return: -368.53, Len: 98
Step: 304, Return: -200.26, Len: 83
Step: 389, Return: -167.25, Len: 85
Step: 502, Return: -63.07, Len: 113
Step: 602, Return: -229.85, Len: 100
Step: 695, Return: -271.38, Len: 93
Step: 776, Return: -83.85, Len: 81
Step: 926, Return: -425.70, Len: 150


  return F.mse_loss(input, target, reduction=self.reduction)


Step: 1040, Return: -105.88, Len: 114
Step: 1145, Return: -181.55, Len: 105
Step: 1244, Return: -237.27, Len: 99
Step: 1359, Return: -218.05, Len: 115
Step: 1503, Return: -290.72, Len: 144
Step: 1601, Return: -417.03, Len: 98
Step: 1763, Return: -376.73, Len: 162
Step: 1858, Return: -133.01, Len: 95
Step: 1929, Return: -99.35, Len: 71
Step: 2034, Return: -142.97, Len: 105
Step: 2164, Return: -135.06, Len: 130
Step: 2234, Return: -107.50, Len: 70
Step: 2310, Return: -66.90, Len: 76
Step: 2458, Return: -236.81, Len: 148
Step: 2564, Return: -513.88, Len: 106
Step: 2652, Return: -213.44, Len: 88
Step: 2772, Return: -242.73, Len: 120
Step: 2848, Return: -50.30, Len: 76
Step: 2934, Return: -235.05, Len: 86
Step: 3030, Return: -240.16, Len: 96
Step: 3225, Return: -163.55, Len: 195
Step: 3377, Return: -49.27, Len: 152
Step: 3541, Return: -285.65, Len: 164
Step: 3634, Return: -316.79, Len: 93
Step: 3759, Return: -315.04, Len: 125
Step: 3867, Return: -370.80, Len: 108
Step: 3966, Return: -532.76