Skip to content

Commit

Permalink
Fix warmup steps and minor issues in benchmarks (#334)
Browse files Browse the repository at this point in the history
The previous code was incorrect for the case of `warmup_steps != 1` (this mode was never used, but can be used in future).
  • Loading branch information
borzunov committed Jun 30, 2023
1 parent d126ee3 commit 10c72ac
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 18 deletions.
14 changes: 8 additions & 6 deletions benchmarks/benchmark_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import multiprocessing as mp
from time import perf_counter

import numpy as np
import torch
from hivemind.utils.logging import get_logger

Expand Down Expand Up @@ -47,9 +48,9 @@ def benchmark_forward(process_idx, args):
logger.info(f"Created model: {process_idx=} {model.device=}")

torch.manual_seed(42)
for step in range(args.n_steps):
if step == args.warmup_steps:
start_time = perf_counter()
step_times = []
for step in range(args.warmup_steps + args.n_steps):
start_time = perf_counter()

input_ids = torch.randint(0, model.config.vocab_size, size=(args.batch_size, args.seq_len))

Expand All @@ -59,10 +60,11 @@ def benchmark_forward(process_idx, args):
logger.info(f"{process_idx=} Fwd end")

if step >= args.warmup_steps:
speed = step / (perf_counter() - start_time) * input_ids.numel()
logger.info(f"{process_idx=} {step=} {speed=:.3f}")
step_times.append(perf_counter() - start_time)
speed = input_ids.numel() / np.mean(step_times)
logger.info(f"{process_idx=} {step=} {speed=:.2f}")

logger.info(f"Final result: {process_idx=} {speed=:.3f}")
logger.info(f"Final result: {process_idx=} {speed=:.2f}")


if __name__ == "__main__":
Expand Down
18 changes: 11 additions & 7 deletions benchmarks/benchmark_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import multiprocessing as mp
from time import perf_counter

import numpy as np
import torch
from hivemind.utils.logging import get_logger
from transformers import AutoTokenizer
Expand Down Expand Up @@ -38,26 +39,29 @@ def main():

@torch.inference_mode()
def benchmark_inference(process_idx, args):
tokenizer = AutoTokenizer.from_pretrained(args.model)
tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False)
# Using use_fast=False since LlamaTokenizerFast takes a long time to start, and we decode 1 token at a time anyway

model = AutoDistributedModelForCausalLM.from_pretrained(
args.model, initial_peers=args.initial_peers, torch_dtype=DTYPE_MAP[args.torch_dtype]
)
logger.info(f"Created model: {process_idx=} {model.device=} {model.config.torch_dtype=}")
logger.info(f"Created model: {process_idx=} {model.device=}")

result = ""
step_times = []
with model.transformer.h.inference_session(max_length=args.seq_len) as sess:
for step in range(args.seq_len):
if step == args.warmup_steps:
start_time = perf_counter()
start_time = perf_counter()

outputs = model.generate(max_new_tokens=1, session=sess)
result += tokenizer.decode(outputs[0])

if step >= args.warmup_steps:
speed = step / (perf_counter() - start_time)
logger.info(f"{process_idx=} {step=} {speed=:.3f}")
step_times.append(perf_counter() - start_time)
speed = 1 / np.mean(step_times)
logger.info(f"{process_idx=} {step=} {speed=:.2f}")

logger.info(f"Final result: {process_idx=} {speed=:.3f}")
logger.info(f"Final result: {process_idx=} {speed=:.2f}")


if __name__ == "__main__":
Expand Down
12 changes: 7 additions & 5 deletions benchmarks/benchmark_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def benchmark_training(process_idx, args):
torch.manual_seed(42)
fwd_times = []
bwd_times = []
for step in range(args.n_steps):
for step in range(args.warmup_steps + args.n_steps):
input_ids = torch.randint(0, model.config.vocab_size, size=(args.batch_size, args.seq_len), device=args.device)
if args.task == "cls":
labels = torch.randint(0, 2, size=[args.batch_size], device=args.device)
Expand All @@ -78,20 +78,22 @@ def benchmark_training(process_idx, args):
logger.info(f"{process_idx=} {step=} Forward")
start_time = perf_counter()
outputs = model(input_ids, labels=labels)
fwd_times.append(perf_counter() - start_time)
if step >= args.warmup_steps:
fwd_times.append(perf_counter() - start_time)

logger.info(f"{process_idx=} {step=} Backward")
start_time = perf_counter()
outputs.loss.backward()
bwd_times.append(perf_counter() - start_time)
if step >= args.warmup_steps:
bwd_times.append(perf_counter() - start_time)

logger.info(f"{process_idx=} {step=} Optimizer step")
opt.step()
opt.zero_grad()

if step >= args.warmup_steps:
fwd_speed = input_ids.numel() / np.mean(fwd_times[1:])
bwd_speed = input_ids.numel() / np.mean(bwd_times[1:])
fwd_speed = input_ids.numel() / np.mean(fwd_times)
bwd_speed = input_ids.numel() / np.mean(bwd_times)
logger.info(f"{process_idx=} Fwd speed: {fwd_speed:.2f} | Bwd speed: {bwd_speed:.2f}")

logger.info(f"Final result: {process_idx=} {fwd_speed=:.2f} | {bwd_speed=:.2f}")
Expand Down

0 comments on commit 10c72ac

Please sign in to comment.