Skip to content

Commit

Permalink
Feature/error on nan (#255)
Browse files Browse the repository at this point in the history
* update pytorch version

* refactor Approximation.step()

* raise RuntimeError if clip_grad enable and norm is non-finte

* update pytorch version in github workflow
  • Loading branch information
cpnota committed Aug 5, 2021
1 parent 01836e0 commit e8f3f16
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 9 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
run: |
sudo apt-get install swig
sudo apt-get install unrar
pip install torch==1.8.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install torch==1.9.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
make install
AutoROM -v
python -m atari_py.import_roms $(python -c 'import site; print(site.getsitepackages()[0])')/multi_agent_ale_py/ROM
Expand Down
20 changes: 14 additions & 6 deletions all/approximation/approximation.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def eval(self, *inputs):
with torch.no_grad():
# check current mode
mode = self.model.training
# switch to eval mode
# switch model to eval mode
self.model.eval()
# run forward pass
result = self.model(*inputs)
Expand Down Expand Up @@ -144,14 +144,11 @@ def step(self):
Returns:
self: The current Approximation object
'''
if self._clip_grad != 0:
utils.clip_grad_norm_(self.model.parameters(), self._clip_grad)
self._clip_grad_norm()
self._optimizer.step()
self._optimizer.zero_grad()
self._step_lr_scheduler()
self._target.update()
if self._scheduler:
self._writer.add_schedule(self._name + '/lr', self._optimizer.param_groups[0]['lr'])
self._scheduler.step()
self._checkpointer()
return self

Expand All @@ -164,3 +161,14 @@ def zero_grad(self):
'''
self._optimizer.zero_grad()
return self

def _clip_grad_norm(self):
'''Clip the gradient norm if set. Raises RuntimeError if norm is non-finite.'''
if self._clip_grad != 0:
utils.clip_grad_norm_(self.model.parameters(), self._clip_grad, error_if_nonfinite=True)

def _step_lr_scheduler(self):
'''Step the . Raises RuntimeError if norm is non-finite.'''
if self._scheduler:
self._writer.add_schedule(self._name + '/lr', self._optimizer.param_groups[0]['lr'])
self._scheduler.step()
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@
"gym~=0.18.0", # common environment interface
"numpy>=1.18.0", # math library
"matplotlib>=3.3.0", # plotting library
"opencv-python~=3.4.0", # used by atari wrappers
"torch~=1.8.0", # core deep learning library
"opencv-python~=3.4.0", # used by atari wrappers
"torch~=1.9.0", # core deep learning library
"tensorboard>=2.3.0", # logging and visualization
"tensorboardX>=2.1.0", # tensorboard/pytorch compatibility
"cloudpickle>=1.2.0", # used to copy environments
Expand Down

0 comments on commit e8f3f16

Please sign in to comment.