-
Notifications
You must be signed in to change notification settings - Fork 229
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Load trained weights into agent and get predicted actions #2
Comments
Hey, I updated the checkpoint code and run scripts to make this easy. You can now train an agent as normal: python dreamerv3/train.py --run.logdir ~/logdir/train --configs crafter --run.script train And then load the agent to evaluate it in an environment without further training: python dreamerv3/train.py --run.logdir ~/logdir/eval --configs crafter \
--run.script eval_only --run.from_checkpoint ~/logdir/train/checkpoint.pkl You also asked for a minimal example to load the agent yourself. The relevant code is in env = ...
config = ...
step = embodied.Counter()
agent = Agent(env.obs_space, env.act_space, step, config)
checkpoint = embodied.Checkpoint()
checkpoint.agent = agent
checkpoint.load('path/to/checkpoint.pkl', keys=['agent'])
state = None
act, state = agent.policy(obs, state, mode='eval') |
Great, thank you so much. |
Hello ! Any idea how the initial state should be formatted? I am trying to run from the minimal code you provided above with a gym environment. However:
returns an error. From the policy() function I can see that it is expecting:
Is there a way to initialize the state? |
You can just pass in This is done in |
@ThomasRochefortB did you manage to run the minimal snippet successfully? On my side, I run into an error that seems to come from the observation data being not formatted as expected when passed to the agent policy. Here is what I did: Training, everything work well:
And then, when I try to run the minimal snippet inference like that:
I get an error from that line |
@jobesu14 The easiest way is to take I think the issue in your snippet is that the policy expects a batch size. I think it should look something like the following but don't have the time to test it right now: logdir = embodied.Path('~/logdir/test_1')
config = embodied.Config.load(logdir / 'config.yaml')
env = crafter.Env()
env = from_gym.FromGym(env)
env = dreamerv3.wrap_env(env, config)
step = embodied.Counter()
agent = dreamerv3.Agent(env.obs_space, env.act_space, step, config)
checkpoint = embodied.Checkpoint()
checkpoint.agent = agent
checkpoint.load(logdir / 'checkpoint.ckpt', keys=['agent'])
state = None
act = {'action': env.act_space['action'].sample(), 'reset': np.array(True)}
while True:
obs = env.step(act)
obs = {k: v[None] for k, v in obs.items()}
act, state = agent.policy(obs, state, mode='eval')
act = {'action': act['action'][0], 'reset': obs['is_last'][0]} |
I didn't maange to make the above snipets work. I kept having issues with the parameters of the agent initial_policy calls. Basically adding a callback for each step in the Hope this helps. |
@danijar Any chance you've been able to figure out rendering from your snippet above? This still produces the error mentioned above for me. It would be amazing to have code for |
If you env returns an image key as part of the observation dictionary, it will already get rendered and can be viewed in TensorBoard. Does that work for your use case? |
Thanks for such a great research algo @danijar! Wondering if there's any good way now to |
After 2 long days, found the answer based on this issue! Leaving here for anyone who wants to render their DRL AIs playing: In import matplotlib.pyplot as plt
def step(self, action):
if action['reset'] or self._done:
self._done = False
obs = self._env.reset()
return self._obs(obs, 0.0, is_first=True)
if self._act_dict:
action = self._unflatten(action)
else:
action = action[self._act_key]
obs, reward, self._done, self._info = self._env.step(action)
plt.imshow(obs)
plt.show(block=False)
plt.pause(0.001) # Pause to ensure the plot updates
plt.clf() # Clear the plot so that the next image replaces this one
return self._obs(
obs, reward,
is_last=bool(self._done),
is_terminal=bool(self._info.get('is_terminal', self._done))) |
Hello, thanks for sharing this amazing piece of work!
Is there an easy way to load the trained weights from the
checkpoint.pkl
into an agent and get the predicted action from it (agent.policy(obs, state, mode='eval'))['action']
). The idea would be to visualize online in a standard pygame loop for instsance?Looking at the code, I guess the easiest would be to use the
dremerv3.Agent
class, but I don't understand how to load the weights from the pickle file 😅The text was updated successfully, but these errors were encountered: