Skip to content

Commit

Permalink
attempt inference with v2 model
Browse files Browse the repository at this point in the history
  • Loading branch information
kingoflolz committed Aug 27, 2021
1 parent 43e4ed8 commit 0f9b555
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 13 deletions.
106 changes: 96 additions & 10 deletions mesh_transformer/transformer_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,9 @@ class CausalTransformerV2:
def __init__(self, config):
self.config = config
optimizer = config["optimizer"]
with_optimizer = optimizer is not None

head_print(f"with_optimizer: {with_optimizer}")

bf16_optimizer = config.get("bf16_optimizer", False)
early_cast = config.get("early_cast", False)
Expand All @@ -358,8 +361,20 @@ def residual(x, mask):
out = x + TransformerLayerShardV2(config, init_scale=2. / config["layers"])(x, mask)
return maybe_shard(out, P("dp", None, "mp"))

def init_decode(x, given_length, mask):
residual, decode_state = TransformerLayerShardV2(config, init_scale=2. / config["layers"])\
.get_init_decode_state(x, given_length, mask)
out = x + residual
return maybe_shard(out, P("dp", None, "mp")), decode_state

def iter_decode(decode_state, x):
residual, decode_state = TransformerLayerShardV2(config, init_scale=2. / config["layers"])\
.decode_once(decode_state, x, 0)
out = x + residual
return maybe_shard(out, P("dp", None, "mp")), decode_state

def transformer(x, mask):
return hk.remat(residual)(x, mask)
return hk.remat(residual, prevent_cse=False)(x, mask)

def projection(x):
return Projection(config)(x)
Expand All @@ -372,8 +387,10 @@ def init_fns():
return embed_init_fn, transformer_init_fn, projection_init_fn

def shard_strategy(shape_dtype, parallel):
if shape_dtype.ndim <= 1:
if shape_dtype.ndim == 0:
return P()
if shape_dtype.ndim == 1:
return P(None)
# embedding/projection layers
elif shape_dtype.shape == (config["n_vocab"], config["d_model"]):
return P(parallel, None)
Expand All @@ -385,7 +402,7 @@ def shard_strategy(shape_dtype, parallel):
if shape_dtype.ndim == 2:
# a channel wise variable (e.g. layernorm parameters)
# replicate it for speed
return P(None)
return P(None, None)
elif shape_dtype.ndim == 3:
# a weight matrix
matrix_size = shape_dtype.shape[1:]
Expand Down Expand Up @@ -423,12 +440,16 @@ def init_scan_fn(key, x):
"proj": projection_init_fn(p_key, jax.random.uniform(t_key, input_shape[1:], dtype=jnp.float32)),
}

return {
output_state = {
"params": (to_bf16 if early_cast else to_f32)(params),
"step": np.array(0),
"opt_state": optimizer.init((to_bf16 if bf16_optimizer else to_f32)(params))
}

if with_optimizer:
output_state["opt_state"] = optimizer.init((to_bf16 if bf16_optimizer else to_f32)(params))

return output_state

assert thread_resources.env.shape['mp'] == config["cores_per_replica"]

dp = thread_resources.env.shape['dp']
Expand All @@ -444,14 +465,19 @@ def init_scan_fn(key, x):
state_shard = {
"step": P(),

# zero level 1: shard optimizer states over both MP and DP
"opt_state": jax.tree_map(partial(shard_strategy, parallel=["mp", "dp"]), param_shapes["opt_state"]),

# fp32 params are also sharded (so this is like a weird mix between zero-1 and zero-3...)
"params": jax.tree_map(partial(shard_strategy, parallel=["mp", "dp"]), param_shapes["params"]),
}

if "opt_state" in param_shapes:
# zero level 1: shard optimizer states over both MP and DP
state_shard["opt_state"] = jax.tree_map(partial(shard_strategy, parallel=["mp", "dp"]), param_shapes["opt_state"])

self.state_shard = state_shard

head_print("sharding strategy:")
# head_print("state shard: ", state_shard)
# head_print("param_shapes: ", param_shapes)
jax.tree_multimap(head_print, state_shard, param_shapes)

self.init_pjit = pjit(init, in_axis_resources=(None, P("dp")), out_axis_resources=state_shard)
Expand Down Expand Up @@ -567,6 +593,62 @@ def eval(params, ctx, tgt, ctx_length):
P("dp"), P("dp"), P("dp")),
out_axis_resources=P("dp"))

def generate(params, key, ctx, ctx_length, aux, sampler_options):
sampler = config["sampler"]
gen_length = config["gen_length"]

embed_apply_fn, _ = apply_fns()
init_decode_apply = hk.without_apply_rng(hk.transform(init_decode)).apply
iter_decode_apply = hk.without_apply_rng(hk.transform(iter_decode)).apply

def get_inital(params, ctx, ctx_length):
x = embed_apply_fn(params["embed"], ctx)
mask = (jnp.arange(0, ctx.shape[1])[None, :] > ctx_length[:, None]) * -1e10

def apply_scan_fn(layer_in, layer_state):
x, mask = layer_in

x, decode_state = init_decode_apply(layer_state, x, mask)
return (x, mask), decode_state

_, init_state = jax.lax.scan(apply_scan_fn,
(to_bf16(x), mask),
xs=params["transformer"])

return (last.astype(jnp.uint32), init_state, hk.next_rng_key())

initial_state = get_inital(params, ctx, ctx_length)
initial_carry = ()

def generate_scan_fn(carry, sampler_input):
next_token, decode_state, sample_key = carry
sample_key, new_key = jax.random.split(sample_key)

x = embed_apply_fn(params["embed"], next_token)
mask = (jnp.arange(0, ctx.shape[1])[None, :] > ctx_length[:, None]) * -1e10

def layer_scan_fn(carry_in, layer_in):
x, mask = carry_in
layer_state, decode_state = layer_in

x, decode_state = iter_decode_apply(layer_state, decode_state, x)
return (x, mask), decode_state

(x, _), new_state = jax.lax.scan(layer_scan_fn,
(to_bf16(x), mask),
xs=params["transformer"])

projection_apply_fn = hk.without_apply_rng(hk.transform(Projection(config))).apply

logits = projection_apply_fn(params["proj"], x)
next_token, sample_info = sampler(sample_key, logits, sampler_input, **sampler_options)

new_carry = (next_token, new_state, new_key)
return new_carry, (next_token, sample_info)

final_state, outputs = jax.lax.scan(generate_scan_fn, initial_state, xs=aux, length=gen_length)
return final_state, outputs

self.move_weights_pjit = pjit(lambda x: to_bf16(x),
in_axis_resources=(state_shard["params"], ),
out_axis_resources=mp_shard_strategy if early_collect else state_shard["params"])
Expand All @@ -584,7 +666,11 @@ def eval(params, ctx, tgt, ctx_length):

self.state = self.init_pjit(next(key), x)
self.state_shard = state_shard
self.eval_weights = None

if with_optimizer:
self.eval_weights = None
else:
self.eval_weights = self.state["params"]

param_count = hk.data_structures.tree_size(self.state['params'])
head_print(f"Total parameters: {param_count * dp}")
Expand All @@ -593,7 +679,7 @@ def write_ckpt(self, path, _):
write_ckpt_v2(self.state, path)

def load_ckpt(self, path):
self.state = load_ckpt_v2(self.state, path)
self.state = load_ckpt_v2(self.state, path, self.state_shard, not self.config.get("eval_only", False))

def train(self, sample):
# print("train iter")
Expand Down
2 changes: 1 addition & 1 deletion scripts/init_ray.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ sudo docker create --name libtpu gcr.io/cloud-tpu-v2-images/libtpu:libtpu_202105
screen -d -m python -c 'import time; time.sleep(999999999)'

# initializes jax and installs ray on cloud TPUs
sudo pip install --upgrade jaxlib==0.1.67 jax==0.2.12 ray[default]==1.4.1 fabric dataclasses optax git+https://github.com/deepmind/dm-haiku tqdm cloudpickle smart_open[gcs] einops func_timeout
sudo pip install --upgrade jaxlib==0.1.67 jax==0.2.12 ray[default]==1.5.1 fabric dataclasses optax git+https://github.com/deepmind/dm-haiku tqdm cloudpickle smart_open[gcs] einops func_timeout
2 changes: 1 addition & 1 deletion scripts/init_ray_v2.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ screen -d -m python -c 'import time; time.sleep(999999999)'

# initializes jax and installs ray on cloud TPUs
sudo pip install "jax[tpu]>=0.2.18" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
sudo pip install --upgrade ray[default]==1.4.1 fabric dataclasses optax git+https://github.com/deepmind/dm-haiku tqdm cloudpickle smart_open[gcs] einops func_timeout
sudo pip install --upgrade ray[default]==1.5.1 fabric dataclasses optax git+https://github.com/deepmind/dm-haiku tqdm cloudpickle smart_open[gcs] einops func_timeout
4 changes: 3 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def parse_args():
sample_size=params['seq'],
restore_state=train_load_restore)

global_val_batch = per_replica_batch * tpu_size // cores_per_replica
global_val_batch = int(per_replica_batch * tpu_size // cores_per_replica * params.get("val_batch_multiplier", 1))

val_sets = {}

Expand Down Expand Up @@ -164,4 +164,6 @@ def parse_args():
print(f"step {step} val results: {dumped}")
wandb.log(flat_results, step)
step += 1

pbar.set_postfix({'loss': loss, 'last_loss': last_loss})
pbar.update()

0 comments on commit 0f9b555

Please sign in to comment.