Skip to content

Commit

Permalink
remove comments
Browse files Browse the repository at this point in the history
  • Loading branch information
fattorib committed May 24, 2023
1 parent 0f7a783 commit 22cda84
Showing 1 changed file with 1 addition and 6 deletions.
7 changes: 1 addition & 6 deletions tensor_parallel_emulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,9 @@ def cumul_minibatch_step(carry, x_y):
cumul_loss, cumul_grads = carry
minibatch = x_y
loss, grads = grad_fn(to_bf16(params), minibatch)
# grads = with_sharding_constraint(grads, model_mp_spec)
cumul_loss, cumul_grads = jax.tree_util.tree_map(
jnp.add, (cumul_loss, cumul_grads), (loss, grads)
)
# jax.debug.print("{x}", x = jax.tree_map(lambda x: x.shape, grads))

# cumul_grads = with_sharding_constraint(cumul_grads, model_mp_spec)

)
return (cumul_loss, cumul_grads), None

grad_init = to_bf16(jax.tree_util.tree_map(jnp.zeros_like, params))
Expand Down

0 comments on commit 22cda84

Please sign in to comment.