Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
ultmaster committed Jan 9, 2023
1 parent 01338f0 commit 2efe0c6
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 19 deletions.
32 changes: 17 additions & 15 deletions nni/nas/strategy/_rl_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def reset(self) -> ObservationType:
'action_history': self.action_history,
'cur_step': self.cur_step,
'action_dim': self.num_choices[self.cur_step]
}
}, {}

def step(self, action: int) -> EnvStepType | Generator[Sample, float, EnvStepType]:
"""Step the environment.
Expand All @@ -145,7 +145,7 @@ def step(self, action: int) -> EnvStepType | Generator[Sample, float, EnvStepTyp
else:
done = False

return obs, 0., done, {}
return obs, 0., done, False, {}


class TuningTrajectoryGenerator:
Expand Down Expand Up @@ -185,7 +185,8 @@ def __init__(self, search_space: Mutable, policy: PolicyFactory | BasePolicy | N
self.simplified_space: dict[str, Categorical] = {}
self.search_labels: list[str] = []
num_choices: list[int] = []
for label, mutable in search_space.simplify().items():
# Expand CategoricalMultiple to Categorical by default.
for label, mutable in search_space.simplify(is_leaf=lambda x: isinstance(x, Categorical)).items():
if isinstance(mutable, MutableAnnotation):
# Skip annotations like constraints.
continue
Expand Down Expand Up @@ -226,7 +227,8 @@ def next_sample(self) -> Sample:
The class will be in a state pending for reward after a call of :meth:`next_sample`.
It will either receive the reward via :meth:`send_reward` or be reset via another :meth:`next_sample`.
"""
obs, done = self.env.reset(), False
obs, info = self.env.reset()
done = False
last_state = None # hidden state

self._trajectory = []
Expand All @@ -238,7 +240,7 @@ def next_sample(self) -> Sample:
truncated={},
done={},
obs_next={},
info={},
info=info,
policy={}
)

Expand Down Expand Up @@ -269,15 +271,15 @@ def next_sample(self) -> Sample:
if step_count == len(self.simplified_space):
return self.sample

obs_next, rew, done, info = self.env.step(self._last_action)
assert not done, 'The environment should not be done yet.'
obs_next, rew, terminated, truncated, info = self.env.step(self._last_action)
assert not terminated, 'The environment should not be done yet.'

self._transition.update(
obs_next=obs_next,
rew=rew,
terminated=done,
truncated=False,
done=done,
terminated=terminated,
truncated=truncated,
done=terminated,
info=info
)

Expand All @@ -297,15 +299,15 @@ def send_reward(self, reward: float) -> ReplayBuffer:
If None, the sample will be ignored.
"""

obs_next, _, done, info = self.env.step(self._last_action)
assert done, 'The environment should be done.'
obs_next, _, terminated, truncated, info = self.env.step(self._last_action)
assert terminated, 'The environment should be done.'

self._transition.update(
obs_next=obs_next,
rew=reward,
terminated=done,
truncated=False,
done=done,
terminated=terminated,
truncated=truncated,
done=terminated,
info=info
)
self._trajectory.append(deepcopy(self._transition))
Expand Down
4 changes: 0 additions & 4 deletions test/algo/nas/strategy/test_rl_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,6 @@ def test_raises():
# TODO: improve message
TuningTrajectoryGenerator(search_space)

search_space = ExpressionConstraint(Categorical([0, 1], label='a') > 0)
with pytest.raises(ValueError, match='only supports Categorical'):
TuningTrajectoryGenerator(search_space)


class HiddenStatePolicy(BasePolicy):
def __init__(self, env):
Expand Down

0 comments on commit 2efe0c6

Please sign in to comment.