Skip to content

Commit

Permalink
drop unused args
Browse files Browse the repository at this point in the history
  • Loading branch information
fattorib committed May 25, 2023
1 parent 7224d90 commit 41f32f8
Showing 1 changed file with 7 additions and 19 deletions.
26 changes: 7 additions & 19 deletions tensor_parallel_emulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ def train_step(
model: Any = None,
model_mp_spec: Any = None,
batch_loss_spec: Any = None,
batch_grad_spec: Any = None,
dp: int = 0
):
"""
Computes loss/grads for a single batch of data, optionally with gradient accumulation
Expand All @@ -48,8 +46,7 @@ def train_step(
# reshape to add a microbatch dimension
batch = batch.reshape(accum_steps, -1, context)
batch = with_sharding_constraint(batch, batch_loss_spec)

params = with_sharding_constraint(params, model_mp_spec)
# params = with_sharding_constraint(params, model_mp_spec)

def loss_fn(params, batch):
_, loss = model.apply(
Expand All @@ -69,18 +66,16 @@ def cumul_minibatch_step(carry, x_y):
minibatch = x_y

loss, grads = grad_fn(to_bf16(params), minibatch)
cumul_loss, cumul_grads = jax.tree_map(
jnp.add, (cumul_loss, cumul_grads), (loss, grads)
cumul_grads = jax.tree_map(
jnp.add, cumul_grads, grads
)
return (cumul_loss, cumul_grads), None
return (cumul_loss+loss, cumul_grads), None

grad_init = to_bf16(jax.tree_util.tree_map(jnp.zeros_like, params))
grad_init = with_sharding_constraint(grad_init, model_mp_spec) # gradients follow mp sharding

(loss,grads), _ = jax.lax.scan(cumul_minibatch_step, init = (jnp.zeros(()), grad_init), xs = batch)

grads = with_sharding_constraint(grads, batch_grad_spec) # if dp only, this tells XLA to replicate gradients across all devices

loss, grads = jax.tree_util.tree_map(lambda x: x / accum_steps, (loss, grads))

metrics = {
Expand Down Expand Up @@ -127,14 +122,9 @@ def to_bf16(t):
lambda x: x.astype(jnp.bfloat16) if x.dtype == jnp.float32 else x, t
)

maybe_profiler = jax.profiler.trace("jax-trace", create_perfetto_link=True)
maybe_profiler = jax.profiler.trace("jax-trace-pjit", create_perfetto_link=False)

# Setting up device mesh
# following new jax array tutorial

# 4-way dp / 2 way tp

# named sharding is easier to follow along with
if args.mp > 1:
mesh = Mesh(np.array(jax.devices()).reshape(args.dp,args.mp), ('dp','mp'))

Expand All @@ -158,9 +148,7 @@ def to_bf16(t):

else:
param_spec = no_shard
# batch_grad_spec = set_partitions_rules(param_shape, mesh, _get_partition_rules_dp)
batch_grad_spec = no_shard
# batch_loss_spec = batch_sharding
batch_grad_spec = no_shard
batch_loss_spec = NamedSharding(mesh, P(None, 'dp'))

configs = OmegaConf.load("conf/model_config.yaml")
Expand All @@ -179,7 +167,7 @@ def to_bf16(t):
with mesh:
#TODO: Rng sharding is not currently correct
train_step_dp = jax.jit(
partial(train_step, model=model, accum_steps=GRAD_ACCUM_STEPS, model_mp_spec = param_spec,batch_grad_spec = batch_grad_spec, batch_loss_spec = batch_loss_spec, dp = args.dp),
partial(train_step, model=model, accum_steps=GRAD_ACCUM_STEPS, model_mp_spec = param_spec,batch_loss_spec = batch_loss_spec),
in_shardings=(param_spec, batch_sharding, no_shard),
out_shardings=(param_spec,no_shard),
)
Expand Down

0 comments on commit 41f32f8

Please sign in to comment.