Skip to content

Commit

Permalink
cast params to match grad dtype in fwd pass (this was removed by acci…
Browse files Browse the repository at this point in the history
…dent)
  • Loading branch information
fattorib committed May 29, 2023
1 parent 15e8a04 commit 3be5e62
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions tensor_parallel_shard_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from jax.experimental.shard_map import shard_map
from jax.experimental.pjit import pjit
from jax.lax import with_sharding_constraint
import contextlib


def parse():
Expand Down Expand Up @@ -47,6 +48,7 @@ def train_step(

# reshape to add a microbatch dimension
batch = batch.reshape(accum_steps, -1, context)
params = to_bf16(params)

def loss_fn(params, batch):
_, loss = model.apply(
Expand Down Expand Up @@ -117,9 +119,7 @@ def update_opt_state(

def to_bf16(t):
return t

import contextlib


maybe_profiler = contextlib.nullcontext()

else:
Expand All @@ -134,10 +134,10 @@ def to_bf16(t):
return jax.tree_map(
lambda x: x.astype(jnp.bfloat16) if x.dtype == jnp.float32 else x, t
)

maybe_profiler = jax.profiler.trace(
"jax-trace-pjit", create_perfetto_link=False
)
maybe_profiler = contextlib.nullcontext()
# maybe_profiler = jax.profiler.trace(
# "jax-trace-pjit", create_perfetto_link=False
# )

# Setting up device mesh
if args.mp > 1:
Expand Down

0 comments on commit 3be5e62

Please sign in to comment.