Skip to content

Commit

Permalink
optimizer sharding + updates
Browse files Browse the repository at this point in the history
  • Loading branch information
fattorib committed May 28, 2023
1 parent 03bee97 commit 5c76fc6
Showing 1 changed file with 52 additions and 8 deletions.
60 changes: 52 additions & 8 deletions tensor_parallel_shard_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,18 @@
from jax.sharding import Mesh
from omegaconf import OmegaConf
from tqdm import tqdm
from jax.experimental import mesh_utils
from jax.sharding import PartitionSpec as P
from src.models.GPT import model_getter
from src.partitioning.partition import (
create_opt_spec,
set_partitions_rules,
_get_partition_rules_dp,
_get_partition_rules_tp,
_get_partition_rules_tp_dp,
)
from src.training.training_utils import initialized
import jax.numpy as jnp
from typing import Any
from jax.lax import with_sharding_constraint
from jax.experimental.shard_map import shard_map
from jax.experimental.pjit import pjit
from jax.lax import with_sharding_constraint


def parse():
Expand Down Expand Up @@ -86,6 +83,19 @@ def cumul_minibatch_step(carry, x_y):

return grads, metrics

def update_opt_state(
params: Any,
grads: Any,
opt_state: Any,
optimizer: Any,
tp_spec: Any
):
# updates the optimizer state and params
params = with_sharding_constraint(params, tp_spec)
grads = with_sharding_constraint(grads, tp_spec)
updates, new_opt_state = optimizer.update(grads, opt_state, params)
new_params = optax.apply_updates(params, updates)
return new_params, new_opt_state

# Emulating 8 TPU cores
import os
Expand Down Expand Up @@ -160,13 +170,39 @@ def init(rng, init_batch):
params = jax.jit(init)(rng, (jnp.ones((1, CTX_LEN), dtype=jnp.int32)))

param_shape = jax.tree_map(lambda x: x.size, params) # we literally just do this to get keys

if args.mp > 1:
param_spec = set_partitions_rules(
param_shape, mesh, _get_partition_rules_tp, axis_name="mp"
)
batch_loss_spec = P(None, "dp", None)
params = shard_map(lambda x:x, mesh, in_specs=no_shard, out_specs=param_spec)(params)

batch_loss_spec = P(None, "dp", None)

# optimizer state init
mask = jax.tree_map(
lambda x: x.ndim != 1 and x.shape != (model.block_size, model.embedding_dim),
params,
)

tx = optax.chain(
optax.clip(1.0),
optax.adamw(
learning_rate=0.001,
weight_decay=0.1,
mask=mask,
b2=0.95,
),
)

opt_state_shapes = jax.eval_shape(tx.init, params)
opt_state_spec = create_opt_spec(param_spec, opt_state_shapes)

with mesh:
opt_state = pjit(
tx.init, in_axis_resources=(param_spec,), out_axis_resources=opt_state_spec,
)(params)

else:
param_spec = no_shard
batch_loss_spec = P(None, "dp", None)
Expand All @@ -183,6 +219,7 @@ def init(rng, init_batch):

# shard model across desired TP axes
with mesh:

train_step_tp = jax.jit(
shard_map(
partial(train_step, model = model, accum_steps=GRAD_ACCUM_STEPS),
Expand All @@ -193,12 +230,20 @@ def init(rng, init_batch):
)
)

update_opt_step_tp = pjit(
partial(update_opt_state, optimizer = tx, tp_spec=param_spec),
in_axis_resources=(param_spec, param_spec, opt_state_spec),
out_axis_resources= (param_spec, opt_state_spec)
)

rng, dropout_rng = jax.random.split(rng, 2)

init_batch = jax.numpy.ones(shape=(BATCH_SIZE, CTX_LEN), dtype=jax.numpy.int32)

grads, metrics = train_step_tp(params, init_batch)

params, opt_state = update_opt_step_tp(params, grads, opt_state)

start = time()

# visualize array shardings
Expand All @@ -212,8 +257,7 @@ def init(rng, init_batch):
shape=(BATCH_SIZE, CTX_LEN), dtype=jax.numpy.int32
)
grads, metrics = train_step_tp(params, batch)

params = jax.tree_map(lambda x,y: x - 0.01*y, params, grads)
params, opt_state = update_opt_step_tp(params, grads, opt_state)

jnp.zeros((10,10)).block_until_ready()
total_time = time() - start
Expand Down

0 comments on commit 5c76fc6

Please sign in to comment.