Реализуйте алгоритм GAIL на среде Mountain Car. Перед этим сгенерируйте экспертные данные (из детерминированной стратегии с первой практики). Хорошей идеей будет добавить в state (observation) синус и косинус от временной метки t для лучшего обучения.

In [1]:
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
from torch.distributions.categorical import Categorical
import torch.optim as optim
import torch.nn.functional as F

In [2]:
env = gym.make("MountainCar-v0")

In [3]:
def generate_expert_data(env, num_episodes=1000):
    states, actions = [], []
    for _ in range(num_episodes):
        obs, _ = env.reset()
        done = False
        t = 0  # Инициализация переменной времени
        while not done:
            # Простая детерминированная стратегия: толкать влево или вправо в зависимости от положения
            action = 0 if obs[1] < 0 else 2
            next_obs, _, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            states.append(np.append(obs, [np.sin(t), np.cos(t)]))  # Используем t вместо _
            actions.append(action)
            obs = next_obs
            t += 1  # Увеличиваем t
    return np.array(states), np.array(actions)

In [4]:
states, actions = generate_expert_data(env)

In [5]:
obs_dim = env.observation_space.shape[0] + 2  # +2 для синуса и косинуса временной метки
act_dim = env.action_space.n
expert_obs = np.copy(states)
expert_acts = np.copy(actions)

In [6]:
class Policy(nn.Module):
    def __init__(self, obs_dim, act_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_dim, 64), nn.ReLU(),
            nn.Linear(64, act_dim)
        )

    def forward(self, obs):
        logits = self.net(obs)
        return Categorical(logits=logits)

    def get_action(self, obs):
        dist = self.forward(obs)
        return dist.sample().item()

In [7]:
class Discriminator(nn.Module):
    def __init__(self, obs_dim, act_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_dim + act_dim, 64), nn.ReLU(),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )

    def forward(self, obs, act):
        act_onehot = F.one_hot(act, num_classes=act_dim).float()
        x = torch.cat([obs, act_onehot], dim=1)
        return self.net(x)

In [8]:
class TrajectoryBuffer:
    def __init__(self):
        self.obs, self.acts, self.rews = [], [], []

    def store(self, o, a, r):
        self.obs.append(o)
        self.acts.append(a)
        self.rews.append(r)

    def get(self):
        return (
            torch.tensor(np.array(self.obs), dtype=torch.float32),
            torch.tensor(np.array(self.acts), dtype=torch.long),
            torch.tensor(np.array(self.rews), dtype=torch.float32)
        )

In [9]:
policy = Policy(obs_dim, act_dim)
discrim = Discriminator(obs_dim, act_dim)
policy_opt = optim.Adam(policy.parameters(), lr=1e-3)
discrim_opt = optim.Adam(discrim.parameters(), lr=1e-3)

In [10]:
for epoch in range(3000):
    buf = TrajectoryBuffer()
    obs, _ = env.reset()
    done = False
    total_reward = 0
    t = 0

    while not done:
        # Добавляем синус и косинус временной метки в состояние
        obs_with_time = np.append(obs, [np.sin(t), np.cos(t)])
        obs_tensor = torch.tensor(obs_with_time, dtype=torch.float32).unsqueeze(0)
        action = policy.get_action(obs_tensor)
        next_obs, _, terminated, truncated, _ = env.step(action)
        done = terminated or truncated

        buf.store(obs_with_time, action, 0)
        obs = next_obs
        t += 1

    agent_obs, agent_acts, _ = buf.get()

    idxs = np.random.choice(len(expert_obs), len(agent_obs), replace=False)
    exp_obs = torch.tensor(expert_obs[idxs], dtype=torch.float32)
    exp_acts = torch.tensor(expert_acts[idxs], dtype=torch.long)

    for _ in range(2):
        discrim_opt.zero_grad()

        discr_exp = discrim(exp_obs, exp_acts)
        discr_ag = discrim(agent_obs, agent_acts)
        disc_loss = -torch.mean(torch.log(discr_exp + 1e-8)) - torch.mean(torch.log(1 - discr_ag + 1e-8))

        disc_loss.backward()
        discrim_opt.step()

    with torch.no_grad():
        rewards = -torch.log(1 - discrim(agent_obs, agent_acts) + 1e-8).squeeze()

    policy_opt.zero_grad()
    dist = policy(agent_obs)
    log_probs = dist.log_prob(agent_acts)
    loss = -(log_probs * rewards).mean()
    loss.backward()
    policy_opt.step()

    if epoch % 10 == 0:
        print(f"Epoch {epoch}: GAIL Loss {loss.item():.3f}, Disc Loss {disc_loss.item():.3f}")

Epoch 0: GAIL Loss 0.778, Disc Loss 1.414
Epoch 10: GAIL Loss 0.732, Disc Loss 1.358
Epoch 20: GAIL Loss 0.730, Disc Loss 1.320
Epoch 30: GAIL Loss 0.743, Disc Loss 1.296
Epoch 40: GAIL Loss 0.667, Disc Loss 1.255
Epoch 50: GAIL Loss 0.613, Disc Loss 1.289
Epoch 60: GAIL Loss 0.585, Disc Loss 1.250
Epoch 70: GAIL Loss 0.490, Disc Loss 1.268
Epoch 80: GAIL Loss 0.470, Disc Loss 1.284
Epoch 90: GAIL Loss 0.455, Disc Loss 1.268
Epoch 100: GAIL Loss 0.433, Disc Loss 1.267
Epoch 110: GAIL Loss 0.459, Disc Loss 1.280
Epoch 120: GAIL Loss 0.462, Disc Loss 1.252
Epoch 130: GAIL Loss 0.481, Disc Loss 1.295
Epoch 140: GAIL Loss 0.450, Disc Loss 1.258
Epoch 150: GAIL Loss 0.424, Disc Loss 1.250
Epoch 160: GAIL Loss 0.419, Disc Loss 1.237
Epoch 170: GAIL Loss 0.431, Disc Loss 1.221
Epoch 180: GAIL Loss 0.402, Disc Loss 1.208
Epoch 190: GAIL Loss 0.414, Disc Loss 1.175
Epoch 200: GAIL Loss 0.407, Disc Loss 1.156
Epoch 210: GAIL Loss 0.364, Disc Loss 1.118
Epoch 220: GAIL Loss 0.333, Disc Loss 1.073

Протестируйте ваш алгоритм

In [11]:
for episode in range(10):
    obs, _ = env.reset()
    done = False
    total_reward = 0
    t = 0

    while not done:
        obs_with_time = np.append(obs, [np.sin(t), np.cos(t)])
        obs_tensor = torch.tensor(obs_with_time, dtype=torch.float32).unsqueeze(0)
        action = policy.get_action(obs_tensor)
        next_obs, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated

        obs = next_obs
        total_reward += reward
        t += 1
    print(total_reward)

env.close()

-190.0
-163.0
-200.0
-198.0
-132.0
-164.0
-194.0
-136.0
-200.0
-200.0
