In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from collections import deque
from datetime import datetime
import numpy as np
from tqdm import trange
from dm_control import suite
import yaml
from train_baselines import train_baselines
from rollout import record_rollout
from utils import save_rewards, save_plot
import os
import shutil

/usr/local/lib/python3.10/dist-packages/glfw/__init__.py:917: GLFWError: (65550) b'X11: The DISPLAY environment variable is missing'
2025-05-05 04:02:55.447938: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1746417775.469767    7632 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1746417775.476289    7632 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1746417775.493764    7632 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1746417775.493788    7632 computation_placer.cc:177] computation placer already registered. Please check linkage and 

In [2]:
with open("config.yaml", "r") as f:
    cfg = yaml.safe_load(f)

EPISODES = cfg["episodes"]
STEPS_PER_EPISODE = cfg["steps_per_episode"]
LATENT_DIM = cfg["latent_dim"]
HIDDEN_DIM = cfg["hidden_dim"]
BATCH_SIZE = cfg["batch_size"]
LEARNING_RATE = cfg["learning_rate"]
GAMMA = cfg["gamma"]
SEED = cfg["seed"]

In [3]:
torch.manual_seed(SEED)
np.random.seed(SEED)

class DMCWrapper:
    def __init__(self, domain, task):
        self.env = suite.load(domain, task)
        self.action_spec = self.env.action_spec()
        self.obs_dim = sum(np.prod(v.shape) for v in self.env.observation_spec().values())
        self.act_dim = self.action_spec.shape[0]
        self.reset()

    def reset(self):
        self.ts = self.env.reset()
        return self._flatten_obs(self.ts)

    def step(self, action):
        self.ts = self.env.step(action)
        obs = self._flatten_obs(self.ts)
        reward = self.ts.reward or 0.0
        done = self.ts.last()
        return obs, reward, done

    def _flatten_obs(self, ts):
        return np.concatenate([v.ravel() for v in ts.observation.values()])

In [4]:
class MemoryAugmentedWorldModel(nn.Module):
    def __init__(self, obs_dim, act_dim):
        super().__init__()
        self.rnn = nn.GRU(obs_dim + act_dim, HIDDEN_DIM, batch_first=True)
        self.linear = nn.Linear(HIDDEN_DIM, LATENT_DIM)

    def forward(self, obs, act, hidden=None):
        x = torch.cat([obs, act], dim=-1).unsqueeze(1)
        out, hidden = self.rnn(x, hidden)
        z = self.linear(out.squeeze(1))
        return z, hidden

class Actor(nn.Module):
    def __init__(self, latent_dim, act_dim):
        super().__init__()
        self.fc1 = nn.Linear(latent_dim, HIDDEN_DIM)
        self.fc2 = nn.Linear(HIDDEN_DIM, act_dim)

    def forward(self, z):
        return torch.tanh(self.fc2(F.relu(self.fc1(z))))

class Critic(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.fc1 = nn.Linear(latent_dim, HIDDEN_DIM)
        self.fc2 = nn.Linear(HIDDEN_DIM, 1)

    def forward(self, z):
        return self.fc2(F.relu(self.fc1(z)))

In [5]:
class ReplayBuffer:
    def __init__(self, size=10000):
        self.buffer = deque(maxlen=size)

    def add(self, *transition):
        self.buffer.append(transition)

    def sample(self, batch_size):
        idx = np.random.choice(len(self.buffer), batch_size, replace=False)
        batch = [self.buffer[i] for i in idx]
        obs, act, rew, next_obs, done = zip(*batch)
        return (
            torch.tensor(obs, dtype=torch.float32),
            torch.tensor(act, dtype=torch.float32),
            torch.tensor(rew, dtype=torch.float32).unsqueeze(-1),
            torch.tensor(next_obs, dtype=torch.float32),
            torch.tensor(done, dtype=torch.float32).unsqueeze(-1)
        )

In [8]:
def train_dreamer(save_dir):
    env = DMCWrapper("cheetah", "run")
    wm = MemoryAugmentedWorldModel(env.obs_dim, env.act_dim)
    actor = Actor(LATENT_DIM, env.act_dim)
    critic = Critic(LATENT_DIM)

    wm_opt = optim.Adam(wm.parameters(), lr=LEARNING_RATE)
    actor_opt = optim.Adam(actor.parameters(), lr=LEARNING_RATE)
    critic_opt = optim.Adam(critic.parameters(), lr=LEARNING_RATE)

    buffer = ReplayBuffer()
    rewards = []

    for ep in trange(EPISODES, desc="Dreamer"):
        obs = env.reset()
        total = 0
        hidden = None

        for _ in range(STEPS_PER_EPISODE):
            obs_t = torch.from_numpy(np.array(obs)).float().unsqueeze(0)
            z, hidden = wm(obs_t, torch.zeros((1, env.act_dim)), hidden)
            action = actor(z).squeeze(0).detach().numpy()

            next_obs, reward, done = env.step(action)
            buffer.add(obs, action, reward, next_obs, done)
            obs = next_obs
            total += reward
            if done:
                break

            if len(buffer.buffer) >= BATCH_SIZE:
                o, a, r, o2, d = buffer.sample(BATCH_SIZE)

                # --- Критик ---
                z1, _ = wm(o, a)
                with torch.no_grad():
                    next_a = actor(wm(o2, a)[0].detach())
                    next_z, _ = wm(o2, next_a)
                    target = r + GAMMA * critic(next_z) * (1 - d)
                critic_loss = F.mse_loss(critic(z1), target)
                critic_opt.zero_grad()
                critic_loss.backward()
                critic_opt.step()

                # --- Актор ---
                z_actor, _ = wm(o, a.detach())
                pred_action = actor(z_actor)
                actor_loss = -critic(wm(o, pred_action)[0]).mean()
                actor_opt.zero_grad()
                actor_loss.backward()
                actor_opt.step()

                # --- World Model ---
                z_wm, _ = wm(o, a)
                wm_loss = F.mse_loss(z_wm, z_wm.detach())
                wm_opt.zero_grad()
                wm_loss.backward()
                wm_opt.step()

        rewards.append(total)

    return rewards


In [None]:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
exp_dir = os.path.join("results", f"exp_{timestamp}")
os.makedirs(exp_dir, exist_ok=True)

# === Copy config.yaml into experiment folder ===
shutil.copy("config.yaml", os.path.join(exp_dir, "config.yaml"))

# === Run Dreamer training ===
print("🧠 Training DreamerV3-lite agent...")
dreamer_rewards = train_dreamer(exp_dir)

# === Run PPO and TD3 ===
print("🤖 Training PPO agent...")
ppo_rewards = train_baselines("ppo", exp_dir)

print("⚡ Training TD3 agent...")
td3_rewards = train_baselines("td3", exp_dir)

# === Save rewards to CSV ===
save_rewards(dreamer_rewards, os.path.join(exp_dir, "rewards_dreamer.csv"))
save_rewards(ppo_rewards, os.path.join(exp_dir, "rewards_ppo.csv"))
save_rewards(td3_rewards, os.path.join(exp_dir, "rewards_td3.csv"))

# === Save plot ===
save_plot(
    [dreamer_rewards, ppo_rewards, td3_rewards],
    ["Dreamer", "PPO", "TD3"],
    os.path.join(exp_dir, "reward_comparison.png")
)

# === Record rollout ===
print("🎥 Recording rollout...")
record_rollout(os.path.join(exp_dir, "dreamer_rollout.mp4"))

print(f"✅ Эксперимент завершён. Результаты сохранены в: {exp_dir}")

🧠 Training DreamerV3-lite agent...


Dreamer:   7%|▋         | 2/30 [00:09<02:14,  4.82s/it]