Skip to content
This repository has been archived by the owner on Nov 1, 2024. It is now read-only.

Commit

Permalink
Include shard metadata in the resharded state dicts
Browse files Browse the repository at this point in the history
  • Loading branch information
tangbinh committed Oct 27, 2022
1 parent 022a59b commit 1cb84c9
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 37 deletions.
90 changes: 53 additions & 37 deletions metaseq/scripts/reshard_fsdp.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import logging
import os
import re
from copy import deepcopy
from glob import glob
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple

import fire
import torch
Expand All @@ -21,7 +22,8 @@ def reshard_fsdp_checkpoints(
) -> None:
"""
Reshard FSDP checkpoints and write outputs to files. The model weights and optimizer states
are merged from the input shards before the resharding logic applies.
are merged from the sharded checkpoints before the resharding logic applies. The sharded
checkpoints are expected to contain shard metadata.
Args:
:param input_glob_pattern: A glob pattern specifying the path names of the input shards.
Expand Down Expand Up @@ -59,40 +61,41 @@ def reshard_fsdp_checkpoints(
def reshard_fsdp_state_dicts(
shard_state_dicts: List[Dict[str, Any]],
num_output_shards: int = 1,
unflatten_weights: bool = False,
unflatten_weights: bool = True,
skip_optimizer_state: bool = False,
) -> List[Dict[str, Any]]:
logger.info(f"Resharding state dicts from {len(shard_state_dicts)} shards")
# Unshard model weights
resharded_state_dict = [{} for _ in range(num_output_shards)]
shard_weights = [state["model"] for state in shard_state_dicts]
shard_metadata = [state["shard_metadata"] for state in shard_state_dicts]
resharded_model_weights = reshard_fsdp_model_weights(
shard_weights,
shard_metadata,
resharded_state_dicts = [{} for _ in range(num_output_shards)]
resharded_weights, resharded_metadata = reshard_fsdp_model_weights(
[state["model"] for state in shard_state_dicts],
[state["shard_metadata"] for state in shard_state_dicts],
num_output_shards,
unflatten_weights=unflatten_weights,
)
for shard_idx, model_weights in enumerate(resharded_model_weights):
resharded_state_dict[shard_idx]["model"] = model_weights
for shard_idx, (weight, metadata) in enumerate(
zip(resharded_weights, resharded_metadata)
):
resharded_state_dicts[shard_idx]["model"] = weight
resharded_state_dicts[shard_idx]["shard_metadata"] = metadata

# Unshard last optimizer state
if not skip_optimizer_state and "last_optimizer_state" in shard_state_dicts[0]:
reshared_state_dicts = reshard_fsdp_optim_state(
reshared_optim_states = reshard_fsdp_optim_state(
[state["last_optimizer_state"] for state in shard_state_dicts],
num_output_shards,
)
# TODO: Support optimizer state unpadding
for shard_idx, optim_state in enumerate(reshared_state_dicts):
resharded_state_dict[shard_idx]["last_optimizer_state"] = optim_state
for shard_idx, optim_state in enumerate(reshared_optim_states):
resharded_state_dicts[shard_idx]["last_optimizer_state"] = optim_state

# Copy other state values from the first shard
for key in shard_state_dicts[0]:
if key not in {"model", "last_optimizer_state", "sharded_metadata"}:
if key not in {"model", "last_optimizer_state", "shard_metadata"}:
for shard_idx in range(num_output_shards):
resharded_state_dict[shard_idx][key] = shard_state_dicts[0][key]
resharded_state_dicts[shard_idx][key] = shard_state_dicts[0][key]

return resharded_state_dict
return resharded_state_dicts


def reshard_fsdp_model_weights(
Expand All @@ -108,18 +111,19 @@ def reshard_fsdp_model_weights(
raise ValueError("Unflatten weights only if the number of output shards is 1.")

resharded_weights = [{} for _ in range(num_output_shards)]
for idx, metadata in enumerate(shard_metadata[0]["param_metadata"]):
fsdp_path = metadata["fsdp_path"]
for flat_name, param_info in metadata["params"].items():
resharded_metadata = [deepcopy(shard_metadata[0]) for _ in range(num_output_shards)]
for idx, param_metadata in enumerate(shard_metadata[0]["param_metadata"]):
fsdp_path = param_metadata["fsdp_path"]
for flat_name, param_info in param_metadata["params"].items():
full_key = ".".join([fsdp_path, flat_name]) if fsdp_path else flat_name
if full_key not in shard_weights[0]:
raise ValueError(f"No weight found for key {full_key} in metadata.")

# Unshard FSDP tensor weights
sharded_weights = []
for weights, metadata in zip(shard_weights, shard_metadata):
for weight, metadata in zip(shard_weights, shard_metadata):
pad = metadata["param_metadata"][idx]["params"][flat_name]["padding"]
sharded_weights.append(_unpad_tensor(weights[full_key], pad))
sharded_weights.append(_unpad_tensor(weight[full_key], pad))
unsharded_weights = torch.cat(sharded_weights, dim=0)

# For single shard output, tensor weights can be unflattened
Expand All @@ -129,12 +133,18 @@ def reshard_fsdp_model_weights(
for n, t, s in zip(names, unsharded_weights.split(numels), shapes):
param_name = ".".join([fsdp_path, n]) if fsdp_path else n
resharded_weights[0][param_name] = t.view(s)
resharded_metadata = [{}] * num_output_shards
continue

# Otherwise, reshard weights by chunking unsharded tensors
shards = _shard_and_pad_tensor(unsharded_weights, num_output_shards)
for shard_idx, shard in enumerate(shards):
resharded_weights[shard_idx][flat_name] = shard
weights, paddings = _shard_and_pad_tensor(
unsharded_weights, num_output_shards
)
for shard_idx, (weight, pad) in enumerate(zip(weights, paddings)):
resharded_weights[shard_idx][flat_name] = weight
resharded_metadata[shard_idx]["param_metadata"][idx]["params"][
flat_name
]["padding"] = pad

# Copy shared parameters
if unflatten_weights:
Expand All @@ -149,8 +159,7 @@ def reshard_fsdp_model_weights(
for shard_idx in range(num_output_shards):
resharded_weights[shard_idx][buffer_name] = shard_weights[0][buffer_name]

# TODO: Update and return sharded_metadata after
return resharded_weights
return resharded_weights, resharded_metadata


def reshard_fsdp_optim_state(
Expand Down Expand Up @@ -183,25 +192,25 @@ def reshard_fsdp_optim_state(
torch.cat([shard["state"][idx][key] for shard in shard_optim_states]),
pad=padding[key] if padding and key in padding else 0,
)
shards = _shard_and_pad_tensor(unsharded_value, num_out_shards)
for shard_idx, shard in enumerate(shards):
resharded_state_dict[shard_idx]["state"][idx][key] = shard
chunks, _ = _shard_and_pad_tensor(unsharded_value, num_out_shards)
for state_dict, chunk in zip(resharded_state_dict, chunks):
state_dict["state"][idx][key] = chunk

return resharded_state_dict


def _shard_and_pad_tensor(
tensor: torch.Tensor, num_shards: int, dim: int = 0
) -> List[torch.Tensor]:
) -> Tuple[List[torch.Tensor], List[int]]:
if num_shards == 1:
return [tensor]
return [tensor], [0]
shards = tensor.chunk(num_shards, dim=dim)
assert len(shards) == num_shards, len(shards)
for idx, shard in enumerate(shards):
num_to_pad = shards[0].numel() - shard.numel()
if num_to_pad > 0:
shards[idx] = F.pad(shard, [0, num_to_pad])
return shards
paddings = [shards[0].numel() - shard.numel() for shard in shards]
for idx, (shard, padding) in enumerate(zip(shards, paddings)):
if padding > 0:
shards[idx] = F.pad(shard, [0, padding])
return shards, paddings


def _unpad_tensor(shard: torch.Tensor, pad: int) -> torch.Tensor:
Expand All @@ -211,4 +220,11 @@ def _unpad_tensor(shard: torch.Tensor, pad: int) -> torch.Tensor:


if __name__ == "__main__":
"""
Example usage:
python -m metaseq.scripts.reshard_fsdp \
--input-glob-pattern "opt-2.7b/raw/checkpoint_last-model_part-0-shard*.pt" \
--output-shard-name "opt-2.7b/reshard/reshard-model_part-0.pt" \
--num-output-shards 1 --skip-optimizer-state True --unflatten-weights True
"""
fire.Fire(reshard_fsdp_checkpoints)
1 change: 1 addition & 0 deletions metaseq/scripts/reshard_mp.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def reshard_mp(
no_pad=False,
drop_optimizer_state=False,
):
logger.info("Warning: This method is now deprecated in favor of the `metaseq.scripts.reshard_fsdp` script.")
middle = f"model_part-{part}"
do_pad = not no_pad
if not Path(f"{save_prefix}-{middle}-shard0.pt").exists():
Expand Down
2 changes: 2 additions & 0 deletions metaseq/scripts/reshard_mp_launch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

echo "Warning: This script is now deprecated in favor of the `metaseq.scripts.reshard_fsdp` script."

prefix=$1
save_dir=$2
mparts=$3
Expand Down
2 changes: 2 additions & 0 deletions metaseq/scripts/reshard_mp_launch_no_slurm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

echo "Warning: This script is now deprecated in favor of the `metaseq.scripts.reshard_fsdp` script."

prefix=$1
save_dir=$2
mparts=$3
Expand Down

0 comments on commit 1cb84c9

Please sign in to comment.