Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug #8

Closed
Nightbringers opened this issue Feb 2, 2024 · 34 comments
Closed

bug #8

Nightbringers opened this issue Feb 2, 2024 · 34 comments
Assignees

Comments

@Nightbringers
Copy link

Nightbringers commented Feb 2, 2024

It seems root_metadata not used in update_root, parameter are mismatch in mcts.py.

line 81
eval_state = self.update_root(eval_state, env_state, root_metadata, params)

line 97

def update_root(self, tree: MCTSTree, root_embedding: chex.ArrayTree, 
                 params: chex.ArrayTree, **kwargs) -> MCTSTree:
   
     key, tree = get_rng(tree)
     root_policy_logits, root_value = self.eval_fn(root_embedding, params, key)
     root_policy = jax.nn.softmax(root_policy_logits)
     root_node = tree.at(tree.ROOT_INDEX)
     root_node = self.update_root_node(root_node, root_policy, root_value, root_embedding)
     return set_root(tree, root_node)
@lowrollr lowrollr self-assigned this Feb 2, 2024
@lowrollr
Copy link
Owner

lowrollr commented Feb 2, 2024

Thank you for pointing this out!

All examples currently use the AlphaZero class, which implements update_root with the correct parameter ordering.

for MCTS parameter ordering is indeed incorrect

@lowrollr
Copy link
Owner

lowrollr commented Feb 2, 2024

4e55c72

@Nightbringers
Copy link
Author

Can you give a example that how to train use mult gpus?

I'm try to integrate with my custom JAX environment and have some problem. this is the key code:

def step_fn(state, action):
    new_state = env.step(state, action)
    return new_state, StepMetadata(
        rewards=new_state.rewards,
        action_mask=new_state.legal_action_mask,
        terminated=new_state.terminated,
        cur_player_id=new_state.current_player,
    )

def init_fn(key):
    state = env.init(key)
    return state, StepMetadata(
        rewards=state.rewards,
        action_mask=state.legal_action_mask,
        terminated=state.terminated,
        cur_player_id=state.current_player,
    )
az_evaluator = AlphaZero(MCTS)(
    eval_fn = eval_fn,
    num_iterations = 100,
    max_nodes = 200,
    branching_factor=82,
    action_selector = PUCTSelector()
)
def env_step_fn(state, action):
    new_state = env.step(state, action)
    return new_state, StepMetadata(
        rewards=new_state.rewards,
        action_mask=new_state.legal_action_mask,
        terminated=new_state.terminated,
        cur_player_id=new_state.current_player,
    )

eval_key, rng_key = jax.random.split(rng_key)
eval_keys = jax.random.split(eval_key, batch_size)
env, env_state_metadata = jax.vmap(init_fn)(eval_keys)`
evaluator_init = partial(az_evaluator.init, template_embedding=env)
eval_state = jax.vmap(evaluator_init)(eval_keys)
output = az_evaluator.evaluate(
            eval_state=eval_state,
            env_state=env,
            root_metadata=env_state_metadata,
            params=param,
            env_step_fn=env_step_fn
        )

this is the error:

eval_state = self.update_root(eval_state, env_state, root_metadata, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/turbozero/core/evaluators/alphazero.py", line 31, in update_root
key, tree = get_rng(tree)
^^^^^^^^^^^^^
File "/turbozero/core/trees/tree.py", line 143, in get_rng
rng, new_rng = jax.random.split(tree.key, 2)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/lib/python3.11/site-packages/jax/_src/random.py", line 303, in split
return _return_prng_keys(wrapped, _split(typed_key, num))
^^^^^^^^^^^^^^^^^^^^^^
File "/python3.11/site-packages/jax/_src/random.py", line 286, in _split
raise TypeError("split accepts a single key, but was given a key array of"
TypeError: split accepts a single key, but was given a key array ofshape (100,) != (). Use jax.vmap for batching.

@lowrollr
Copy link
Owner

lowrollr commented Feb 2, 2024

evaluate should be vmapped, try something like this:

eval_key, env_key, rng_key = jax.random.split(rng_key, 3)
eval_keys = jax.random.split(eval_key, batch_size)
env_keys = jax.random.split(env_key, batch_size)

env_state, env_state_metadata = jax.vmap(init_fn)(env_keys)

# template embedding should not have a batch dimension
template_env_state, _ = init_fn(jax.random.PRNGKey(0))
evaluator_init = partial(az_evaluator.init, template_embedding=template_env_state)
eval_state = jax.vmap(evaluator_init)(eval_keys)

evaluate = partial(az_evaluator.evaluate, 
                   env_step_fn=env_step_fn,
                   params=param)

output = jax.vmap(evaluate)(
        eval_state=eval_state,
        env_state=env_state,
        root_metadata=env_state_metadata)

I recommend using the Trainer class as described here, https://github.com/lowrollr/turbozero/blob/main/notebooks/hello_world.ipynb

I haven't fully documented a lot of the underlying classes yet which do have their peculiarities -- Trainer should be more straightforward to work with.

@Nightbringers
Copy link
Author

thanks,now have a new problem.

File "/turbozero/core/evaluators/mcts/mcts.py", line 81, in evaluate
eval_state = self.update_root(eval_state, env_state, root_metadata, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/turbozero/core/evaluators/alphazero.py", line 56, in update_root
return set_root(tree, root_node)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/turbozero/core/trees/tree.py", line 121, in set_root
data=jax.tree_util.tree_map(
^^^^^^^^^^^^^^^^^^^^^^^
File "/turbozero/core/trees/tree.py", line 122, in
lambda x, y: x.at[tree.ROOT_INDEX].set(y),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py", line 497, in set
return scatter._scatter_update(self.array, self.index, values, lax.scatter,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/lib/python3.11/site-packages/jax/_src/ops/scatter.py", line 80, in _scatter_update
return _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/lib/python3.11/site-packages/jax/_src/ops/scatter.py", line 115, in _scatter_impl
y = jnp.broadcast_to(y, tuple(indexer.slice_shape))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 1227, in broadcast_to
return util._broadcast_to(array, shape)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "//lib/python3.11/site-packages/jax/_src/numpy/util.py", line 425, in _broadcast_to
for arr_d, shape_d in safe_zip(arr_shape, shape_tail))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: safe_zip() argument 2 is shorter than argument 1

@lowrollr
Copy link
Owner

lowrollr commented Feb 3, 2024

Could you share your code?

@Nightbringers
Copy link
Author

Nightbringers commented Feb 3, 2024

def one_step(prev):
        """Execute one self-play move using MCTS.
        """
        env_state, rng_key, step, eval_state = prev
        rng_key, rng_key_next = jax.random.split(rng_key, 2)
        env_state_metadata = StepMetadata(
            rewards=env_state.rewards,
            action_mask=env_state.legal_action_mask,
            terminated=env_state.terminated,
            cur_player_id=env_state.current_player,
        )
        terminated = env_state.terminated

        output = jax.vmap(evaluate)(
            eval_state=eval_state,
            env_state=env_state,
            root_metadata=env_state_metadata)

        env_state = step_fn_move(env_state, output.action)

        eval_state = output.eval_state
        eval_state = az_evaluator.step(eval_state, output.action)

        return (env_state, rng_key_next, step + 1, env_new3, eval_state)

@Nightbringers
Copy link
Author

Nightbringers commented Feb 3, 2024

it use jax.lax.scan:

output = jax.lax.scan(one_step, (env_state, rng_key, step,eval_state), None, length=400, unroll=1)

maybe this place is incorrect?

eval_state = output.eval_state
eval_state = az_evaluator.step(eval_state, output.action)

@lowrollr
Copy link
Owner

lowrollr commented Feb 3, 2024

The error suggests to me that some of the data might have been passed to evaluate without a batch dimension, just from debugging similar errors before.

Trainer already implements the necessary functions to collect episodes and progress the env and evaluator to the next state, I'm curious why you are implementing these yourself? Is there a feature you need that's missing or something that's confusing or unclear? I'm hopeful that most users won't need to provide anything besides environment dynamics functions.

@Nightbringers
Copy link
Author

Nightbringers commented Feb 3, 2024

yes,there a feature i need that's missing, like go self-play data process, it need symmetries.

@lowrollr
Copy link
Owner

lowrollr commented Feb 3, 2024

Could you describe how the feature should work? I can help you find the right spot to integrate it.

@Nightbringers
Copy link
Author

And i need use mult gpus, I don't see something like pmap in code. Need a example that how to train use mult gpus.

@lowrollr
Copy link
Owner

lowrollr commented Feb 3, 2024

Supporting multiple GPUs should be straightforward to add but is not currently included. I can work on this and let you know when it is supported. Created a separate issue to track.

@Nightbringers
Copy link
Author

Could you describe how the feature should work? I can help you find the right spot to integrate it.

in go game, self-play data usually augmentation by something like np.rot90 and np.fliplr.

@Nightbringers
Copy link
Author

about that error, could be this place need vmap?

eval_state = output.eval_state
eval_state = az_evaluator.step(eval_state, output.action)

@lowrollr
Copy link
Owner

lowrollr commented Feb 3, 2024

about that error, could be this place need vmap?

eval_state = output.eval_state
eval_state = az_evaluator.step(eval_state, output.action)

Yes, every function that operates on an evaluator state assumes a singular input rather than a batch and should be vmapped.

Could you describe how the feature should work? I can help you find the right spot to integrate it.

in go game, self-play data usually augmentation by something like np.rot90 and np.fliplr.

If I were you, I would implement a custom class extending Trainer, and overwrite collect to place each of the symmetries into the replay buffer with self.memory_buffer.add_experience. You will need to make sure policy_mask and policy_weights are augmented consistently with your game state data.

here: https://github.com/lowrollr/turbozero/blob/main/core/training/train.py#L103-L140

You should only need to extend the behavior of collect.

It might be a good idea for me to allow for users to specify any number of transforms to apply to augment experiences prior to storing in replay memory -- this is a fairly common use-case so it ideally should not require a custom class.

@lowrollr
Copy link
Owner

lowrollr commented Feb 3, 2024

If I were you, I would implement a custom class extending Trainer, and overwrite collect to place each of the symmetries into the replay buffer with self.memory_buffer.add_experience. You will need to make sure policy_mask and policy_weights are augmented consistently with your game state data.

class GoTrainer(Trainer):
   def collect(self,
        state: CollectionState,
        params: chex.ArrayTree
    ) -> CollectionState:
        step_key, new_key = jax.random.split(state.key)
        eval_output, new_env_state, new_metadata, terminated, rewards = \
            self.step_train(
                key = step_key,
                env_state = state.env_state,
                env_state_metadata = state.metadata,
                eval_state = state.eval_state,
                params = params
            )
        buffer_state = self.memory_buffer.add_experience(
            state = state.buffer_state,
            experience = BaseExperience(
                env_state=state.env_state,
                policy_mask=state.metadata.action_mask,
                policy_weights=eval_output.policy_weights,
                reward=jnp.empty_like(state.metadata.rewards)
            )
        )

       # generate symmetries here and add to replay memory just like above
        
        buffer_state = jax.lax.cond(
            terminated,
            lambda s: self.memory_buffer.assign_rewards(s, rewards),
            lambda s: s,
            buffer_state
        )

        return state.replace(
            key=new_key,
            eval_state=eval_output.eval_state,
            env_state=new_env_state,
            buffer_state=buffer_state,
            metadata=new_metadata
        )

@Nightbringers
Copy link
Author

yes, I'm going to try this, but if can't use mult gpus, training will be slow.

I change to this:

eval_state = output.eval_state
eval_state = jax.vmap(az_evaluator.step)(eval_state, output.action)

But error occurred before the code reached that.
that error is in this:

output = jax.vmap(evaluate)(
            eval_state=eval_state,
            env_state=env_state,
            root_metadata=env_state_metadata)

and i print(state.observation.shape) in eval_fn, it shows (19, 19, 17) , without batch dimension, is this normal?

@Nightbringers
Copy link
Author

I seems find why this error happened.

@lowrollr
Copy link
Owner

lowrollr commented Feb 3, 2024

I am going to close this issue as the original problem has been resolved.

I hope to continue to expand the documentation of this project so that it is more clear as to how to approach unique use-cases.

Thank you for your questions and feedback! Please create another issue if you run in to problems and feel free to email me if you have more questions.

@lowrollr lowrollr closed this as completed Feb 3, 2024
@Nightbringers
Copy link
Author

The speed that I test was slow. Then I test https://github.com/lowrollr/turbozero/blob/main/notebooks/hello_world.ipynb this example, seems also slow from gpu utilization.

what is max_nodes means? I found that max_nodes has a significant impact on speed; the larger max_nodes is, the slower the speed.
mctx-az is also in this situation. and mctx-az is faster than turbozero that in my test.

@lowrollr
Copy link
Owner

lowrollr commented Feb 3, 2024

max_nodes reflects the maximum capacity of the tree. Trees cannot be sized dynamically so must have a maximum number of nodes set prior to collection. It makes sense that performance gets worse as max_nodes increases, this ultimately makes it so operations are on larger matrices.

I am aware that the backend is currently less performant than mctx. It is a priority for me to fix this.

@Nightbringers
Copy link
Author

I'm still confuse about max_nodes.
if max_nodes=100, is it means every node have 100 child node? or is it means that will search 100 node at most in this evaluate?
why you suggest it larger than num_simulations?that will be very very slow.

yes, I look forward to it running faster.

@Nightbringers
Copy link
Author

Nightbringers commented Feb 3, 2024

I think the most factor affecting speed lies in the computational aspect of the CPU. GPU not fully utilized.

@lowrollr
Copy link
Owner

lowrollr commented Feb 3, 2024

max_nodes refers to the maximum capacity of the tree -- so yes it does mean a tree with max_nodes=100 will at most evaluate 100 distinct game states.

I advise setting it higher than num_iterations because this implementation re-uses subtrees from a previous search -- so most of the time the tree will already be partially populated when a new search is started. Setting max_nodes higher than num_iterations means that there will be room for num_iterations nodes in the tree more often. I have yet to document out-of-bounds behavior but it works similarly to https://github.com/lowrollr/mctx-az.

Increasing max_nodes linearly increases the memory footprint of the search tree data structure. It's definitely a trade-off of speed vs. accuracy to set it higher/lower, and its value relative to 'num_iterations' should be problem-dependent (branching factor and # of iterations both matter). Setting max_nodes = num_iterations is fine if you're worried about speed.

@lowrollr
Copy link
Owner

lowrollr commented Feb 3, 2024

I think the most factor affecting speed lies in the computational aspect of the CPU. GPU not fully utilized.

Do you have evidence of this? I'm not aware of any CPU-bound portion of the training loop and in my experiments I've had no issues with GPU utilization.

@Nightbringers
Copy link
Author

In my experiments, keep num_simulations unchange, when max_nodes = 32, it cost 460 seconds,
when max_nodes = 600, it cost 2200 seconds,gpu Pwr:Usage/Cap is more lower than when max_nodes = 32. Test with mctx-az. Turbozero is same situation.
In this situation, because num_simulations unchange, so the computational workload of the GPU remains unchanged, so problem is in cpu?

@lowrollr
Copy link
Owner

lowrollr commented Feb 4, 2024

Some more details could be useful here.

What are you running, one call to search?
What is num_simulations set to?
What environment are you using?

I'm not sure this is entirely unexpected behavior.

Computational workload is higher when max_nodes is increased. Search operates on tensors of size [num_batches, max_nodes, ... ]

@Nightbringers
Copy link
Author

environment: go 19*19
model_size: just like alphazero paper, 40 block, 256 channl
num_simulations: 128
batch size: 50 per device
step: 410

There is another strange phenomenon: as the number of steps increases, the gpu utilization rate becomes lower and lower

@lowrollr
Copy link
Owner

lowrollr commented Feb 5, 2024

Thank you for letting me know, I will look into what you are describing to see if I can replicate it and diagnose.

Are you running more than one call to search (MCTS.evaluate)? Some of the weird behavior you are describing could be down to JIT-compilation overhead on the first call.

@Nightbringers
Copy link
Author

It's use jax.lax.scan,like this: jax.lax.scan(one_step, (env_state, rng_key, step,eval_state), None, length=410, unroll=1)

JIT-compilation should only affect one times.
environment: go 19*19
model_size: just like alphazero paper, 40 block, 256 channl
num_simulations: 128
batch size: 50 per device
step: 410
when max_nodes = 32, it cost 460 seconds,
when max_nodes = 600, it cost 2200 seconds,
gpu Pwr:Usage/Cap is more lower than when max_nodes = 32.
Test with mctx-az. Turbozero is same situation.

@lowrollr
Copy link
Owner

lowrollr commented Feb 5, 2024

Thanks for pointing this out, I will see if I can replicate.

@Nightbringers
Copy link
Author

Are you familiar with the MuZero algorithm? I have some questions and hope you can help me.

@lowrollr
Copy link
Owner

lowrollr commented Feb 8, 2024

I haven't worked with MuZero specifically as much but have read the paper. Feel free to send me an email with your questions and I'll see if I can answer.

I still plan on looking into the issues you mention but have been very busy and have not had a chance yet.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants