Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace jax.lax.select with jnp.where #25

Merged
merged 4 commits into from Jan 19, 2023

Conversation

nlsfnr
Copy link
Contributor

@nlsfnr nlsfnr commented Aug 30, 2022

Thanks for the awesome work!

This PR fixes an issue where jax.lax.select complains about dtypes not being equal when adjusting the DynamicLossScale.

The exception's stack trace ends with:

File ".../jmp/_src/loss_scale.py", line 147, in adjust
    loss_scale = jax.lax.select(
TypeError: lax.select requires arguments to have the same dtypes, got float32, int32. (Tip: jnp.where is a similar function that does automatic type promotion on inputs).

My code looks similar to

scale = jmp.DynamicLossScale(jnp.asarray(2 ** 15))
...
gradients, scale = gradient_fn(..., scale)
gradients = scale.unscale(gradients)
gradients_finite = jmp.all_finite(gradients)
scale = scale.adjust(gradients_finite)  # This line throws the exception
...

@nlsfnr nlsfnr changed the title Replaced jax.lax.select with jnp.where Replace jax.lax.select with jnp.where Aug 30, 2022
@tomhennigan
Copy link
Collaborator

Hi @nlsfnr , thank you for the PR.

I think the issue in your code snippet is that we expect the loss scale to not be an integer, in our example we suggest using the half dtype for the loss scale:

loss_scale = jmp.DynamicLossScale(jmp.half_dtype()(2 ** 15))

I think where in this case would be meaningfully different than select because it will broadcast the condition variable (as well as doing dtype promotion). I would be concerned (without seeing benchmarks) that this would negatively impact the performance of select on large trees.

Perhaps one way to have the best of both worlds would be to add a __post_init__ method in DynamicLossScale which raises an error if an integer value is used for self.loss_scale and keep the select as it is. WDYT?

@nlsfnr
Copy link
Contributor Author

nlsfnr commented Oct 12, 2022

Hi @tomhennigan,

I accidentally closed the PR because I undid the only commit in it.

The new commit enforces the dtype of loss_scale and min_loss_scale, using a __post_init__ hook, just as you said.

@nlsfnr nlsfnr reopened this Oct 12, 2022
@tomhennigan tomhennigan requested review from tomhennigan and removed request for lorenrose1013 October 12, 2022 17:45
Copy link
Collaborator

@tomhennigan tomhennigan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, this looks good! We'll pull it on our internal copy later this week and a robot account will take care of marking this PR as merged shortly after.

@nlsfnr
Copy link
Contributor Author

nlsfnr commented Oct 14, 2022

Hi @tomhennigan,

I noticed some shortcomings in my latest commits.

First, they might break downstream code. An example are the unittests that I had to change, i.e. DynamicLossScale(1) worked before but will throw a TypeError now. Although this is good, it might break existing code. As a fix, I changed the behaviour so that instead of throwing a TypeError it now gives a warning.

Second, and probably more importantly, the change seems to make DynamicLossScale incompatible with jax.pmap. An error similar to the one below occurs during the compilation:

 File ".../training.py", line 174, in train
    tmp = train_step(p_tokens, p_params, p_opt_state, p_loss_scale, p_rng)
  File ".../jmp/jmp/_src/loss_scale.py", line 150, in tree_unflatten
    return cls(loss_scale, counter, period, factor)
  File "<string>", line 8, in __init__
  File ".../jmp/jmp/_src/loss_scale.py", line 126, in __post_init__
    loss_scale_dtype = jnp.asarray(self.loss_scale).dtype    <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
  File ".../.venv/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 1924, in asarray
    return array(a, dtype=dtype, copy=False, order=order)
  File ".../.venv/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 1872, in array
    dtype = dtypes._lattice_result_type(*leaves)[0]
TypeError: Value '<object object at 0x7fc5940989c0>' with dtype object is not a valid JAX array type. Only arrays of numeric types are supported by JAX.

I believe this is because jax will call the pmapped function with tracer objects for compilation (i.e. the <object object at 0x7fc5940989c0> object from the error), but I am not sure. As a fix, I added a try-except-else block around the highlighted line above. Please see my latest commit.

LMK what you think of this. If you find the two try-blocks in the DynamicLossScale.__post_init__ a bit ugly, I am happy to factor them out into functions.

@tomhennigan
Copy link
Collaborator

Thanks for the thoughtful comments, I agree that the post_init function is a bit busy now.

I think a useful helper function would be the following, it will act as the identity for tracers:

def as_jax_array(x) -> jax.Array:
  if not isinstance(s.loss_scale, jax.Array):
    x = jnp.asarray(x)
  return x

We can then extract out a helper for the check_floating test:

def check_floating(name, x):
  if not jnp.issubdtype(x.dtype, jnp.floating):
    warnings.warn(f'Expected floating type for {name}, got {x.dtype}')

Finally, we can re-assign all fields and apply the check in post_init:

def __post_init__(self):
  object.__setattr__(self, 'loss_scale', as_jax_array(self.loss_scale)
  object.__setattr__(self, 'min_loss_scale', as_jax_array(self.min_loss_scale)
  check_floating('loss_scale', self.loss_scale)
  check_floating('min_loss_scale', self.min_loss_scale)

@nlsfnr
Copy link
Contributor Author

nlsfnr commented Oct 19, 2022

Hi @tomhennigan , I've had a deeper look at the issues behind the mysterious <object object at 0x7fc5940989c0> value from the exception.

The TLDR is that Jax needs to determine the shape of custom PyTrees during compilation. To do this, it passes raw object() values instead of actual Arrays to the tree_unflatten method of said PyTree. In the Jax documentation they explicitly state that this can and probably will break input validation done in __init__ or __new__, just like in our case.

To avoid this, they recommend checking if type(x) is object. See here. The latest commit implements that (and cleans up the __post_init__ method).

A notable alternative to all of this would be to convert loss_scale to a floating type inside DynamicLossScale.adjust.

@tomhennigan
Copy link
Collaborator

Thank you for digging into this, I did not realise vmap passed object() instances into the tree leaves, I had assumed it would use tracers.

Looks like the unit tests pass, so we'll pull this into our internal repo and run tests to check internal usages pass. Should be merged soon 😄

copybara-service bot pushed a commit that referenced this pull request Jan 18, 2023
--
21900d0 by Nicolas Forstner <nls.forstner@gmail.com>:

enforce floating dtype for loss_scale and min_loss_scale in DynamicLossScale

--
4d9ce3d by Nicolas Forstner <nls.forstner@gmail.com>:

added unittests

FUTURE_COPYBARA_INTEGRATE_REVIEW=#25 from nlsfnr:main 4d9ce3d
PiperOrigin-RevId: 502849580
copybara-service bot pushed a commit that referenced this pull request Jan 18, 2023
--
21900d0 by Nicolas Forstner <nls.forstner@gmail.com>:

enforce floating dtype for loss_scale and min_loss_scale in DynamicLossScale

--
4d9ce3d by Nicolas Forstner <nls.forstner@gmail.com>:

added unittests

FUTURE_COPYBARA_INTEGRATE_REVIEW=#25 from nlsfnr:main 4d9ce3d
PiperOrigin-RevId: 502849580
@copybara-service copybara-service bot merged commit b3b588d into google-deepmind:main Jan 19, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants