diff --git a/Dockerfile b/Dockerfile index 58c34e2..5249004 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM nvcr.io/nvidia/pytorch:22.11-py3 +FROM nvcr.io/nvidia/pytorch:23.01-py3 ARG USER=1000 ARG USERNAME=user diff --git a/Makefile b/Makefile index 9915de8..84e7ba5 100644 --- a/Makefile +++ b/Makefile @@ -55,6 +55,14 @@ gpt-bigcode-mqa1: gpt-bigcode-mqa2: ${RUN_HF} ${BIGCODE_ARGS} attention_type=3 +.PHONY: santacoder-original +santacoder: + ${RUN_HF} --pretrained_model=bigcode/santacoder --tokenizer=bigcode/santacoder --trust_remote_code ${EXP_ARGS} + .PHONY: santacoder santacoder: ${RUN_HF} --pretrained_model=bigcode/santacoder-fast-inference --tokenizer=bigcode/santacoder ${EXP_ARGS} + +.PHONY: optimized-santacoder +optimized-santacoder: + ${RUN_HF} --pretrained_model=olivierdehaene/optimized-santacoder --tokenizer=bigcode/santacoder --trust_remote_code ${EXP_ARGS} diff --git a/requirements.txt b/requirements.txt index 2b8ca55..6e91e92 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ accelerate==0.15.0 bitsandbytes +safetensors deepspeed==0.7.7 -e ./transformers diff --git a/src/main.py b/src/main.py index 8d7ea80..900fa31 100644 --- a/src/main.py +++ b/src/main.py @@ -14,11 +14,13 @@ def main(argv: Optional[List[str]] = None) -> None: pipeline = pipeline_class( model_type=args.model_type, pretrained_model=args.pretrained_model, + pretrained_config=args.pretrained_config, config_args=args.config_args, tokenizer=args.tokenizer, device=args.device, dtype=args.dtype, fast_init=args.fast_init, + trust_remote_code=args.trust_remote_code, ) benchmark_end_to_end( diff --git a/src/pipelines/pipeline.py b/src/pipelines/pipeline.py index 0e212c4..86a495c 100644 --- a/src/pipelines/pipeline.py +++ b/src/pipelines/pipeline.py @@ -9,6 +9,7 @@ from src.utils.fast_init import fast_init from src.utils.logging import format_ms, log_rank_n +from src.utils.utils import parse_revision from transformers import ( CONFIG_MAPPING, AutoConfig, @@ -41,12 +42,14 @@ def __init__( self, *, model_type: Optional[str] = None, + pretrained_config: Optional[str] = None, pretrained_model: Optional[str] = None, config_args: Dict[str, Any], tokenizer: str, device: torch.device, dtype: torch.dtype, fast_init: bool = True, + trust_remote_code: bool = False, ): self.initialization_metrics = {} log_rank_n("*** Setting up tokenizer", logger.info) @@ -60,10 +63,11 @@ def __init__( self.dtype = dtype self.is_int8 = self.dtype == torch.int8 self.fast_init = fast_init + self.trust_remote_code = trust_remote_code if self.is_int8 and self.device != torch.device("cuda"): raise ValueError(f"Model quantization not supported on device {self.device}") - self.config = self._get_config(model_type, pretrained_model, config_args) + self.config = self._get_config(model_type, pretrained_config or pretrained_model, config_args) t2 = time.perf_counter() logger.info(f"Model configuration: {self.config}") @@ -86,7 +90,9 @@ def _create_model(self) -> PreTrainedModel: log_rank_n("*** Creating model", logger.info) with fast_init(self.device) if self.fast_init else contextlib.nullcontext(): torch_dtype = torch.float16 if self.is_int8 else self.dtype - model = AutoModelForCausalLM.from_config(config=self.config, torch_dtype=torch_dtype) + model = AutoModelForCausalLM.from_config( + config=self.config, torch_dtype=torch_dtype, trust_remote_code=self.trust_remote_code + ) t1 = time.perf_counter() log_rank_n("*** Moving to device", logger.info) model.to(self.device) @@ -98,6 +104,7 @@ def _create_model(self) -> PreTrainedModel: self.initialization_metrics["model initialization"] = t1 - t0 self.initialization_metrics["move to device"] = t2 - t1 self.initialization_metrics["initialize weights"] = t3 - t2 + return model def _reload_model(self): @@ -118,9 +125,12 @@ def _load_pretrained(self, pretrained_model: str) -> PreTrainedModel: log_rank_n(f"*** Loading model from {pretrained_model}", logger.info) kwargs = {"load_in_8bit": True, "device_map": "auto"} if self.is_int8 else {"torch_dtype": self.dtype} with fast_init(self.device) if self.fast_init else contextlib.nullcontext(): + pretrained_model, revision = parse_revision(pretrained_model) model = AutoModelForCausalLM.from_pretrained( pretrained_model, + revision=revision, config=self.config, + trust_remote_code=self.trust_remote_code, **kwargs, ) t1 = time.perf_counter() @@ -135,7 +145,7 @@ def _load_pretrained(self, pretrained_model: str) -> PreTrainedModel: def _get_config( self, model_type: Optional[str], - pretrained_model: Optional[str], + pretrained_config: Optional[str], config_args: Dict[str, Any], ) -> PretrainedConfig: config_args = { @@ -145,15 +155,16 @@ def _get_config( } if model_type is None: - if pretrained_model is None: + if pretrained_config is None: raise ValueError("You need to provide either --model_type or --pretrained_model") config_class = AutoConfig elif model_type not in CONFIG_MAPPING: raise ValueError(f"Unknown model type: {model_type}") else: config_class = CONFIG_MAPPING[model_type] + config_args["model_type"] = model_type - if pretrained_model is None: + if pretrained_config is None: config_args.update( { "bos_token_id": self.tokenizer.bos_token_id, @@ -163,7 +174,10 @@ def _get_config( ) config, unused = config_class.from_dict({}, **config_args) else: - config, unused = config_class.from_pretrained(pretrained_model, **config_args) + pretrained_config, revision = parse_revision(pretrained_config) + config, unused = config_class.from_pretrained( + pretrained_config, revision=revision, trust_remote_code=self.trust_remote_code, **config_args + ) if unused: raise ValueError(f"There were unused configuration parameters: {tuple(unused)}") @@ -216,7 +230,8 @@ def aggregate_and_format_metrics(self, metrics: List[Dict[str, Any]]): "Latency (decode)": format_ms(mean_metrics[DECODE_TIME]), "Latency (max)": format_ms(max(all_metrics[END_TO_END_TIME])), "Latency (min)": format_ms(min(all_metrics[END_TO_END_TIME])), - "Tokens generated": f"{mean_metrics[NUM_GENERATED_TOKENS]:.0f}", + "Tokens generated (average)": f"{mean_metrics[NUM_GENERATED_TOKENS]:.0f}", + "Tokens generated (total)": f"{np.sum(all_metrics[NUM_GENERATED_TOKENS]).item():.0f}", "Throughput (model)": f"{model_throughput:.2f} tokens/s", "Throughput (end to end)": f"{throughput:.2f} tokens/s", "Token time (end to end)": f"{format_ms(throughput ** -1)}/token", diff --git a/src/utils/arguments.py b/src/utils/arguments.py index ccc5684..d5b4d84 100644 --- a/src/utils/arguments.py +++ b/src/utils/arguments.py @@ -11,8 +11,10 @@ def get_arg_parser() -> ArgumentParser: # Model parser.add_argument("--model_type") + parser.add_argument("--pretrained_config") parser.add_argument("--pretrained_model") parser.add_argument("--tokenizer", default="gpt2") + parser.add_argument("--trust_remote_code", action="store_true") parser.add_argument("config_args", nargs="*") # Runtime @@ -47,10 +49,14 @@ def get_arg_parser() -> ArgumentParser: def parse_config_args(config_args: List[str]) -> typing.Dict[str, Any]: parsed_config_args = {} for config_arg in config_args: - try: - key, value = [x.strip() for x in config_arg.split("=")] - except ValueError: - raise ValueError(f"Cannot parse argument: {config_arg}") + split_arg = [x.strip() for x in config_arg.split("=", 1)] + if len(split_arg) != 2: + raise ValueError(f"Cannot parse argument (not in 'key=value' format): {config_arg}") + key, value = split_arg + if not key.isidentifier(): + raise ValueError(f"Invalid argument (not a python identifier): {key}") + if key in parsed_config_args: + raise ValueError(f"Duplicate argument: {key}") if value.lower() == "true": value = True elif value.lower() == "false": @@ -65,7 +71,7 @@ def parse_config_args(config_args: List[str]) -> typing.Dict[str, Any]: value = float(value) except ValueError: pass - parsed_config_args[key.strip()] = value + parsed_config_args[key] = value return parsed_config_args diff --git a/src/utils/benchmark.py b/src/utils/benchmark.py index 7882f7f..ef67dd3 100644 --- a/src/utils/benchmark.py +++ b/src/utils/benchmark.py @@ -1,12 +1,13 @@ import contextlib import gc import logging +import time from typing import List, Union import torch from src.pipelines.pipeline import Pipeline -from src.utils.logging import format_ms, log_dict, log_rank_n +from src.utils.logging import format_mib, format_ms, log_dict, log_rank_n logger = logging.getLogger(__name__) @@ -91,8 +92,27 @@ def benchmark_end_to_end( else: profiler = contextlib.nullcontext() + benchmark_stats = { + "Model parameters": pipeline.get_num_parameters(), + "Batch size": len(inputs), + **generate_kwargs, + **pipeline.get_initialization_metrics(), + "Warmup cycles": skip + warmup, + "Benchmark cycles": cycles, + "Total cycles": skip + warmup + cycles, + } + + if pipeline.device.type == "cuda": + benchmark_stats["Initial memory used"] = format_mib(torch.cuda.memory_allocated()) + benchmark_stats["Initial memory reserved"] = format_mib(torch.cuda.memory_reserved()) + torch.cuda.reset_peak_memory_stats() + + t0 = time.perf_counter() with profiler as p: for step in range(skip + warmup + cycles): + if step == skip + warmup: + t1 = time.perf_counter() + benchmark_stats["Warmup time"] = format_ms(t1 - t0) generated_text, metrics = pipeline(inputs, **generate_kwargs) if profile: p.step() @@ -108,18 +128,18 @@ def benchmark_end_to_end( torch.cuda.synchronize() gc.collect() torch.cuda.empty_cache() + if pipeline.device.type == "cuda": + benchmark_stats["Memory used"] = format_mib(torch.cuda.memory_allocated()) + benchmark_stats["Memory reserved"] = format_mib(torch.cuda.memory_reserved()) + benchmark_stats["Max memory used"] = format_mib(torch.cuda.max_memory_allocated()) + benchmark_stats["Max memory reserved"] = format_mib(torch.cuda.max_memory_reserved()) + + t2 = time.perf_counter() + benchmark_stats["Benchmark time"] = format_ms(t2 - t1) + benchmark_stats["Total time"] = format_ms(t2 - t0) if len(all_metrics) > 0: - log_rank_n("*** Performance metrics:", logger.info) - log_dict(pipeline.aggregate_and_format_metrics(all_metrics), logger.info) - - log_rank_n("*** Benchmarking stats:", logger.info) - log_dict( - { - "Model parameters": pipeline.get_num_parameters(), - "Batch size": len(inputs), - **generate_kwargs, - **pipeline.get_initialization_metrics(), - }, - logger.info, - ) + benchmark_stats.update(pipeline.aggregate_and_format_metrics(all_metrics)) + + log_rank_n("*** Benchmark results:", logger.info) + log_dict(benchmark_stats, logger.info) diff --git a/src/utils/logging.py b/src/utils/logging.py index 4ec8a39..9c28276 100644 --- a/src/utils/logging.py +++ b/src/utils/logging.py @@ -43,3 +43,7 @@ def log_dict(data: dict, logger: Callable = logging.info, rank: int = 0): def format_ms(t: float): return f"{1000 * t:.2f} ms" + + +def format_mib(m: float): + return f"{m/2**20:.0f} MiB" diff --git a/src/utils/utils.py b/src/utils/utils.py index d678fe3..f18079f 100644 --- a/src/utils/utils.py +++ b/src/utils/utils.py @@ -1,6 +1,6 @@ import time from functools import partial -from typing import Any, List, Tuple, Union +from typing import Any, List, Optional, Tuple, Union def run_and_log_time(execs: Union[List[partial], partial]) -> Tuple[Union[List[Any], Any], float]: @@ -16,3 +16,12 @@ def run_and_log_time(execs: Union[List[partial], partial]) -> Tuple[Union[List[A time_elapsed = time.perf_counter() - start_time return results, time_elapsed + + +def parse_revision(pretrained_model: Optional[str]) -> Tuple[Optional[str], Optional[str]]: + revision = None + if pretrained_model is not None: + pretrained_split = pretrained_model.split(":", 1) + if len(pretrained_split) == 2: + pretrained_model, revision = pretrained_split + return pretrained_model, revision