Skip to content

Commit

Permalink
Lower memory consumption in Colab demo (#126)
Browse files Browse the repository at this point in the history
  • Loading branch information
vfbd authored Sep 15, 2021
1 parent 0c32f0c commit 8f071d0
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 2 deletions.
4 changes: 2 additions & 2 deletions colab_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1346,7 +1346,7 @@
"import optax\n",
"import transformers\n",
"\n",
"from mesh_transformer.checkpoint import read_ckpt\n",
"from mesh_transformer.checkpoint import read_ckpt_lowmem\n",
"from mesh_transformer.sampling import nucleaus_sample\n",
"from mesh_transformer.transformer_shard import CausalTransformer"
],
Expand Down Expand Up @@ -1517,7 +1517,7 @@
"\n",
"network = CausalTransformer(params)\n",
"\n",
"network.state = read_ckpt(network.state, \"step_383500/\", devices.shape[1])\n",
"network.state = read_ckpt_lowmem(network.state, \"step_383500/\", devices.shape[1])\n",
"\n",
"network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))"
],
Expand Down
53 changes: 53 additions & 0 deletions mesh_transformer/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,59 @@ def _unshard(shards, old_flattened):
return loaded_pytree


def read_ckpt_lowmem(pytree, dir, shards_in, shards_out=None, load_opt=True):
if shards_out is None:
shards_out = shards_in

old_flattened, structure = jax.tree_flatten(pytree)

original_opt_state = pytree["opt_state"]

def _unshard():
start = time.time()
unsharded = []
devices = jax.devices()
device_count = len(devices)
device_index = 0

for file_index in range(pieces):
array_keys = [*np.load(f"{dir}shard_0/{file_index}.npz").keys()]
for array_index in range(len(array_keys)):
unstacked = []
for shard_index in range(shards_in):
npz = np.load(f"{dir}shard_{shard_index}/{file_index}.npz")
array = npz[array_keys[array_index]]
if array.dtype == 'V2':
array.dtype = jnp.bfloat16
unstacked.append(array)

x = jax.device_put(jnp.stack(unstacked), device=devices[device_index % device_count])

if shards_out != shards_in:
x = reshard(x, old_flattened[device_index].shape)
unsharded.append(x)

assert x.shape == old_flattened[device_index].shape, f"Incompatible checkpoints {x.shape} vs {old_flattened[device_index].shape}"
device_index += 1

print(f"read from disk/gcs in {time.time() - start:.06}s")
return unsharded

try:
unsharded = _unshard()
except AssertionError:
load_opt = False # no opt to load in ckpt
del pytree['opt_state']
old_flattened, structure = jax.tree_flatten(pytree)
unsharded = _unshard()

loaded_pytree = jax.tree_unflatten(structure, unsharded)

if not load_opt:
loaded_pytree['opt_state'] = original_opt_state
return loaded_pytree


def parallel_write(arrays, fname):
# TODO: make this actually parallel
with open(fname, "wb") as f:
Expand Down

0 comments on commit 8f071d0

Please sign in to comment.