Skip to content

Commit

Permalink
reorder loss operations to reduce all-gathers
Browse files Browse the repository at this point in the history
  • Loading branch information
fattorib committed May 24, 2023
1 parent 2615cce commit d2239f3
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 18 deletions.
4 changes: 2 additions & 2 deletions src/utils/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@ def cross_entropy_loss(labels: jnp.array, logits: jnp.array) -> jnp.array:
jnp.array: Loss
"""

return -jnp.mean(
jnp.sum(labels * nn.log_softmax(logits.astype(jnp.float32), axis=-1), axis=-1)
return (
-jnp.sum(labels * nn.log_softmax(logits.astype(jnp.float32), axis=-1), axis=-1)
)
38 changes: 22 additions & 16 deletions tensor_parallel_emulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,19 @@ def train_step(
accum_steps: int = 8,
model: Any = None,
model_mp_spec: Any = None,
batch_loss_spec: Any = None,
batch_grad_spec: Any = None,
dp: int = 0
):
"""
Computes loss/grads for a single batch of data.
"""
# jax.debug.print("{x}", x = batch.shape)
_,context = batch.shape

batch = batch.reshape(accum_steps, -1, context)

params = with_sharding_constraint(params, model_mp_spec)

# force loss to only compute locally?
def loss_fn(params, batch):
_, loss = model.apply(
{"params": params["params"]},
Expand All @@ -57,7 +57,8 @@ def loss_fn(params, batch):
train=True,
rngs={"dropout": rng_key},
)
return loss
loss = with_sharding_constraint(loss, batch_loss_spec)
return loss.mean()

grad_fn = jax.value_and_grad(loss_fn, has_aux=False)

Expand All @@ -77,7 +78,7 @@ def cumul_minibatch_step(carry, x_y):

# grads = jax.tree_map(lambda x: x.reshape([1, *x.shape]), grads)
# grads = jax.tree_map(lambda x: jax.numpy.repeat(x, dp, axis=0), grads)
# grads = with_sharding_constraint(grads, batch_grad_spec)
grads = with_sharding_constraint(grads, batch_grad_spec) # if dp, this tells XLA to replicate gradients across all devices
# grads = jax.tree_map(lambda x: jax.numpy.mean(x, axis=0), grads)

loss, grads = jax.tree_util.tree_map(lambda x: x / accum_steps, (loss, grads))
Expand Down Expand Up @@ -129,10 +130,14 @@ def to_bf16(t):
# 4-way dp / 2 way tp

# named sharding is easier to follow along with
mesh = Mesh(np.array(jax.devices()).reshape(args.dp,args.mp), ('dp','mp'))
if args.mp > 1:
mesh = Mesh(np.array(jax.devices()).reshape(args.dp,args.mp), ('dp','mp'))

else:
mesh = Mesh(np.array(jax.devices()).reshape(args.dp), ('dp'))

# indicates batch dim is split across dp axis
batch_sharding = jax.sharding.NamedSharding(mesh, P('dp', None))
batch_sharding = jax.sharding.NamedSharding(mesh, P('dp'))
no_shard = NamedSharding(mesh, None)

# # Setting up model + param spec
Expand All @@ -148,7 +153,9 @@ def to_bf16(t):

else:
param_spec = no_shard
batch_grad_spec = set_partitions_rules(param_shape, mesh, _get_partition_rules_dp)
# batch_grad_spec = set_partitions_rules(param_shape, mesh, _get_partition_rules_dp)
batch_grad_spec = no_shard
batch_loss_spec = batch_sharding

configs = OmegaConf.load("conf/model_config.yaml")
model_info = configs[MODEL_SIZE]
Expand All @@ -166,7 +173,7 @@ def to_bf16(t):
with mesh:
#TODO: Rng sharding is not currently correct
train_step_dp = jax.jit(
partial(train_step, model=model, accum_steps=GRAD_ACCUM_STEPS, model_mp_spec = param_spec,batch_grad_spec = batch_grad_spec, dp = args.dp),
partial(train_step, model=model, accum_steps=GRAD_ACCUM_STEPS, model_mp_spec = param_spec,batch_grad_spec = batch_grad_spec, batch_loss_spec = batch_loss_spec, dp = args.dp),
in_shardings=(param_spec, batch_sharding, no_shard),
out_shardings=(param_spec,no_shard)
)
Expand All @@ -175,8 +182,7 @@ def to_bf16(t):

init_batch = jax.numpy.ones(shape=(BATCH_SIZE, CTX_LEN), dtype=jax.numpy.int32)

# batch = jax.device_put(init_batch, batch_sharding)
batch = init_batch
batch = jax.device_put(init_batch, batch_sharding)
grads,metrics = train_step_dp(params, batch, dropout_rng)

start = time()
Expand All @@ -185,14 +191,14 @@ def to_bf16(t):
# jax.debug.visualize_array_sharding(batch)

for i in tqdm(range(NUM_PASSES)):
with jax.profiler.trace("jax-trace", create_perfetto_link=False):
dropout_rng, rng = jax.random.split(dropout_rng)
# with jax.profiler.trace("jax-trace", create_perfetto_link=True):
dropout_rng, rng = jax.random.split(dropout_rng)

batch = jax.numpy.ones(shape=(BATCH_SIZE, CTX_LEN), dtype=jax.numpy.int32)
# batch = jax.device_put(batch, batch_sharding)
grads, metrics = train_step_dp(params, batch, dropout_rng)
batch = jax.numpy.ones(shape=(BATCH_SIZE, CTX_LEN), dtype=jax.numpy.int32)
batch = jax.device_put(batch, batch_sharding)
grads, metrics = train_step_dp(params, batch, dropout_rng)

# params = jax.tree_map(lambda x,y : x - 0.01*y, params, grads)
params = jax.tree_map(lambda x,y : x - 0.01*y, params, grads)

print(metrics)
total_time = time() - start
Expand Down

0 comments on commit d2239f3

Please sign in to comment.