Skip to content

Commit

Permalink
slim_model: options for float16, checkpoint selection (#21)
Browse files Browse the repository at this point in the history
  • Loading branch information
nostalgebraist authored Jun 16, 2021
1 parent 180a587 commit 960e694
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
4 changes: 4 additions & 0 deletions mesh_transformer/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ def to_bf16(t):
return jax.tree_map(lambda x: x.astype(jnp.bfloat16) if x.dtype == jnp.float32 else x, t)


def to_f16(t):
return jax.tree_map(lambda x: x.astype(jnp.float16) if x.dtype == jnp.float32 else x, t)


# identity in forward pass, psum in backward
@jax.custom_vjp
def f_psum(x):
Expand Down
16 changes: 12 additions & 4 deletions slim_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@
from mesh_transformer.transformer_shard import CausalTransformer
from smart_open import open

from mesh_transformer.util import clip_by_global_norm, to_bf16
from mesh_transformer.util import clip_by_global_norm, to_bf16, to_f16


def parse_args():
# Parse command line arguments
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default=None, help="Config file location")
parser.add_argument("--ckpt-step", type=int, default=-1, help="Step number of the checkpoint to convert (if not specified, converts the most recent checkpoint)")
parser.add_argument("--f16", default=False, action="store_true", help="Convert to float16 (instead of bfloat16)")

args = parser.parse_args()
return args
Expand All @@ -26,6 +28,7 @@ def parse_args():
if __name__ == "__main__":
args = parse_args()
params = json.load(open(args.config))
convert_fn = to_f16 if args.f16 else to_bf16

cores_per_replica = params["cores_per_replica"]

Expand Down Expand Up @@ -53,7 +56,10 @@ def parse_args():
with open(f"gs://{bucket}/{model_dir}/meta.json", "r") as f:
meta = json.load(f)

ckpt_step = meta["checkpoints"][-1]
if args.ckpt_step > -1:
ckpt_step = args.ckpt_step
else:
ckpt_step = meta["checkpoints"][-1]
print(f"using checkpoint {ckpt_step}")

with jax.experimental.maps.mesh(devices, ('dp', 'mp')):
Expand All @@ -66,9 +72,11 @@ def parse_args():
start = time.time()
del network.state["opt_state"]

network.state["params"] = to_bf16(network.state["params"])
network.state["params"] = convert_fn(network.state["params"])
print(f"network converted in {time.time() - start:.06}s")

suffix = "_slim_f16" if args.f16 else "_slim"

for i in range(cores_per_replica):
write_ckpt(network.state, f"gs://{bucket}/{model_dir}_slim/step_{ckpt_step}/", i)
write_ckpt(network.state, f"gs://{bucket}/{model_dir}{suffix}/step_{ckpt_step}/", i)
print(f"written shard {i}")

0 comments on commit 960e694

Please sign in to comment.