-
Notifications
You must be signed in to change notification settings - Fork 722
/
convert_to_singleton.py
170 lines (139 loc) · 5 KB
/
convert_to_singleton.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
#!/usr/bin/env python
"""
Script for backing out of the MP-resharded (reshard.pt) files and getting back
a non-flattened state dict.
Particularly useful for converting our models to other repositories.
Usage:
$ ls 125m
dict.txt
gpt2-merges.txt
gpt2-vocab.json
reshard-model_part-0.pt
reshard-model_part-1.pt
$ python -m metaseq.scripts.convert_to_singleton 125m
$ ls 125m
dict.txt
gpt2-merges.txt
gpt2-vocab.json
reshard-model_part-0.pt
reshard-model_part-1.pt
restored.pt
"""
import argparse
import glob
import logging
import os
import sys
import torch
from metaseq import options, tasks, checkpoint_utils, utils
from metaseq.dataclass.configs import MetaseqConfig
from metaseq.dataclass.utils import convert_namespace_to_omegaconf
from metaseq.distributed import utils as distributed_utils
from metaseq.distributed import fsdp_enable_wrap, fsdp_wrap
from metaseq.distributed.stitch_fsdp_ckpt import reshard_megatron_parts
logging.basicConfig(
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=os.environ.get("LOGLEVEL", "INFO").upper(),
stream=sys.stdout,
)
logger = logging.getLogger("convert_to_singleton")
def create_generation_config_with_defaults(model_path, ddp_backend="pytorch_ddp"):
files = glob.glob(f"{model_path}/reshard*.pt")
MP = len(files)
BPE_MERGES = model_path + "/gpt2-merges.txt"
BPE_VOCAB = model_path + "/gpt2-vocab.json"
# Skeleton out all the annoying command line args we can infer
ARGS = [
"--model-parallel-size",
str(MP),
"--distributed-world-size",
str(MP),
"--ddp-backend",
ddp_backend,
"--task",
"language_modeling",
"--bpe-merges",
BPE_MERGES,
"--merges-filename",
BPE_MERGES,
"--bpe-vocab",
BPE_VOCAB,
"--vocab-filename",
BPE_VOCAB,
"--bpe",
"hf_byte_bpe",
"--path",
model_path + "/reshard.pt",
"--checkpoint-shard-count",
"1",
"--use-sharded-state",
model_path,
]
print(ARGS)
# build up the config file
parser = options.get_generation_parser()
# dumb defaults overriding
parser.set_defaults(lr_scheduler=None, criterion=None)
args = options.parse_args_and_arch(parser, input_args=ARGS)
cfg = convert_namespace_to_omegaconf(args)
cfg.distributed_training.distributed_world_size = MP
return cfg
def worker_main(cfg: MetaseqConfig):
"""
Load up the model on all workers for Model Parallelism, then
unflatten, move to cpu, and save to "restored.pt".
"""
task = tasks.setup_task(cfg.task)
def _build_model(cfg, task):
cfg.model.tensor_parallel_init_model_on_gpu = True
model = task.build_model(cfg.model).cuda()
return fsdp_wrap(model)
with fsdp_enable_wrap(
cfg.distributed_training,
use_sharded_state=cfg.distributed_training.use_sharded_state,
):
models, _model_args, _task = checkpoint_utils.load_model_ensemble_and_task(
utils.split_paths(cfg.common_eval.path),
arg_overrides=None,
task=task,
suffix=cfg.checkpoint.checkpoint_suffix,
strict=True,
num_shards=cfg.checkpoint.checkpoint_shard_count,
build_model_hook=_build_model,
)
model = models[0]
# consolidate everything on rank0
mp_size = distributed_utils.get_model_parallel_world_size()
model_parts = [{} for _ in range(mp_size)]
with model.summon_full_params():
for name, p in model.named_parameters():
gathered = [torch.zeros_like(p) for _ in range(mp_size)]
torch.distributed.all_gather(
gathered, p, group=distributed_utils.get_global_group()
)
for r, t in enumerate(gathered):
model_parts[r][name] = t.cpu()
glued = reshard_megatron_parts(model_parts, new_model_part_count=1)[0]
# glued['decoder.output_projection.weight'] = glued['decoder.embed_tokens.weight']
glued["decoder.version"] = model.state_dict()["decoder.version"].cpu()
if "decoder.output_projection.weight" in glued:
del glued["decoder.output_projection.weight"]
output_sd = checkpoint_utils.load_checkpoint_to_cpu(
cfg.common_eval.path.replace("reshard.pt", "reshard-model_part-0.pt")
)
output_sd["model"] = utils.move_to_cpu(glued)
output_sd["cfg"]["model"].arch = "transformer_lm"
output_sd["cfg"]["model"]._name = "transformer_lm"
if distributed_utils.get_global_rank() == 0:
with open(cfg.task.data + "/restored.pt", "wb") as f:
torch.save(output_sd, f)
def main():
# parser to be used like docstring shows
real_parser = argparse.ArgumentParser()
real_parser.add_argument("location")
args = real_parser.parse_args()
cfg = create_generation_config_with_defaults(args.location)
distributed_utils.call_main(cfg, worker_main)
if __name__ == "__main__":
main()