Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM nvcr.io/nvidia/pytorch:22.11-py3
FROM nvcr.io/nvidia/pytorch:23.01-py3

ARG USER=1000
ARG USERNAME=user
Expand Down
8 changes: 8 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,14 @@ gpt-bigcode-mqa1:
gpt-bigcode-mqa2:
${RUN_HF} ${BIGCODE_ARGS} attention_type=3

.PHONY: santacoder-original
santacoder:
${RUN_HF} --pretrained_model=bigcode/santacoder --tokenizer=bigcode/santacoder --trust_remote_code ${EXP_ARGS}

.PHONY: santacoder
santacoder:
${RUN_HF} --pretrained_model=bigcode/santacoder-fast-inference --tokenizer=bigcode/santacoder ${EXP_ARGS}

.PHONY: optimized-santacoder
optimized-santacoder:
${RUN_HF} --pretrained_model=olivierdehaene/optimized-santacoder --tokenizer=bigcode/santacoder --trust_remote_code ${EXP_ARGS}
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
accelerate==0.15.0
bitsandbytes
safetensors
deepspeed==0.7.7
-e ./transformers

Expand Down
2 changes: 2 additions & 0 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@ def main(argv: Optional[List[str]] = None) -> None:
pipeline = pipeline_class(
model_type=args.model_type,
pretrained_model=args.pretrained_model,
pretrained_config=args.pretrained_config,
config_args=args.config_args,
tokenizer=args.tokenizer,
device=args.device,
dtype=args.dtype,
fast_init=args.fast_init,
trust_remote_code=args.trust_remote_code,
)

benchmark_end_to_end(
Expand Down
29 changes: 22 additions & 7 deletions src/pipelines/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from src.utils.fast_init import fast_init
from src.utils.logging import format_ms, log_rank_n
from src.utils.utils import parse_revision
from transformers import (
CONFIG_MAPPING,
AutoConfig,
Expand Down Expand Up @@ -41,12 +42,14 @@ def __init__(
self,
*,
model_type: Optional[str] = None,
pretrained_config: Optional[str] = None,
pretrained_model: Optional[str] = None,
config_args: Dict[str, Any],
tokenizer: str,
device: torch.device,
dtype: torch.dtype,
fast_init: bool = True,
trust_remote_code: bool = False,
):
self.initialization_metrics = {}
log_rank_n("*** Setting up tokenizer", logger.info)
Expand All @@ -60,10 +63,11 @@ def __init__(
self.dtype = dtype
self.is_int8 = self.dtype == torch.int8
self.fast_init = fast_init
self.trust_remote_code = trust_remote_code
if self.is_int8 and self.device != torch.device("cuda"):
raise ValueError(f"Model quantization not supported on device {self.device}")

self.config = self._get_config(model_type, pretrained_model, config_args)
self.config = self._get_config(model_type, pretrained_config or pretrained_model, config_args)
t2 = time.perf_counter()

logger.info(f"Model configuration: {self.config}")
Expand All @@ -86,7 +90,9 @@ def _create_model(self) -> PreTrainedModel:
log_rank_n("*** Creating model", logger.info)
with fast_init(self.device) if self.fast_init else contextlib.nullcontext():
torch_dtype = torch.float16 if self.is_int8 else self.dtype
model = AutoModelForCausalLM.from_config(config=self.config, torch_dtype=torch_dtype)
model = AutoModelForCausalLM.from_config(
config=self.config, torch_dtype=torch_dtype, trust_remote_code=self.trust_remote_code
)
t1 = time.perf_counter()
log_rank_n("*** Moving to device", logger.info)
model.to(self.device)
Expand All @@ -98,6 +104,7 @@ def _create_model(self) -> PreTrainedModel:
self.initialization_metrics["model initialization"] = t1 - t0
self.initialization_metrics["move to device"] = t2 - t1
self.initialization_metrics["initialize weights"] = t3 - t2

return model

def _reload_model(self):
Expand All @@ -118,9 +125,12 @@ def _load_pretrained(self, pretrained_model: str) -> PreTrainedModel:
log_rank_n(f"*** Loading model from {pretrained_model}", logger.info)
kwargs = {"load_in_8bit": True, "device_map": "auto"} if self.is_int8 else {"torch_dtype": self.dtype}
with fast_init(self.device) if self.fast_init else contextlib.nullcontext():
pretrained_model, revision = parse_revision(pretrained_model)
model = AutoModelForCausalLM.from_pretrained(
pretrained_model,
revision=revision,
config=self.config,
trust_remote_code=self.trust_remote_code,
**kwargs,
)
t1 = time.perf_counter()
Expand All @@ -135,7 +145,7 @@ def _load_pretrained(self, pretrained_model: str) -> PreTrainedModel:
def _get_config(
self,
model_type: Optional[str],
pretrained_model: Optional[str],
pretrained_config: Optional[str],
config_args: Dict[str, Any],
) -> PretrainedConfig:
config_args = {
Expand All @@ -145,15 +155,16 @@ def _get_config(
}

if model_type is None:
if pretrained_model is None:
if pretrained_config is None:
raise ValueError("You need to provide either --model_type or --pretrained_model")
config_class = AutoConfig
elif model_type not in CONFIG_MAPPING:
raise ValueError(f"Unknown model type: {model_type}")
else:
config_class = CONFIG_MAPPING[model_type]
config_args["model_type"] = model_type

if pretrained_model is None:
if pretrained_config is None:
config_args.update(
{
"bos_token_id": self.tokenizer.bos_token_id,
Expand All @@ -163,7 +174,10 @@ def _get_config(
)
config, unused = config_class.from_dict({}, **config_args)
else:
config, unused = config_class.from_pretrained(pretrained_model, **config_args)
pretrained_config, revision = parse_revision(pretrained_config)
config, unused = config_class.from_pretrained(
pretrained_config, revision=revision, trust_remote_code=self.trust_remote_code, **config_args
)

if unused:
raise ValueError(f"There were unused configuration parameters: {tuple(unused)}")
Expand Down Expand Up @@ -216,7 +230,8 @@ def aggregate_and_format_metrics(self, metrics: List[Dict[str, Any]]):
"Latency (decode)": format_ms(mean_metrics[DECODE_TIME]),
"Latency (max)": format_ms(max(all_metrics[END_TO_END_TIME])),
"Latency (min)": format_ms(min(all_metrics[END_TO_END_TIME])),
"Tokens generated": f"{mean_metrics[NUM_GENERATED_TOKENS]:.0f}",
"Tokens generated (average)": f"{mean_metrics[NUM_GENERATED_TOKENS]:.0f}",
"Tokens generated (total)": f"{np.sum(all_metrics[NUM_GENERATED_TOKENS]).item():.0f}",
"Throughput (model)": f"{model_throughput:.2f} tokens/s",
"Throughput (end to end)": f"{throughput:.2f} tokens/s",
"Token time (end to end)": f"{format_ms(throughput ** -1)}/token",
Expand Down
16 changes: 11 additions & 5 deletions src/utils/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@ def get_arg_parser() -> ArgumentParser:

# Model
parser.add_argument("--model_type")
parser.add_argument("--pretrained_config")
parser.add_argument("--pretrained_model")
parser.add_argument("--tokenizer", default="gpt2")
parser.add_argument("--trust_remote_code", action="store_true")
parser.add_argument("config_args", nargs="*")

# Runtime
Expand Down Expand Up @@ -47,10 +49,14 @@ def get_arg_parser() -> ArgumentParser:
def parse_config_args(config_args: List[str]) -> typing.Dict[str, Any]:
parsed_config_args = {}
for config_arg in config_args:
try:
key, value = [x.strip() for x in config_arg.split("=")]
except ValueError:
raise ValueError(f"Cannot parse argument: {config_arg}")
split_arg = [x.strip() for x in config_arg.split("=", 1)]
if len(split_arg) != 2:
raise ValueError(f"Cannot parse argument (not in 'key=value' format): {config_arg}")
key, value = split_arg
if not key.isidentifier():
raise ValueError(f"Invalid argument (not a python identifier): {key}")
if key in parsed_config_args:
raise ValueError(f"Duplicate argument: {key}")
if value.lower() == "true":
value = True
elif value.lower() == "false":
Expand All @@ -65,7 +71,7 @@ def parse_config_args(config_args: List[str]) -> typing.Dict[str, Any]:
value = float(value)
except ValueError:
pass
parsed_config_args[key.strip()] = value
parsed_config_args[key] = value
return parsed_config_args


Expand Down
48 changes: 34 additions & 14 deletions src/utils/benchmark.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import contextlib
import gc
import logging
import time
from typing import List, Union

import torch

from src.pipelines.pipeline import Pipeline
from src.utils.logging import format_ms, log_dict, log_rank_n
from src.utils.logging import format_mib, format_ms, log_dict, log_rank_n


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -91,8 +92,27 @@ def benchmark_end_to_end(
else:
profiler = contextlib.nullcontext()

benchmark_stats = {
"Model parameters": pipeline.get_num_parameters(),
"Batch size": len(inputs),
**generate_kwargs,
**pipeline.get_initialization_metrics(),
"Warmup cycles": skip + warmup,
"Benchmark cycles": cycles,
"Total cycles": skip + warmup + cycles,
}

if pipeline.device.type == "cuda":
benchmark_stats["Initial memory used"] = format_mib(torch.cuda.memory_allocated())
benchmark_stats["Initial memory reserved"] = format_mib(torch.cuda.memory_reserved())
torch.cuda.reset_peak_memory_stats()

t0 = time.perf_counter()
with profiler as p:
for step in range(skip + warmup + cycles):
if step == skip + warmup:
t1 = time.perf_counter()
benchmark_stats["Warmup time"] = format_ms(t1 - t0)
generated_text, metrics = pipeline(inputs, **generate_kwargs)
if profile:
p.step()
Expand All @@ -108,18 +128,18 @@ def benchmark_end_to_end(
torch.cuda.synchronize()
gc.collect()
torch.cuda.empty_cache()
if pipeline.device.type == "cuda":
benchmark_stats["Memory used"] = format_mib(torch.cuda.memory_allocated())
benchmark_stats["Memory reserved"] = format_mib(torch.cuda.memory_reserved())
benchmark_stats["Max memory used"] = format_mib(torch.cuda.max_memory_allocated())
benchmark_stats["Max memory reserved"] = format_mib(torch.cuda.max_memory_reserved())

t2 = time.perf_counter()
benchmark_stats["Benchmark time"] = format_ms(t2 - t1)
benchmark_stats["Total time"] = format_ms(t2 - t0)

if len(all_metrics) > 0:
log_rank_n("*** Performance metrics:", logger.info)
log_dict(pipeline.aggregate_and_format_metrics(all_metrics), logger.info)

log_rank_n("*** Benchmarking stats:", logger.info)
log_dict(
{
"Model parameters": pipeline.get_num_parameters(),
"Batch size": len(inputs),
**generate_kwargs,
**pipeline.get_initialization_metrics(),
},
logger.info,
)
benchmark_stats.update(pipeline.aggregate_and_format_metrics(all_metrics))

log_rank_n("*** Benchmark results:", logger.info)
log_dict(benchmark_stats, logger.info)
4 changes: 4 additions & 0 deletions src/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,7 @@ def log_dict(data: dict, logger: Callable = logging.info, rank: int = 0):

def format_ms(t: float):
return f"{1000 * t:.2f} ms"


def format_mib(m: float):
return f"{m/2**20:.0f} MiB"
11 changes: 10 additions & 1 deletion src/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import time
from functools import partial
from typing import Any, List, Tuple, Union
from typing import Any, List, Optional, Tuple, Union


def run_and_log_time(execs: Union[List[partial], partial]) -> Tuple[Union[List[Any], Any], float]:
Expand All @@ -16,3 +16,12 @@ def run_and_log_time(execs: Union[List[partial], partial]) -> Tuple[Union[List[A

time_elapsed = time.perf_counter() - start_time
return results, time_elapsed


def parse_revision(pretrained_model: Optional[str]) -> Tuple[Optional[str], Optional[str]]:
revision = None
if pretrained_model is not None:
pretrained_split = pretrained_model.split(":", 1)
if len(pretrained_split) == 2:
pretrained_model, revision = pretrained_split
return pretrained_model, revision