Skip to content

Commit

Permalink
Bugfix/state (#189)
Browse files Browse the repository at this point in the history
* make GAE schedulable

* make cliprewards toggleable

* pass on type error

* run formatter

* more robust fix for life_lost
  • Loading branch information
cpnota committed Dec 29, 2020
1 parent d2dc3ab commit fb28f66
Show file tree
Hide file tree
Showing 7 changed files with 17 additions and 4 deletions.
5 changes: 3 additions & 2 deletions all/bodies/atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@


class DeepmindAtariBody(Body):
def __init__(self, agent, lazy_frames=False, episodic_lives=True, frame_stack=4):
def __init__(self, agent, lazy_frames=False, episodic_lives=True, frame_stack=4, clip_rewards=True):
agent = FrameStack(agent, lazy=lazy_frames, size=frame_stack)
agent = ClipRewards(agent)
if clip_rewards:
agent = ClipRewards(agent)
if episodic_lives:
agent = EpisodicLives(agent)
super().__init__(agent)
Expand Down
1 change: 1 addition & 0 deletions all/bodies/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class TensorDeviceCache:
To efficiently implement device trasfer of lazy states, this class
caches the transfered tensor so that it is not copied multiple times.
'''

def __init__(self, max_size=16):
self.max_size = max_size
self.cache_data = []
Expand Down
2 changes: 2 additions & 0 deletions all/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class State(dict):
device (string):
The torch device on which component tensors are stored.
"""

def __init__(self, x, device='cpu', **kwargs):
if not isinstance(x, dict):
x = {'observation': x}
Expand Down Expand Up @@ -260,6 +261,7 @@ class StateArray(State):
device (string):
The torch device on which component tensors are stored.
"""

def __init__(self, x, shape, device='cpu', **kwargs):
if not isinstance(x, dict):
x = {'observation': x}
Expand Down
6 changes: 6 additions & 0 deletions all/environments/atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
WarpFrame,
LifeLostEnv,
)
from all.core import State


class AtariEnvironment(GymEnvironment):
Expand All @@ -31,6 +32,11 @@ def __init__(self, name, *args, **kwargs):
def name(self):
return self._name

def reset(self):
state = self._env.reset(), 0., False, {'life_lost': False}
self._state = State.from_gym(state, dtype=self._env.observation_space.dtype, device=self._device)
return self._state

def duplicate(self, n):
return [
AtariEnvironment(self._name, *self._args, **self._kwargs) for _ in range(n)
Expand Down
3 changes: 2 additions & 1 deletion all/environments/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ def name(self):
return self._name

def reset(self):
self._state = State.from_gym(self._env.reset(), dtype=self._env.observation_space.dtype, device=self._device)
state = self._env.reset(), 0., False, None
self._state = State.from_gym(state, dtype=self._env.observation_space.dtype, device=self._device)
return self._state

def step(self, action):
Expand Down
3 changes: 2 additions & 1 deletion all/memory/generalized_advantage.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import torch
from all.core import State
from all.optim import Schedulable


class GeneralizedAdvantageBuffer:
class GeneralizedAdvantageBuffer(Schedulable):
def __init__(
self,
v,
Expand Down
1 change: 1 addition & 0 deletions all/memory/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def _sample_proportional(self, batch_size):

class NStepReplayBuffer(ReplayBuffer):
'''Converts any ReplayBuffer into an NStepReplayBuffer'''

def __init__(
self,
steps,
Expand Down

0 comments on commit fb28f66

Please sign in to comment.