-
-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
bug #8
Comments
Thank you for pointing this out! All examples currently use the for |
Can you give a example that how to train use mult gpus? I'm try to integrate with my custom JAX environment and have some problem. this is the key code:
this is the error: eval_state = self.update_root(eval_state, env_state, root_metadata, params) |
eval_key, env_key, rng_key = jax.random.split(rng_key, 3)
eval_keys = jax.random.split(eval_key, batch_size)
env_keys = jax.random.split(env_key, batch_size)
env_state, env_state_metadata = jax.vmap(init_fn)(env_keys)
# template embedding should not have a batch dimension
template_env_state, _ = init_fn(jax.random.PRNGKey(0))
evaluator_init = partial(az_evaluator.init, template_embedding=template_env_state)
eval_state = jax.vmap(evaluator_init)(eval_keys)
evaluate = partial(az_evaluator.evaluate,
env_step_fn=env_step_fn,
params=param)
output = jax.vmap(evaluate)(
eval_state=eval_state,
env_state=env_state,
root_metadata=env_state_metadata) I recommend using the I haven't fully documented a lot of the underlying classes yet which do have their peculiarities -- |
thanks,now have a new problem. File "/turbozero/core/evaluators/mcts/mcts.py", line 81, in evaluate |
Could you share your code? |
|
it use jax.lax.scan: output = jax.lax.scan(one_step, (env_state, rng_key, step,eval_state), None, length=400, unroll=1) maybe this place is incorrect?
|
The error suggests to me that some of the data might have been passed to
|
yes,there a feature i need that's missing, like go self-play data process, it need symmetries. |
Could you describe how the feature should work? I can help you find the right spot to integrate it. |
And i need use mult gpus, I don't see something like pmap in code. Need a example that how to train use mult gpus. |
Supporting multiple GPUs should be straightforward to add but is not currently included. I can work on this and let you know when it is supported. Created a separate issue to track. |
in go game, self-play data usually augmentation by something like np.rot90 and np.fliplr. |
about that error, could be this place need vmap?
|
Yes, every function that operates on an evaluator state assumes a singular input rather than a batch and should be vmapped.
If I were you, I would implement a custom class extending here: https://github.com/lowrollr/turbozero/blob/main/core/training/train.py#L103-L140 You should only need to extend the behavior of It might be a good idea for me to allow for users to specify any number of transforms to apply to augment experiences prior to storing in replay memory -- this is a fairly common use-case so it ideally should not require a custom class. |
class GoTrainer(Trainer):
def collect(self,
state: CollectionState,
params: chex.ArrayTree
) -> CollectionState:
step_key, new_key = jax.random.split(state.key)
eval_output, new_env_state, new_metadata, terminated, rewards = \
self.step_train(
key = step_key,
env_state = state.env_state,
env_state_metadata = state.metadata,
eval_state = state.eval_state,
params = params
)
buffer_state = self.memory_buffer.add_experience(
state = state.buffer_state,
experience = BaseExperience(
env_state=state.env_state,
policy_mask=state.metadata.action_mask,
policy_weights=eval_output.policy_weights,
reward=jnp.empty_like(state.metadata.rewards)
)
)
# generate symmetries here and add to replay memory just like above
buffer_state = jax.lax.cond(
terminated,
lambda s: self.memory_buffer.assign_rewards(s, rewards),
lambda s: s,
buffer_state
)
return state.replace(
key=new_key,
eval_state=eval_output.eval_state,
env_state=new_env_state,
buffer_state=buffer_state,
metadata=new_metadata
)
|
yes, I'm going to try this, but if can't use mult gpus, training will be slow. I change to this: eval_state = output.eval_state But error occurred before the code reached that.
and i print(state.observation.shape) in eval_fn, it shows (19, 19, 17) , without batch dimension, is this normal? |
I seems find why this error happened. |
I am going to close this issue as the original problem has been resolved. I hope to continue to expand the documentation of this project so that it is more clear as to how to approach unique use-cases. Thank you for your questions and feedback! Please create another issue if you run in to problems and feel free to email me if you have more questions. |
The speed that I test was slow. Then I test https://github.com/lowrollr/turbozero/blob/main/notebooks/hello_world.ipynb this example, seems also slow from gpu utilization. what is max_nodes means? I found that max_nodes has a significant impact on speed; the larger max_nodes is, the slower the speed. |
I am aware that the backend is currently less performant than mctx. It is a priority for me to fix this. |
I'm still confuse about max_nodes. yes, I look forward to it running faster. |
I think the most factor affecting speed lies in the computational aspect of the CPU. GPU not fully utilized. |
I advise setting it higher than Increasing |
Do you have evidence of this? I'm not aware of any CPU-bound portion of the training loop and in my experiments I've had no issues with GPU utilization. |
In my experiments, keep num_simulations unchange, when max_nodes = 32, it cost 460 seconds, |
Some more details could be useful here. What are you running, one call to search? I'm not sure this is entirely unexpected behavior. Computational workload is higher when |
environment: go 19*19 There is another strange phenomenon: as the number of steps increases, the gpu utilization rate becomes lower and lower |
Thank you for letting me know, I will look into what you are describing to see if I can replicate it and diagnose. Are you running more than one call to search ( |
It's use jax.lax.scan,like this: jax.lax.scan(one_step, (env_state, rng_key, step,eval_state), None, length=410, unroll=1) JIT-compilation should only affect one times. |
Thanks for pointing this out, I will see if I can replicate. |
Are you familiar with the MuZero algorithm? I have some questions and hope you can help me. |
I haven't worked with MuZero specifically as much but have read the paper. Feel free to send me an email with your questions and I'll see if I can answer. I still plan on looking into the issues you mention but have been very busy and have not had a chance yet. |
It seems root_metadata not used in update_root, parameter are mismatch in mcts.py.
line 81
eval_state = self.update_root(eval_state, env_state, root_metadata, params)
line 97
The text was updated successfully, but these errors were encountered: