Skip to content

Commit

Permalink
cleanup + contextlib for easier profiling on CPU/TPU
Browse files Browse the repository at this point in the history
  • Loading branch information
fattorib committed May 25, 2023
1 parent d2239f3 commit 7224d90
Showing 1 changed file with 27 additions and 21 deletions.
48 changes: 27 additions & 21 deletions tensor_parallel_emulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from omegaconf import OmegaConf
from tqdm import tqdm
from jax.experimental import mesh_utils
from jax.sharding import PositionalSharding, NamedSharding
from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec as P
from src.models.GPT import model_getter
from src.partitioning.partition import create_opt_spec, set_partitions_rules, _get_partition_rules_dp,_get_partition_rules_tp,_get_partition_rules_tp_dp
Expand Down Expand Up @@ -41,14 +41,16 @@ def train_step(
dp: int = 0
):
"""
Computes loss/grads for a single batch of data.
Computes loss/grads for a single batch of data, optionally with gradient accumulation
"""
_,context = batch.shape

# reshape to add a microbatch dimension
batch = batch.reshape(accum_steps, -1, context)
batch = with_sharding_constraint(batch, batch_loss_spec)

params = with_sharding_constraint(params, model_mp_spec)

# force loss to only compute locally?
def loss_fn(params, batch):
_, loss = model.apply(
{"params": params["params"]},
Expand All @@ -57,29 +59,27 @@ def loss_fn(params, batch):
train=True,
rngs={"dropout": rng_key},
)
loss = with_sharding_constraint(loss, batch_loss_spec)
return loss.mean()
return jnp.mean(loss)

grad_fn = jax.value_and_grad(loss_fn, has_aux=False)

# accumulate gradients
def cumul_minibatch_step(carry, x_y):
cumul_loss, cumul_grads = carry
minibatch = x_y

loss, grads = grad_fn(to_bf16(params), minibatch)
cumul_loss, cumul_grads = jax.tree_util.tree_map(
cumul_loss, cumul_grads = jax.tree_map(
jnp.add, (cumul_loss, cumul_grads), (loss, grads)
)
return (cumul_loss, cumul_grads), None

grad_init = to_bf16(jax.tree_util.tree_map(jnp.zeros_like, params))
grad_init = with_sharding_constraint(grad_init, model_mp_spec) # gradients follow mp sharding

(loss,grads), _ = jax.lax.scan(cumul_minibatch_step, init = (0.0, grad_init), xs = batch)
(loss,grads), _ = jax.lax.scan(cumul_minibatch_step, init = (jnp.zeros(()), grad_init), xs = batch)

# grads = jax.tree_map(lambda x: x.reshape([1, *x.shape]), grads)
# grads = jax.tree_map(lambda x: jax.numpy.repeat(x, dp, axis=0), grads)
grads = with_sharding_constraint(grads, batch_grad_spec) # if dp, this tells XLA to replicate gradients across all devices
# grads = jax.tree_map(lambda x: jax.numpy.mean(x, axis=0), grads)
grads = with_sharding_constraint(grads, batch_grad_spec) # if dp only, this tells XLA to replicate gradients across all devices

loss, grads = jax.tree_util.tree_map(lambda x: x / accum_steps, (loss, grads))

Expand All @@ -102,14 +102,17 @@ def cumul_minibatch_step(carry, x_y):

if args.emulation:
print("Emulating 8 TPU cores")
GRAD_ACCUM_STEPS = 8
GRAD_ACCUM_STEPS = 16
BATCH_SIZE = 128
CTX_LEN = 32
NUM_PASSES = args.iter
MODEL_SIZE = "test"

def to_bf16(t):
return t

import contextlib
maybe_profiler = contextlib.nullcontext()

else:
GRAD_ACCUM_STEPS = 64
Expand All @@ -123,6 +126,8 @@ def to_bf16(t):
return jax.tree_map(
lambda x: x.astype(jnp.bfloat16) if x.dtype == jnp.float32 else x, t
)

maybe_profiler = jax.profiler.trace("jax-trace", create_perfetto_link=True)

# Setting up device mesh
# following new jax array tutorial
Expand All @@ -137,7 +142,7 @@ def to_bf16(t):
mesh = Mesh(np.array(jax.devices()).reshape(args.dp), ('dp'))

# indicates batch dim is split across dp axis
batch_sharding = jax.sharding.NamedSharding(mesh, P('dp'))
batch_sharding = NamedSharding(mesh, P('dp'))
no_shard = NamedSharding(mesh, None)

# # Setting up model + param spec
Expand All @@ -155,7 +160,8 @@ def to_bf16(t):
param_spec = no_shard
# batch_grad_spec = set_partitions_rules(param_shape, mesh, _get_partition_rules_dp)
batch_grad_spec = no_shard
batch_loss_spec = batch_sharding
# batch_loss_spec = batch_sharding
batch_loss_spec = NamedSharding(mesh, P(None, 'dp'))

configs = OmegaConf.load("conf/model_config.yaml")
model_info = configs[MODEL_SIZE]
Expand All @@ -175,7 +181,7 @@ def to_bf16(t):
train_step_dp = jax.jit(
partial(train_step, model=model, accum_steps=GRAD_ACCUM_STEPS, model_mp_spec = param_spec,batch_grad_spec = batch_grad_spec, batch_loss_spec = batch_loss_spec, dp = args.dp),
in_shardings=(param_spec, batch_sharding, no_shard),
out_shardings=(param_spec,no_shard)
out_shardings=(param_spec,no_shard),
)

rng, dropout_rng = jax.random.split(rng, 2)
Expand All @@ -191,14 +197,14 @@ def to_bf16(t):
# jax.debug.visualize_array_sharding(batch)

for i in tqdm(range(NUM_PASSES)):
# with jax.profiler.trace("jax-trace", create_perfetto_link=True):
dropout_rng, rng = jax.random.split(dropout_rng)
with maybe_profiler:
dropout_rng, rng = jax.random.split(dropout_rng)

batch = jax.numpy.ones(shape=(BATCH_SIZE, CTX_LEN), dtype=jax.numpy.int32)
batch = jax.device_put(batch, batch_sharding)
grads, metrics = train_step_dp(params, batch, dropout_rng)
batch = jax.numpy.ones(shape=(BATCH_SIZE, CTX_LEN), dtype=jax.numpy.int32)
batch = jax.device_put(batch, batch_sharding)
grads, metrics = train_step_dp(params, batch, dropout_rng)

params = jax.tree_map(lambda x,y : x - 0.01*y, params, grads)
# params = jax.tree_map(lambda x,y : x - 0.01*y, params, grads)

print(metrics)
total_time = time() - start
Expand Down

0 comments on commit 7224d90

Please sign in to comment.