diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b1078cd285..113b9d7bdb 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -5,10 +5,10 @@ name: CI on: push: - branches: [ master ] + branches: [master] pull_request: - branches: [ master ] - + branches: [master] + workflow_dispatch: jobs: build: env: @@ -23,38 +23,38 @@ jobs: python-version: ["3.8", "3.9", "3.10", "3.11"] steps: - - uses: actions/checkout@v3 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - # cpu version of pytorch - pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cpu + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + # cpu version of pytorch + pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cpu - # Install Atari Roms - pip install autorom - wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64 - base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz - AutoROM --accept-license --source-file Roms.tar.gz + # Install Atari Roms + pip install autorom + wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64 + base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz + AutoROM --accept-license --source-file Roms.tar.gz - pip install .[extra_no_roms,tests,docs] - # Use headless version - pip install opencv-python-headless - - name: Lint with ruff - run: | - make lint - - name: Build the doc - run: | - make doc - - name: Check codestyle - run: | - make check-codestyle - - name: Type check - run: | - make type - - name: Test with pytest - run: | - make pytest + pip install .[extra_no_roms,tests,docs] + # Use headless version + pip install opencv-python-headless + - name: Lint with ruff + run: | + make lint + - name: Build the doc + run: | + make doc + - name: Check codestyle + run: | + make check-codestyle + - name: Type check + run: | + make type + - name: Test with pytest + run: | + make pytest diff --git a/stable_baselines3/common/vec_env/vec_normalize.py b/stable_baselines3/common/vec_env/vec_normalize.py index cda2370aa0..ca0a980127 100644 --- a/stable_baselines3/common/vec_env/vec_normalize.py +++ b/stable_baselines3/common/vec_env/vec_normalize.py @@ -125,6 +125,20 @@ def _sanity_checks(self) -> None: f"not {self.observation_space}" ) + @staticmethod + def _maybe_cast_reward(reward: np.ndarray) -> np.ndarray: + """ + Cast `np.float64` reward datatype to `np.float32`, + keep the others dtype unchanged. + + :param dtype: The original action space dtype + :return: ``np.float32`` if the dtype was float64, + the original dtype otherwise. + """ + if reward.dtype == np.float64: + return reward.astype(np.float32) + return reward + def __getstate__(self) -> Dict[str, Any]: """ Gets state for pickling. @@ -254,7 +268,8 @@ def normalize_reward(self, reward: np.ndarray) -> np.ndarray: """ if self.norm_reward: reward = np.clip(reward / np.sqrt(self.ret_rms.var + self.epsilon), -self.clip_reward, self.clip_reward) - return reward.astype(np.float32) + + return self._maybe_cast_reward(reward) def unnormalize_obs(self, obs: Union[np.ndarray, Dict[str, np.ndarray]]) -> Union[np.ndarray, Dict[str, np.ndarray]]: # Avoid modifying by reference the original object