# SAC smoke test on HalfCheetah-v5

Minimal setup that instantiates the SAC agent with a short HalfCheetah-v5 rollout so we can sanity-check wiring before running long experiments.


In [None]:
from __future__ import annotations

import functools

import torch
from tensordict.nn import InteractionType, TensorDictModule
from tensordict.nn.distributions import NormalParamExtractor
from torch import nn, optim
from torchrl.collectors import aSyncDataCollector, SyncDataCollector
from torchrl.data import (
    LazyMemmapStorage,
    LazyTensorStorage,
    TensorDictPrioritizedReplayBuffer,
    TensorDictReplayBuffer,
)
from torchrl.envs import (
    CatTensors,
    Compose,
    DMControlEnv,
    DoubleToFloat,
    EnvCreator,
    ParallelEnv,
    TransformedEnv,
)
from torchrl.envs.libs.gym import GymEnv, set_gym_backend
from torchrl.envs.transforms import InitTracker, RewardSum, StepCounter
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules import MLP, ProbabilisticActor, ValueOperator
from torchrl.modules.distributions import TanhNormal
from torchrl.objectives import SoftUpdate
from torchrl.objectives.sac import SACLoss
from torchrl.record import VideoRecorder

# ====================================================================
# Environment utils
# -----------------


def env_maker(cfg, device="cpu", from_pixels=False):
    lib = cfg.env.library
    if lib in ("gym", "gymnasium"):
        with set_gym_backend(lib):
            return GymEnv(
                cfg.env.name,
                device=device,
                from_pixels=from_pixels,
                pixels_only=False,
            )
    elif lib == "dm_control":
        env = DMControlEnv(
            cfg.env.name, cfg.env.task, from_pixels=from_pixels, pixels_only=False
        )
        return TransformedEnv(
            env, CatTensors(in_keys=env.observation_spec.keys(), out_key="observation")
        )
    else:
        raise NotImplementedError(f"Unknown lib {lib}.")


def apply_env_transforms(env, max_episode_steps=1000):
    transformed_env = TransformedEnv(
        env,
        Compose(
            InitTracker(),
            StepCounter(max_episode_steps),
            DoubleToFloat(),
            RewardSum(),
        ),
    )
    return transformed_env


def make_environment(cfg, logger=None):
    """Make environments for training and evaluation."""
    partial = functools.partial(env_maker, cfg=cfg)
    parallel_env = ParallelEnv(
        cfg.collector.env_per_collector,
        EnvCreator(partial),
        serial_for_single=True,
    )
    parallel_env.set_seed(cfg.env.seed)

    train_env = apply_env_transforms(parallel_env, cfg.env.max_episode_steps)

    partial = functools.partial(env_maker, cfg=cfg, from_pixels=cfg.logger.video)
    trsf_clone = train_env.transform.clone()
    if cfg.logger.video:
        trsf_clone.insert(
            0, VideoRecorder(logger, tag="rendering/test", in_keys=["pixels"])
        )
    eval_env = TransformedEnv(
        ParallelEnv(
            cfg.collector.env_per_collector,
            EnvCreator(partial),
            serial_for_single=True,
        ),
        trsf_clone,
    )
    return train_env, eval_env


def make_train_environment(cfg):
    """Make environments for training and evaluation."""
    partial = functools.partial(env_maker, cfg=cfg)
    parallel_env = ParallelEnv(
        cfg.collector.env_per_collector,
        EnvCreator(partial),
        serial_for_single=True,
    )
    parallel_env.set_seed(cfg.env.seed)

    train_env = apply_env_transforms(parallel_env, cfg.env.max_episode_steps)

    return train_env

In [None]:
import gymnasium as gym
import torch
from pathlib import Path

from rlopt.agent.sac import SAC, SACRLOptConfig
from rlopt.config_base import NetworkConfig
from rlopt.env_utils import env_maker


def build_halfcheetah_config(total_frames: int = 1024) -> SACRLOptConfig:
    """Return a minimally tuned SAC config for HalfCheetah-v5 smoke tests."""

    cfg = SACRLOptConfig()
    cfg.seed = 7
    cfg.device = "auto"

    # Environment + collector knobs -----------------------------------------
    cfg.env.env_name = "HalfCheetah-v5"
    cfg.env.library = "gymnasium"
    cfg.env.num_envs = 8

    cfg.collector.frames_per_batch = 1000
    cfg.collector.total_frames = total_frames
    cfg.collector.init_random_frames = 25_000
    cfg.collector.prefetch = 1

    # Replay + optimization -------------------------------------------------
    cfg.loss.mini_batch_size = 256
    cfg.replay_buffer.size = 1_000_000
    cfg.replay_buffer.prefetch = 1
    cfg.optim.lr = 3e-4
    cfg.optim.scheduler = None
    cfg.optim.target_update_polyak = 0.995
    cfg.sac.utd_ratio = 1.0

    # Lightweight logging so the notebook runs without external services ----
    log_dir = Path.cwd() / "notebook_logs"
    cfg.logger.backend = ""
    cfg.logger.log_to_file = True
    cfg.logger.log_dir = str(log_dir)
    cfg.logger.exp_name = "sac_halfcheetah_smoketest"
    cfg.logger.python_level = "info"

    # Network dimensions depend on env specs --------------------------------
    dummy_env = gym.make(cfg.env.env_name)
    obs_dim = dummy_env.observation_space.shape[0]
    action_dim = dummy_env.action_space.shape[0]
    dummy_env.close()

    cfg.policy.input_dim = obs_dim
    cfg.policy.num_cells = [256, 256]
    cfg.policy.activation_fn = "relu"

    cfg.q_function = NetworkConfig(
        num_cells=[256, 256],
        input_dim=obs_dim + action_dim,
        input_keys=["observation", "action"],
        activation_fn="relu",
    )
    cfg.collector.env_per_collector = cfg.env.num_envs
    cfg.env.seed = 8
    cfg.env.max_episode_steps = 1000

    return cfg


# ---------------------------------------------------------------------------
cfg = build_halfcheetah_config(total_frames=1_000_000)
train_env, eval_env = make_environment(cfg, logger=None)
agent = SAC(env=train_env, eval_env=eval_env, config=cfg)
agent.train()

# Quick deterministic rollout on a fresh eval env to verify predict() -------
eval_env = env_maker(cfg, device=cfg.device)
with torch.no_grad():
    td = eval_env.reset().to(agent.device)
    action = agent.predict(td.clone())
print(f"Deterministic action sample (shape={tuple(action.shape)}):\n{action}")

