-
Notifications
You must be signed in to change notification settings - Fork 0
Attempt fix ci: only cast reward from float64 to float32 #2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -5,10 +5,10 @@ name: CI | |||||||||||||
|
||||||||||||||
on: | ||||||||||||||
push: | ||||||||||||||
branches: [ master ] | ||||||||||||||
branches: [master] | ||||||||||||||
pull_request: | ||||||||||||||
branches: [ master ] | ||||||||||||||
|
||||||||||||||
branches: [master] | ||||||||||||||
workflow_dispatch: | ||||||||||||||
jobs: | ||||||||||||||
build: | ||||||||||||||
env: | ||||||||||||||
|
@@ -23,38 +23,38 @@ jobs: | |||||||||||||
python-version: ["3.8", "3.9", "3.10", "3.11"] | ||||||||||||||
|
||||||||||||||
steps: | ||||||||||||||
- uses: actions/checkout@v3 | ||||||||||||||
- name: Set up Python ${{ matrix.python-version }} | ||||||||||||||
uses: actions/setup-python@v4 | ||||||||||||||
with: | ||||||||||||||
python-version: ${{ matrix.python-version }} | ||||||||||||||
- name: Install dependencies | ||||||||||||||
run: | | ||||||||||||||
python -m pip install --upgrade pip | ||||||||||||||
# cpu version of pytorch | ||||||||||||||
pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cpu | ||||||||||||||
- uses: actions/checkout@v3 | ||||||||||||||
- name: Set up Python ${{ matrix.python-version }} | ||||||||||||||
uses: actions/setup-python@v4 | ||||||||||||||
with: | ||||||||||||||
python-version: ${{ matrix.python-version }} | ||||||||||||||
- name: Install dependencies | ||||||||||||||
run: | | ||||||||||||||
python -m pip install --upgrade pip | ||||||||||||||
# cpu version of pytorch | ||||||||||||||
pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cpu | ||||||||||||||
|
||||||||||||||
# Install Atari Roms | ||||||||||||||
pip install autorom | ||||||||||||||
wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64 | ||||||||||||||
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz | ||||||||||||||
AutoROM --accept-license --source-file Roms.tar.gz | ||||||||||||||
# Install Atari Roms | ||||||||||||||
pip install autorom | ||||||||||||||
wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64 | ||||||||||||||
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Correct the redirection operator in the The use of - base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
+ base64 Roms.tar.gz.b64 --decode > Roms.tar.gz Committable suggestion
Suggested change
|
||||||||||||||
AutoROM --accept-license --source-file Roms.tar.gz | ||||||||||||||
Comment on lines
+38
to
+41
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Simplify the installation of Atari ROMs using Manually downloading and decoding the ROMs is unnecessary. pip install autorom
-wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64
-base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
-AutoROM --accept-license --source-file Roms.tar.gz
+AutoROM --accept-license This approach is cleaner and reduces potential points of failure. Committable suggestion
Suggested change
|
||||||||||||||
|
||||||||||||||
pip install .[extra_no_roms,tests,docs] | ||||||||||||||
# Use headless version | ||||||||||||||
pip install opencv-python-headless | ||||||||||||||
- name: Lint with ruff | ||||||||||||||
run: | | ||||||||||||||
make lint | ||||||||||||||
- name: Build the doc | ||||||||||||||
run: | | ||||||||||||||
make doc | ||||||||||||||
- name: Check codestyle | ||||||||||||||
run: | | ||||||||||||||
make check-codestyle | ||||||||||||||
- name: Type check | ||||||||||||||
run: | | ||||||||||||||
make type | ||||||||||||||
- name: Test with pytest | ||||||||||||||
run: | | ||||||||||||||
make pytest | ||||||||||||||
pip install .[extra_no_roms,tests,docs] | ||||||||||||||
# Use headless version | ||||||||||||||
pip install opencv-python-headless | ||||||||||||||
- name: Lint with ruff | ||||||||||||||
run: | | ||||||||||||||
make lint | ||||||||||||||
- name: Build the doc | ||||||||||||||
run: | | ||||||||||||||
make doc | ||||||||||||||
- name: Check codestyle | ||||||||||||||
run: | | ||||||||||||||
make check-codestyle | ||||||||||||||
- name: Type check | ||||||||||||||
run: | | ||||||||||||||
make type | ||||||||||||||
- name: Test with pytest | ||||||||||||||
run: | | ||||||||||||||
make pytest | ||||||||||||||
Comment on lines
+32
to
+60
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Review shell commands for compatibility issues flagged by ShellCheck. ShellCheck reported: Please inspect the shell commands for any misuse of character ranges, especially in loops or pattern matches. Ensure that ranges in expressions like Toolsactionlint
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -125,6 +125,20 @@ def _sanity_checks(self) -> None: | |||||||||||||||||||||||||||||||||
f"not {self.observation_space}" | ||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
@staticmethod | ||||||||||||||||||||||||||||||||||
def _maybe_cast_reward(reward: np.ndarray) -> np.ndarray: | ||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||
Cast `np.float64` reward datatype to `np.float32`, | ||||||||||||||||||||||||||||||||||
keep the others dtype unchanged. | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
:param dtype: The original action space dtype | ||||||||||||||||||||||||||||||||||
:return: ``np.float32`` if the dtype was float64, | ||||||||||||||||||||||||||||||||||
the original dtype otherwise. | ||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||
Comment on lines
+130
to
+137
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Correct the parameter names in the docstring The docstring of Apply this diff to fix the docstring: """
Cast `np.float64` reward datatype to `np.float32`,
keep the other dtypes unchanged.
- :param dtype: The original action space dtype
- :return: ``np.float32`` if the dtype was float64,
+ :param reward: The reward array to potentially cast
+ :return: Reward cast to ``np.float32`` if the original dtype was ``np.float64``,
the original reward otherwise.
""" Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||
if reward.dtype == np.float64: | ||||||||||||||||||||||||||||||||||
return reward.astype(np.float32) | ||||||||||||||||||||||||||||||||||
return reward | ||||||||||||||||||||||||||||||||||
Comment on lines
+138
to
+140
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Simplify the conditional return statement To make the code more concise, consider using a ternary operator. This refactor enhances readability. Apply this diff: if reward.dtype == np.float64:
return reward.astype(np.float32)
return reward
+ # Simplified using a ternary operator
+ return reward.astype(np.float32) if reward.dtype == np.float64 else reward
|
||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
def __getstate__(self) -> Dict[str, Any]: | ||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||
Gets state for pickling. | ||||||||||||||||||||||||||||||||||
|
@@ -254,7 +268,8 @@ def normalize_reward(self, reward: np.ndarray) -> np.ndarray: | |||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||
if self.norm_reward: | ||||||||||||||||||||||||||||||||||
reward = np.clip(reward / np.sqrt(self.ret_rms.var + self.epsilon), -self.clip_reward, self.clip_reward) | ||||||||||||||||||||||||||||||||||
return reward.astype(np.float32) | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
return self._maybe_cast_reward(reward) | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
def unnormalize_obs(self, obs: Union[np.ndarray, Dict[str, np.ndarray]]) -> Union[np.ndarray, Dict[str, np.ndarray]]: | ||||||||||||||||||||||||||||||||||
# Avoid modifying by reference the original object | ||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use
--extra-index-url
instead of--index-url
when installing PyTorch.Using
--index-url
replaces the default PyPI index, potentially causing dependency resolution issues. To add the PyTorch CPU repository without overriding PyPI, use--extra-index-url
:This ensures all dependencies are correctly installed from both indexes.
Committable suggestion