In [None]:
# TODO Reimplement the uniform sampling in tianshou. Here is some stuff from the previous

In [1]:
from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Optional, Tuple

import gymnasium as gym
import numpy as np
from tqdm import trange
from tianshou.data import Batch, ReplayBuffer


def _sample_uniform_action(
    action_space: gym.Space, rng: np.random.Generator
) -> np.ndarray | int:
    # Prefer Gym's sampling, but we want controllable RNG for reproducibility.
    if isinstance(action_space, gym.spaces.Box):
        low = action_space.low
        high = action_space.high
        a = rng.uniform(low=low, high=high, size=action_space.shape).astype(np.float32)
        return np.clip(a, low, high)
    elif isinstance(action_space, gym.spaces.Discrete):
        return int(rng.integers(0, action_space.n))
    elif isinstance(action_space, gym.spaces.MultiDiscrete):
        return rng.integers(
            low=0, high=action_space.nvec, size=action_space.shape, dtype=np.int64
        )
    elif isinstance(action_space, gym.spaces.MultiBinary):
        return rng.integers(0, 2, size=action_space.n, dtype=np.int8)
    else:
        # Fallback: uses env's internal RNG
        return action_space.sample()


def collect_uniform_transitions_tianshou(
    env_name: str,
    buffer_size: int = 100_000,
    seed: int = 0,
    print_rejections: bool = True,
) -> ReplayBuffer:
    """
    Build a Tianshou ReplayBuffer by:
      1) sampling a *uniform* non-terminal state (via env.uniform_reset()),
      2) sampling a *uniform* action,
      3) stepping once, and
      4) storing the transition.

    IMPORTANT:
    - This assumes your environment wrapper implements `uniform_reset()` that returns a valid non-terminal state.
    - If you *don't* have uniform_reset, you cannot generally sample uniformly from the true reachable state space.
    """
    rng = np.random.default_rng(seed)

    env = gym.make(env_name)

    # If you already have your own UniformExplorationWrapper, keep using it.
    # It must implement `uniform_reset()` and (optionally) track `rejections`.
    if not hasattr(env, "uniform_reset"):
        raise AttributeError(
            "env must provide uniform_reset() to sample a uniformly random non-terminal state. "
            "Wrap it with your UniformExplorationWrapper first."
        )

    buf = ReplayBuffer(size=buffer_size)

    # Optional: track how many proposals your wrapper rejected
    rejections_before = getattr(env, "rejections", 0)

    for _ in trange(buffer_size):
        obs = env.uniform_reset()  # your wrapper decides what "uniform" means
        act = _sample_uniform_action(env.action_space, rng)

        obs_next, rew, terminated, truncated, info = env.step(act)

        # Tianshou expects 1D batch dimension for add(); use shape (1, ...)
        b = Batch(
            obs=np.asarray(obs),
            act=np.asarray(act),
            rew=np.asarray(rew, dtype=np.float32),
            terminated=np.asarray(terminated, dtype=bool),
            truncated=np.asarray(truncated, dtype=bool),
            obs_next=np.asarray(obs_next),
            info=info,  # optional
        )
        buf.add(b)

    if print_rejections and hasattr(env, "rejections"):
        rejections_after = getattr(env, "rejections", 0)
        print(
            f"{rejections_after - rejections_before} of proposed states were rejected"
        )

    return buf

In [21]:
env = gym.make("Pendulum-v1")
env.reset(options={"uniform": True})
# env.observation_space

(array([-0.99650794, -0.08349816, -0.7885808 ], dtype=float32), {})

In [2]:
rb = collect_uniform_transitions_tianshou("Pendulum-v1")

AttributeError: env must provide uniform_reset() to sample a uniformly random non-terminal state. Wrap it with your UniformExplorationWrapper first.