[dynamic shapes] Add some support for handling tracers as dimensions #9335
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
The main goal here is to revert the changes to lax.py:_bradcasting_shape_rule
to use again _try_broadcast_shapes, and DimensionHandlers, so that the code
that deals with shapes in lax_numpy and the shaping rules is kept simple
and all details about handling dimension sizes that are non constant are
factored out in the DimensionHandler classes.
We do this by creating a DimensionHandler for the case when the dimension
size is a Tracer. (At the moment this works only for DynamicJaxprTracer; we
will need to relax this.)
In this PR the only conclusive operation on Tracer dimensions is that
two identical Tracers are deemed to represent equal dimensions.