Skip to content

Commit

Permalink
branch cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
fattorib committed May 28, 2023
1 parent f89c2ad commit ced6393
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 199 deletions.
190 changes: 0 additions & 190 deletions minimal_tensor_parallel_emulation.py

This file was deleted.

21 changes: 12 additions & 9 deletions tensor_parallel_emulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def train_step(
# reshape to add a microbatch dimension
batch = batch.reshape(accum_steps, -1, context)
batch = with_sharding_constraint(batch, batch_loss_spec)
params = with_sharding_constraint(params, model_mp_spec)
params = with_sharding_constraint(to_bf16(params), model_mp_spec)

def loss_fn(params, batch):
_, loss = model.apply(
Expand All @@ -56,23 +56,26 @@ def loss_fn(params, batch):
train=True,
rngs={"dropout": rng_key},
)
return jnp.mean(loss)
return loss

grad_fn = jax.value_and_grad(loss_fn, has_aux=False)
def train_batch_loss(params, batch):
per_ex_loss = jax.vmap(loss_fn, in_axes=(None,0), out_axes = (0), axis_name='batch')(params,batch)
return jnp.mean(per_ex_loss, axis = 0)

grad_fn = jax.value_and_grad(train_batch_loss)

# accumulate gradients
def cumul_minibatch_step(carry, x_y):
cumul_loss, cumul_grads = carry
minibatch = x_y

loss, grads = grad_fn(to_bf16(params), minibatch)
minibatch = jnp.expand_dims(minibatch, axis = 1) # yuck
loss, grads = grad_fn(params, minibatch)
cumul_grads = jax.tree_map(
jnp.add, cumul_grads, grads
)
return (cumul_loss+loss, cumul_grads), None

grad_init = to_bf16(jax.tree_util.tree_map(jnp.zeros_like, params))
grad_init = with_sharding_constraint(grad_init, model_mp_spec) # gradients follow mp sharding

(loss,grads), _ = jax.lax.scan(cumul_minibatch_step, init = (jnp.zeros(()), grad_init), xs = batch)

Expand All @@ -97,8 +100,8 @@ def cumul_minibatch_step(carry, x_y):

if args.emulation:
print("Emulating 8 TPU cores")
GRAD_ACCUM_STEPS = 8
BATCH_SIZE = 128
GRAD_ACCUM_STEPS = 64
BATCH_SIZE = 256
CTX_LEN = 32
NUM_PASSES = args.iter
MODEL_SIZE = "test"
Expand Down Expand Up @@ -148,7 +151,7 @@ def to_bf16(t):
batch_loss_spec = NamedSharding(mesh, P(None, 'dp', None))

else:
param_spec = set_partitions_rules(param_shape, mesh, _get_partition_rules_tp, axis_name = 'dp')
param_spec = no_shard
batch_grad_spec = no_shard
batch_loss_spec = NamedSharding(mesh, P(None, 'dp', None))

Expand Down

0 comments on commit ced6393

Please sign in to comment.