-
Notifications
You must be signed in to change notification settings - Fork 465
/
tinyllama.py
395 lines (324 loc) · 14.5 KB
/
tinyllama.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
import glob
import math
import sys
import time
from pathlib import Path
from typing import Optional, Tuple, Union
import math
import lightning as L
import torch
from lightning.fabric.strategies import FSDPStrategy, XLAStrategy
from torch.utils.data import DataLoader
from functools import partial
# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))
# from apex.optimizers import FusedAdam #torch optimizer has a cuda backend, which is faster actually
from lit_gpt.model import GPT, Block, Config, CausalSelfAttention
from lit_gpt.packed_dataset import CombinedDataset, PackedDataset
from lit_gpt.speed_monitor import SpeedMonitorFabric as Monitor
from lit_gpt.speed_monitor import estimate_flops, measure_flops
from lit_gpt.utils import chunked_cross_entropy, get_default_supported_precision, num_parameters, step_csv_logger, lazy_load
from pytorch_lightning.loggers import WandbLogger
from lit_gpt import FusedCrossEntropyLoss
import random
model_name = "tiny_LLaMA_1b"
name = "tinyllama_1b"
out_dir = Path("out") / name
# Hyperparameters
num_of_devices = 8
global_batch_size = 512
learning_rate = 4e-4
micro_batch_size = 8
max_step = 715256 * 2
warmup_steps = 2000
log_step_interval = 10
eval_iters = 100
save_step_interval = 5000
eval_step_interval = 5000
weight_decay = 1e-1
beta1 = 0.9
beta2 = 0.95
grad_clip = 1.0
decay_lr = True
min_lr = 4e-5
batch_size = global_batch_size // num_of_devices
gradient_accumulation_steps = batch_size // micro_batch_size
assert gradient_accumulation_steps > 0
warmup_iters = warmup_steps * gradient_accumulation_steps
max_iters = max_step * gradient_accumulation_steps
lr_decay_iters = max_iters
log_iter_interval = log_step_interval * gradient_accumulation_steps
# Treat all dataset equally by their size. If you want to use a different weight for a dataset, add it to the list with the weight.
train_data_config = [
("train_slim", 0.693584),
("train_star", 0.306416),
]
val_data_config = [
("validation", 1.0),
]
hparams = {k: v for k, v in locals().items() if isinstance(v, (int, float, str)) and not k.startswith("_")}
logger = step_csv_logger("out", name, flush_logs_every_n_steps=log_iter_interval)
wandb_logger = WandbLogger()
def setup(
devices: int = 8,
train_data_dir: Path = Path("data/redpajama_sample"),
val_data_dir: Optional[Path] = None,
precision: Optional[str] = None,
tpu: bool = False,
resume: Union[bool, Path] = False,
) -> None:
precision = precision or get_default_supported_precision(training=True, tpu=tpu)
if devices > 1:
if tpu:
# For multi-host TPU training, the device count for Fabric is limited to the count on a single host.
devices = "auto"
strategy = XLAStrategy(sync_module_states=False)
else:
strategy = FSDPStrategy(
auto_wrap_policy={Block},
activation_checkpointing_policy=None,
state_dict_type="full",
limit_all_gathers=True,
cpu_offload=False,
)
else:
strategy = "auto"
fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=[logger, wandb_logger])
fabric.print(hparams)
#fabric.launch(main, train_data_dir, val_data_dir, resume)
main(fabric, train_data_dir, val_data_dir, resume)
def main(fabric, train_data_dir, val_data_dir, resume):
monitor = Monitor(fabric, window_size=2, time_unit="seconds", log_iter_interval=log_iter_interval)
if fabric.global_rank == 0:
out_dir.mkdir(parents=True, exist_ok=True)
config = Config.from_name(model_name)
train_dataloader, val_dataloader = create_dataloaders(
batch_size=micro_batch_size,
block_size=config.block_size,
fabric=fabric,
train_data_dir=train_data_dir,
val_data_dir=val_data_dir,
seed=3407,
)
if val_dataloader is None:
train_dataloader = fabric.setup_dataloaders(train_dataloader)
else:
train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader)
fabric.seed_everything(3407) # same seed for every process to init model (FSDP)
fabric.print(f"Loading model with {config.__dict__}")
t0 = time.perf_counter()
with fabric.init_module(empty_init=True):
model = GPT(config)
model.apply(partial(model._init_weights ,n_layer=config.n_layer))
fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.")
fabric.print(f"Total parameters {num_parameters(model):,}")
model = fabric.setup(model)
optimizer = torch.optim.AdamW(
model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2), foreach=False
)
# optimizer = FusedAdam(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2),adam_w_mode=True)
optimizer = fabric.setup_optimizers(optimizer)
state = {"model": model, "optimizer": optimizer, "hparams": hparams, "iter_num": 0, "step_count": 0}
if resume is True:
resume = sorted(out_dir.glob("*.pth"))[-1]
if resume :
fabric.print(f"Resuming training from {resume}")
fabric.load(resume, state)
train_time = time.perf_counter()
train(fabric, state, train_dataloader, val_dataloader, monitor, resume)
fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s")
if fabric.device.type == "cuda":
fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")
def train(fabric, state, train_dataloader, val_dataloader, monitor, resume):
model = state["model"]
optimizer = state["optimizer"]
if val_dataloader is not None:
validate(fabric, model, val_dataloader) # sanity check
with torch.device("meta"):
meta_model = GPT(model.config)
# "estimated" is not as precise as "measured". Estimated is optimistic but widely used in the wild.
# When comparing MFU or FLOP numbers with other projects that use estimated FLOPs,
# consider passing `SpeedMonitor(flops_per_batch=estimated_flops)` instead
estimated_flops = estimate_flops(meta_model) * micro_batch_size
fabric.print(f"Estimated TFLOPs: {estimated_flops * fabric.world_size / 1e12:.2f}")
x = torch.randint(0, 1, (micro_batch_size, model.config.block_size))
# measured_flos run in meta. Will trigger fusedRMSNorm error
#measured_flops = measure_flops(meta_model, x)
#fabric.print(f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}")
del meta_model, x
total_lengths = 0
total_t0 = time.perf_counter()
if fabric.device.type == "xla":
import torch_xla.core.xla_model as xm
xm.mark_step()
initial_iter = state["iter_num"]
curr_iter = 0
loss_func = FusedCrossEntropyLoss()
for train_data in train_dataloader:
# resume loader state. This is not elegant but it works. Should rewrite it in the future.
if resume:
if curr_iter < initial_iter:
curr_iter += 1
continue
else:
resume = False
curr_iter = -1
fabric.barrier()
fabric.print("resume finished, taken {} seconds".format(time.perf_counter() - total_t0))
if state["iter_num"] >= max_iters:
break
# determine and set the learning rate for this iteration
lr = get_lr(state["iter_num"]) if decay_lr else learning_rate
for param_group in optimizer.param_groups:
param_group["lr"] = lr
iter_t0 = time.perf_counter()
input_ids = train_data[:, 0 : model.config.block_size].contiguous()
targets = train_data[:, 1 : model.config.block_size + 1].contiguous()
is_accumulating = (state["iter_num"] + 1) % gradient_accumulation_steps != 0
with fabric.no_backward_sync(model, enabled=is_accumulating):
logits = model(input_ids)
loss = loss_func(logits, targets)
# loss = chunked_cross_entropy(logits, targets, chunk_size=0)
fabric.backward(loss / gradient_accumulation_steps)
if not is_accumulating:
fabric.clip_gradients(model, optimizer, max_norm=grad_clip)
optimizer.step()
optimizer.zero_grad()
state["step_count"] += 1
elif fabric.device.type == "xla":
xm.mark_step()
state["iter_num"] += 1
# input_id: B L
total_lengths += input_ids.size(1)
t1 = time.perf_counter()
fabric.print(
f"iter {state['iter_num']} step {state['step_count']}: loss {loss.item():.4f}, iter time:"
f" {(t1 - iter_t0) * 1000:.2f}ms{' (optimizer.step)' if not is_accumulating else ''}"
f" remaining time: {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600:.2f} hours. "
# print days as well
f" or {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600 / 24:.2f} days. "
)
monitor.on_train_batch_end(
state["iter_num"] * micro_batch_size,
t1 - total_t0,
# this assumes that device FLOPs are the same and that all devices have the same batch size
fabric.world_size,
flops_per_batch=estimated_flops,
lengths=total_lengths,
train_loss = loss.item()
)
if val_dataloader is not None and not is_accumulating and state["step_count"] % eval_step_interval == 0:
t0 = time.perf_counter()
val_loss = validate(fabric, model, val_dataloader)
t1 = time.perf_counter() - t0
monitor.eval_end(t1)
fabric.print(f"step {state['iter_num']}: val loss {val_loss:.4f}, val time: {t1 * 1000:.2f}ms")
fabric.log_dict({"metric/val_loss": val_loss.item(), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size},state["step_count"])
fabric.log_dict({"metric/val_ppl": math.exp(val_loss.item()), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size},state["step_count"])
fabric.barrier()
if not is_accumulating and state["step_count"] % save_step_interval == 0:
checkpoint_path = out_dir / f"iter-{state['iter_num']:06d}-ckpt.pth"
fabric.print(f"Saving checkpoint to {str(checkpoint_path)!r}")
fabric.save(checkpoint_path, state)
@torch.no_grad()
def validate(fabric: L.Fabric, model: torch.nn.Module, val_dataloader: DataLoader) -> torch.Tensor:
fabric.print("Validating ...")
model.eval()
losses = torch.zeros(eval_iters, device=fabric.device)
for k, val_data in enumerate(val_dataloader):
if k >= eval_iters:
break
input_ids = val_data[:, 0 : model.config.block_size].contiguous()
targets = val_data[:, 1 : model.config.block_size + 1].contiguous()
logits = model(input_ids)
loss = chunked_cross_entropy(logits, targets, chunk_size=0)
# loss_func = FusedCrossEntropyLoss()
# loss = loss_func(logits, targets)
losses[k] = loss.item()
out = losses.mean()
model.train()
return out
def create_dataloader(
batch_size: int, block_size: int, data_dir: Path, fabric, shuffle: bool = True, seed: int = 12345, split="train"
) -> DataLoader:
datasets = []
data_config = train_data_config if split == "train" else val_data_config
for prefix, _ in data_config:
filenames = sorted(glob.glob(str(data_dir / f"{prefix}*")))
random.seed(seed)
random.shuffle(filenames)
dataset = PackedDataset(
filenames,
# n_chunks control the buffer size.
# Note that the buffer size also impacts the random shuffle
# (PackedDataset is an IterableDataset. So the shuffle is done by prefetch a buffer and shuffle the buffer)
n_chunks=8,
block_size=block_size,
shuffle=shuffle,
seed=seed+fabric.global_rank,
num_processes=fabric.world_size,
process_rank=fabric.global_rank,
)
datasets.append(dataset)
if not datasets:
raise RuntimeError(
f"No data found at {data_dir}. Make sure you ran prepare_redpajama.py to create the dataset."
)
weights = [weight for _, weight in data_config]
sum_weights = sum(weights)
weights = [el / sum_weights for el in weights]
combined_dataset = CombinedDataset(datasets=datasets, seed=seed, weights=weights)
return DataLoader(combined_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)
def create_dataloaders(
batch_size: int,
block_size: int,
fabric,
train_data_dir: Path = Path("data/redpajama_sample"),
val_data_dir: Optional[Path] = None,
seed: int = 12345,
) -> Tuple[DataLoader, DataLoader]:
# Increase by one because we need the next word as well
effective_block_size = block_size + 1
train_dataloader = create_dataloader(
batch_size=batch_size,
block_size=effective_block_size,
fabric=fabric,
data_dir=train_data_dir,
shuffle=True,
seed=seed,
split="train"
)
val_dataloader = (
create_dataloader(
batch_size=batch_size,
block_size=effective_block_size,
fabric=fabric,
data_dir=val_data_dir,
shuffle=False,
seed=seed,
split="validation"
)
if val_data_dir
else None
)
return train_dataloader, val_dataloader
# learning rate decay scheduler (cosine with warmup)
def get_lr(it):
# 1) linear warmup for warmup_iters steps
if it < warmup_iters:
return learning_rate * it / warmup_iters
# 2) if it > lr_decay_iters, return min learning rate
if it > lr_decay_iters:
return min_lr
# 3) in between, use cosine decay down to min learning rate
decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
assert 0 <= decay_ratio <= 1
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
return min_lr + coeff * (learning_rate - min_lr)
if __name__ == "__main__":
# Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false"
# torch.backends.cuda.enable_flash_sdp(False)
torch.set_float32_matmul_precision("high")
from jsonargparse import CLI
CLI(setup)