Реализуйте алгоритм SAC для среды lunar lander

In [1]:
!pip install swig
!pip install "gymnasium[box2d]"



In [2]:
import random
from collections import deque

import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Normal

In [3]:
GAMMA = 0.99
TAU = 0.005
ALPHA = 0.2
ACTOR_LR = 3e-4
CRITIC_LR = 3e-4
REPLAY_SIZE = 100000
BATCH_SIZE = 256
START_STEPS = 10000
TOTAL_STEPS = 200000
UPDATE_AFTER = 1000
UPDATE_EVERY = 50

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
class Actor(nn.Module):
    def __init__(self, obs_dim, act_dim, action_low, action_high):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_dim, 256), nn.ReLU(),
            nn.Linear(256, 256), nn.ReLU(),
        )
        self.mu_layer = nn.Linear(256, act_dim)
        self.log_std_layer = nn.Linear(256, act_dim)

        self.action_low = torch.tensor(action_low, dtype=torch.float32).to(device)
        self.action_high = torch.tensor(action_high, dtype=torch.float32).to(device)

    def forward(self, obs):
        x = F.relu(self.net(obs))
        mean, std = self.mu_layer(x), torch.clamp(self.log_std_layer(x), -20, 2).exp()
        normal = torch.distributions.Normal(mean, std)

        x_t = normal.rsample()
        y_t = torch.tanh(x_t)
        action = y_t * (action_high - action_low) / 2.0 + (action_low + action_high) / 2.0

        log_prob = normal.log_prob(x_t)
        log_prob -= torch.log((1 - y_t.pow(2)) + 1e-6)
        log_prob = log_prob.sum(1, keepdim=True)

        return action, log_prob

    def get_action(self, obs, deterministic=False):
        # Напишите функцию, которая возвращает только действие. Если deterministic True, то действие не семплируется, а просто берётся mean (его всё ещё надо предобразовать)
        with torch.no_grad():
            obs_tensor = torch.FloatTensor(obs).unsqueeze(0).to(device)
            x = self.net(obs_tensor)
            mean = self.mu_layer(x)

            if deterministic:
                y_t = torch.tanh(mean)
                action = y_t * (self.action_high - self.action_low) / 2 + (self.action_high + self.action_low) / 2
            else:
                log_std = self.log_std_layer(x)
                log_std = torch.clamp(log_std, -20, 2)
                std = log_std.exp()
                dist = Normal(mean, std)
                x_t = dist.rsample()
                y_t = torch.tanh(x_t)
                action = y_t * (self.action_high - self.action_low) / 2 + (self.action_high + self.action_low) / 2

            return action.squeeze(0).cpu().numpy()


In [6]:
class Critic(nn.Module):
    def __init__(self, obs_dim, act_dim):
        super().__init__()
        self.q1 = nn.Sequential(
            nn.Linear(obs_dim + act_dim, 256), nn.ReLU(),
            nn.Linear(256, 256), nn.ReLU(),
            nn.Linear(256, 1)
        )
        self.q2 = nn.Sequential(
            nn.Linear(obs_dim + act_dim, 256), nn.ReLU(),
            nn.Linear(256, 256), nn.ReLU(),
            nn.Linear(256, 1)
        )

    def forward(self, obs, act):
        x = torch.cat([obs, act], dim=-1)
        return self.q1(x), self.q2(x)

In [7]:
class ReplayBuffer:
    def __init__(self, size):
        self.buffer = deque(maxlen=size)

    def add(self, *args):
        self.buffer.append(tuple(args))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = map(np.array, zip(*batch))
        return (
            torch.tensor(states, dtype=torch.float32).to(device),
            torch.tensor(actions, dtype=torch.float32).to(device),
            torch.tensor(rewards, dtype=torch.float32).unsqueeze(1).to(device),
            torch.tensor(next_states, dtype=torch.float32).to(device),
            torch.tensor(dones, dtype=torch.float32).unsqueeze(1).to(device)
        )

In [8]:
env = gym.make("LunarLanderContinuous-v3")
obs_dim = env.observation_space.shape[0]
act_dim = env.action_space.shape[0]
action_low, action_high = float(env.action_space.low[0]), float(env.action_space.high[0])

actor = Actor(obs_dim, act_dim, action_low, action_high).to(device)
critic = Critic(obs_dim, act_dim).to(device)
critic_target = Critic(obs_dim, act_dim).to(device)
critic_target.load_state_dict(critic.state_dict())

actor_opt = optim.Adam(actor.parameters(), lr=ACTOR_LR)
critic_opt = optim.Adam(critic.parameters(), lr=CRITIC_LR)

target_entropy = -torch.prod(torch.Tensor(env.action_space.shape).to(device)).item()
log_alpha = torch.zeros(1, requires_grad=True, device=device)
alpha_optim = optim.Adam([log_alpha], lr=ACTOR_LR)
ALPHA = log_alpha.exp().item()

replay = ReplayBuffer(REPLAY_SIZE)

obs, _ = env.reset()
episode_return, episode_len = 0, 0

In [9]:
for step in range(TOTAL_STEPS):
    if step < START_STEPS:
        action = env.action_space.sample()
    else:
        action = actor.get_action(obs)
        action = np.clip(action, action_low, action_high)

    next_obs, rew, terminated, truncated, _ = env.step(action)
    done = terminated or truncated
    replay.add(obs, action, rew, next_obs, done)

    obs = next_obs
    episode_return += rew
    episode_len += 1

    if done:
        obs, _ = env.reset()
        print(f"Step: {step}, Return: {episode_return:.2f}, Len: {episode_len}")
        episode_return, episode_len = 0, 0

    if step >= UPDATE_AFTER and step % UPDATE_EVERY == 0:
        for _ in range(UPDATE_EVERY):
            batch = replay.sample(BATCH_SIZE)
            state_batch, action_batch, reward_batch, next_state_batch, done_batch = batch

            with torch.no_grad():
                next_action, next_log_prob = actor(next_state_batch)
                q1_next, q2_next = critic_target(next_state_batch, next_action)
                q_next = torch.min(q1_next, q2_next) - ALPHA * next_log_prob
                target_q = reward_batch + (1 - done_batch) * GAMMA * q_next

            q1, q2 = critic(state_batch, action_batch)
            critic_loss = F.mse_loss(q1, target_q) + F.mse_loss(q2, target_q)

            critic_opt.zero_grad()
            critic_loss.backward()
            critic_opt.step()

            actions, log_probs = actor(state_batch)
            q1, q2 = critic(state_batch, actions)
            q = torch.min(q1, q2)
            actor_loss = (ALPHA * log_probs - q).mean()

            actor_opt.zero_grad()
            actor_loss.backward()
            actor_opt.step()

            alpha_loss = -(log_alpha * (log_probs.detach() + target_entropy)).mean()
            alpha_optim.zero_grad()
            alpha_loss.backward()
            alpha_optim.step()
            ALPHA = log_alpha.exp().item()

            for target_param, source_param in zip(critic_target.parameters(), critic.parameters()):
                target_param.data.copy_(target_param.data * (1.0 - TAU) + source_param.data * TAU)

        env.close()


Step: 83, Return: -102.50, Len: 84
Step: 190, Return: -308.40, Len: 107
Step: 325, Return: -186.76, Len: 135
Step: 446, Return: -269.16, Len: 121
Step: 523, Return: -61.73, Len: 77
Step: 609, Return: -383.28, Len: 86
Step: 733, Return: -127.10, Len: 124
Step: 830, Return: -163.79, Len: 97
Step: 919, Return: -147.09, Len: 89
Step: 1043, Return: -102.68, Len: 124
Step: 1149, Return: -224.89, Len: 106
Step: 1262, Return: -103.07, Len: 113
Step: 1374, Return: -210.87, Len: 112
Step: 1518, Return: -235.68, Len: 144
Step: 1768, Return: -163.81, Len: 250
Step: 1857, Return: -93.49, Len: 89
Step: 1965, Return: -73.42, Len: 108
Step: 2073, Return: -236.21, Len: 108
Step: 2182, Return: -443.43, Len: 109
Step: 2297, Return: -180.26, Len: 115
Step: 2427, Return: -244.69, Len: 130
Step: 2504, Return: -70.68, Len: 77
Step: 2643, Return: -334.98, Len: 139
Step: 2785, Return: -146.11, Len: 142
Step: 2879, Return: -331.85, Len: 94
Step: 2959, Return: -404.88, Len: 80
Step: 3044, Return: -66.87, Len: 85