Skip to content

Commit

Permalink
Refactor/state array (#167)
Browse files Browse the repository at this point in the history
* rename StateTensor to StateArray

* rename State.from_list(states) to State.array(states)

* add documentation to the State object

* fix documentation for State

* add StateArray documentation

* add State and StateArray paragraph to Basic Concepts doc

* fix bug in State constructor
  • Loading branch information
cpnota committed Sep 28, 2020
1 parent dffe5b7 commit 867ba55
Show file tree
Hide file tree
Showing 22 changed files with 272 additions and 84 deletions.
4 changes: 2 additions & 2 deletions all/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import all.nn
from all.core import State, StateTensor
from all.core import State, StateArray

__all__ = ['nn', 'State', 'StateTensor']
__all__ = ['nn', 'State', 'StateArray']
2 changes: 1 addition & 1 deletion all/agents/vpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def _act(self, state, reward):

def _terminal(self, state, reward):
self._rewards.append(reward)
features = State.from_list(self._features)
features = State.array(self._features)
rewards = torch.tensor(self._rewards, device=features.device)
log_pis = torch.stack(self._log_pis)
self._trajectories.append((features, rewards, log_pis))
Expand Down
10 changes: 5 additions & 5 deletions all/approximation/q_dist_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
from torch import nn
import torch_testing as tt
from all.core import StateTensor
from all.core import StateArray
from all.approximation import QDist

STATE_DIM = 1
Expand All @@ -23,7 +23,7 @@ def test_atoms(self):
tt.assert_almost_equal(self.q.atoms, torch.tensor([-2, -1, 0, 1, 2]))

def test_q_values(self):
states = StateTensor(torch.randn((3, STATE_DIM)), (3,))
states = StateArray(torch.randn((3, STATE_DIM)), (3,))
probs = self.q(states)
self.assertEqual(probs.shape, (3, ACTIONS, ATOMS))
tt.assert_almost_equal(
Expand Down Expand Up @@ -53,7 +53,7 @@ def test_q_values(self):
)

def test_single_q_values(self):
states = StateTensor(torch.randn((3, STATE_DIM)), (3,))
states = StateArray(torch.randn((3, STATE_DIM)), (3,))
actions = torch.tensor([0, 1, 0])
probs = self.q(states, actions)
self.assertEqual(probs.shape, (3, ATOMS))
Expand All @@ -73,7 +73,7 @@ def test_single_q_values(self):
)

def test_done(self):
states = StateTensor(torch.randn((3, STATE_DIM)), (3,), mask=torch.tensor([1, 0, 1]))
states = StateArray(torch.randn((3, STATE_DIM)), (3,), mask=torch.tensor([1, 0, 1]))
probs = self.q(states)
self.assertEqual(probs.shape, (3, ACTIONS, ATOMS))
tt.assert_almost_equal(
Expand All @@ -100,7 +100,7 @@ def test_done(self):
)

def test_reinforce(self):
states = StateTensor(torch.randn((3, STATE_DIM)), (3,))
states = StateArray(torch.randn((3, STATE_DIM)), (3,))
actions = torch.tensor([0, 1, 0])
original_probs = self.q(states, actions)
tt.assert_almost_equal(
Expand Down
6 changes: 3 additions & 3 deletions all/approximation/q_network_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch.nn.functional import smooth_l1_loss
import torch_testing as tt
import numpy as np
from all.core import State, StateTensor
from all.core import State, StateArray
from all.approximation import QNetwork, FixedTarget

STATE_DIM = 2
Expand All @@ -21,7 +21,7 @@ def optimizer(params):
self.q = QNetwork(self.model, optimizer)

def test_eval_list(self):
states = StateTensor(
states = StateArray(
torch.randn(5, STATE_DIM),
(5,),
mask=torch.tensor([1, 1, 0, 1, 0])
Expand All @@ -40,7 +40,7 @@ def test_eval_list(self):
)

def test_eval_actions(self):
states = StateTensor(torch.randn(3, STATE_DIM), (3,))
states = StateArray(torch.randn(3, STATE_DIM), (3,))
actions = [1, 2, 0]
result = self.q.eval(states, actions)
self.assertEqual(result.shape, torch.Size([3]))
Expand Down
6 changes: 3 additions & 3 deletions all/approximation/v_network_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from torch import nn
import torch_testing as tt
from all.approximation.v_network import VNetwork
from all.core import StateTensor
from all.core import StateArray

STATE_DIM = 2

Expand All @@ -22,7 +22,7 @@ def setUp(self):
self.v = VNetwork(self.model, optimizer)

def test_reinforce_list(self):
states = StateTensor(
states = StateArray(
torch.randn(5, STATE_DIM),
(5,),
mask=torch.tensor([1, 1, 0, 1, 0])
Expand All @@ -35,7 +35,7 @@ def test_reinforce_list(self):
tt.assert_almost_equal(result, torch.tensor([0.9732854, 0.5453826, 0., 0.4344811, 0.]))

def test_multi_reinforce(self):
states = StateTensor(
states = StateArray(
torch.randn(6, STATE_DIM),
(6,),
mask=torch.tensor([1, 1, 0, 1, 0, 0, 0])
Expand Down
4 changes: 2 additions & 2 deletions all/bodies/time.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from all.core import StateTensor
from all.core import StateArray
from ._body import Body

class TimeFeature(Body):
Expand All @@ -9,7 +9,7 @@ def __init__(self, agent, scale=0.001):
super().__init__(agent)

def process_state(self, state):
if isinstance(state, StateTensor):
if isinstance(state, StateArray):
if self.timestep is None:
self.timestep = torch.zeros(state.shape, device=state.device)
observation = torch.cat((state.observation, self.scale * self.timestep.view(-1, 1)), dim=1)
Expand Down
6 changes: 3 additions & 3 deletions all/bodies/time_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import unittest
import torch
import torch_testing as tt
from all.core import State, StateTensor
from all.core import State, StateArray
from all.bodies import TimeFeature


Expand Down Expand Up @@ -57,14 +57,14 @@ def test_reset(self):
[0.3923, -0.2236, -0.3195, -1.2050, 1e-3]), atol=1e-04)

def test_multi_env(self):
state = StateTensor(torch.randn(2, 2), (2,))
state = StateArray(torch.randn(2, 2), (2,))
self.agent.act(state)
tt.assert_allclose(self.test_agent.last_state.observation, torch.tensor(
[[0.3923, -0.2236, 0.], [-0.3195, -1.2050, 0.]]), atol=1e-04)
self.agent.act(state)
tt.assert_allclose(self.test_agent.last_state.observation, torch.tensor(
[[0.3923, -0.2236, 1e-3], [-0.3195, -1.2050, 1e-3]]), atol=1e-04)
self.agent.act(StateTensor(state.observation, (2,), done=torch.tensor([False, True])))
self.agent.act(StateArray(state.observation, (2,), done=torch.tensor([False, True])))
tt.assert_allclose(self.test_agent.last_state.observation, torch.tensor(
[[0.3923, -0.2236, 2e-3], [-0.3195, -1.2050, 2e-3]]), atol=1e-04)
self.agent.act(state)
Expand Down
4 changes: 2 additions & 2 deletions all/bodies/vision.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from all.core import State, StateTensor
from all.core import State, StateArray
from ._body import Body

class FrameStack(Body):
Expand All @@ -16,7 +16,7 @@ def process_state(self, state):
self._frames = self._frames[1:] + [state.observation]
if self._lazy:
return LazyState.from_state(state, self._frames)
if isinstance(state, StateTensor):
if isinstance(state, StateArray):
return state.update('observation', torch.cat(self._frames, dim=1))
return state.update('observation', torch.cat(self._frames, dim=0))

Expand Down
4 changes: 2 additions & 2 deletions all/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .state import State, StateTensor
from .state import State, StateArray

__all__ = ['State', 'StateTensor']
__all__ = ['State', 'StateArray']

0 comments on commit 867ba55

Please sign in to comment.