From ac86197034057f833e9725f75b034c206eaf942c Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Wed, 18 Oct 2023 08:34:56 -0700 Subject: [PATCH] [python] Reformat python code (#1195) --- engines/python/setup/djl_python/deepspeed.py | 71 +++++++++++-------- .../python/setup/djl_python/huggingface.py | 56 +++++++++------ .../rolling_batch/scheduler_rolling_batch.py | 7 +- .../python/setup/djl_python/sm_log_filter.py | 15 ++-- .../djl_python/tests/test_sm_log_filter.py | 1 - engines/python/setup/djl_python_engine.py | 3 +- tests/integration/llm/prepare.py | 1 - tests/integration/lmic_test_builder.py | 3 +- .../java/ai/djl/serving/wlm/ModelInfo.java | 2 +- 9 files changed, 93 insertions(+), 66 deletions(-) diff --git a/engines/python/setup/djl_python/deepspeed.py b/engines/python/setup/djl_python/deepspeed.py index fc527cd70..2da344ea3 100644 --- a/engines/python/setup/djl_python/deepspeed.py +++ b/engines/python/setup/djl_python/deepspeed.py @@ -38,10 +38,7 @@ from typing import Optional from peft import PeftConfig, PeftModel -SUPPORTED_QUANTIZATION_MODE = [ - "smoothquant", - "dynamic_int8" -] +SUPPORTED_QUANTIZATION_MODE = ["smoothquant", "dynamic_int8"] SMOOTHQUANT_SUPPORTED_MODEL_TYPES = { "gpt2", @@ -204,9 +201,13 @@ def _parse_smoothquant_properties(self, properties): if 0 <= float(properties['smoothquant_alpha']) <= 1: self.smoothquant_alpha = properties['smoothquant_alpha'] else: - raise ValueError(f"${properties['smoothquant_alpha']} is not between 0 and 1.") + raise ValueError( + f"${properties['smoothquant_alpha']} is not between 0 and 1." + ) except ValueError: - raise ValueError(f"${properties['smoothquant_alpha']} cannot convert to float number.") + raise ValueError( + f"${properties['smoothquant_alpha']} cannot convert to float number." + ) def _get_ds_config(self, properties: dict): ds_config = { @@ -235,12 +236,12 @@ def _get_ds_config(self, properties: dict): "dtype should also be provided for checkpoint loading") if self.quantize_mode: - ds_config['dynamic_quant'] = {'enabled': True, 'use_cutlass': False} + ds_config['dynamic_quant'] = { + 'enabled': True, + 'use_cutlass': False + } if self.quantize_mode == 'smoothquant': - smoothing_value = { - 'smooth': True, - 'calibrate': True - } + smoothing_value = {'smooth': True, 'calibrate': True} if self.smoothquant_alpha: smoothing_value['alpha'] = self.smoothquant_alpha ds_config['smoothing'] = smoothing_value @@ -265,7 +266,9 @@ def _validate_model_type_and_task(self): if self.quantize_mode == \ 'smoothquant' and self.model_config.model_type not in SMOOTHQUANT_SUPPORTED_MODEL_TYPES: - raise ValueError(f"${self.quantize_mode} does not support model ${self.model_config.model_type}") + raise ValueError( + f"${self.quantize_mode} does not support model ${self.model_config.model_type}" + ) def _read_model_config(self): try: @@ -298,14 +301,19 @@ def infer_task_from_model_architecture(self, config: PretrainedConfig): f"Task could not be inferred from model config. " f"Please manually set `task` in serving.properties.") - def get_model_pretrained(self, model_id_or_path, torch_dtype='auto', **kwargs): + def get_model_pretrained(self, + model_id_or_path, + torch_dtype='auto', + **kwargs): tokenizer = AutoTokenizer.from_pretrained(model_id_or_path) - model = TASK_TO_MODEL[self.task].from_pretrained(model_id_or_path, torch_dtype=torch_dtype, **kwargs) + model = TASK_TO_MODEL[self.task].from_pretrained( + model_id_or_path, torch_dtype=torch_dtype, **kwargs) return model, tokenizer def get_model_from_config(self, model_id_or_path, **kwargs): tokenizer = AutoTokenizer.from_pretrained(model_id_or_path) - model = TASK_TO_MODEL[self.task].from_config(self.model_config, **kwargs) + model = TASK_TO_MODEL[self.task].from_config(self.model_config, + **kwargs) return model, tokenizer def get_model(self, model_id_or_path, loading_method, **kwargs): @@ -316,16 +324,20 @@ def get_model(self, model_id_or_path, loading_method, **kwargs): elif loading_method == 'pretrained': return self.get_model_pretrained(model_id_or_path, **kwargs) else: - raise RuntimeError(f'Unsupported model loading method, this should not happen.') + raise RuntimeError( + f'Unsupported model loading method, this should not happen.') + + def load_model(self, model_id_or_path, loading_method, use_mmap_loader, + **kwargs): - def load_model(self, model_id_or_path, loading_method, use_mmap_loader, **kwargs): def load_model_with_mmap(model_id_or_path, loading_method): import mmaploader import accelerate with mmaploader.load_mmap_meta() as mmap_loader: with accelerate.init_empty_weights(): kwargs['low_cpu_mem_usage'] = False - model, tokenizer = self.get_model(model_id_or_path, loading_method, **kwargs) + model, tokenizer = self.get_model(model_id_or_path, + loading_method, **kwargs) return model, tokenizer, mmap_loader.state_dict_mmap state_dict_mmap = {} @@ -335,14 +347,18 @@ def load_model_with_mmap(model_id_or_path, loading_method): if use_mmap_loader: try: - model, tokenizer, state_dict_mmap = load_model_with_mmap(model_id_or_path, loading_method) + model, tokenizer, state_dict_mmap = load_model_with_mmap( + model_id_or_path, loading_method) done = True except: - self.logger.warning(f'failed to load model with mmap loader, will load model normally') + self.logger.warning( + f'failed to load model with mmap loader, will load model normally' + ) if not done: kwargs['low_cpu_mem_usage'] = True - model, tokenizer = self.get_model(model_id_or_path, loading_method, **kwargs) + model, tokenizer = self.get_model(model_id_or_path, loading_method, + **kwargs) return model, tokenizer, state_dict_mmap @@ -365,11 +381,8 @@ def create_model_pipeline(self): f"Please using quantization with a standard HuggingFace checkpoint or " f"turn off quantization and try again.") model, self.tokenizer, state_dict_mmap = self.load_model( - self.model_id_or_path, - 'from_config', - self.ds_config['replace_with_kernel_inject'], - **kwargs - ) + self.model_id_or_path, 'from_config', + self.ds_config['replace_with_kernel_inject'], **kwargs) elif self.peft_config is not None: self.logger.info( f"Peft Model detected. Instantiating base model {self.peft_config.base_model_name_or_path}" @@ -382,7 +395,8 @@ def create_model_pipeline(self): lora_model = PeftModel.from_pretrained(base_model, self.model_id_or_path) model = lora_model.merge_and_unload() - self.tokenizer = AutoTokenizer.from_pretrained(self.peft_config.base_model_name_or_path) + self.tokenizer = AutoTokenizer.from_pretrained( + self.peft_config.base_model_name_or_path) self.logger.info( f"Peft Model merged into base model for deepspeed compatibility" ) @@ -393,8 +407,7 @@ def create_model_pipeline(self): self.ds_config['replace_with_kernel_inject'], low_cpu_mem_usage=self.low_cpu_mem_usage, trust_remote_code=self.trust_remote_code, - **kwargs - ) + **kwargs) if self.data_type: self.ds_config["dtype"] = self.data_type else: diff --git a/engines/python/setup/djl_python/huggingface.py b/engines/python/setup/djl_python/huggingface.py index ec635eb6c..4ffefeb59 100644 --- a/engines/python/setup/djl_python/huggingface.py +++ b/engines/python/setup/djl_python/huggingface.py @@ -16,13 +16,11 @@ import re import torch -from transformers import (pipeline, Pipeline, Conversation, - AutoModelForCausalLM, AutoModelForSeq2SeqLM, - AutoTokenizer, AutoConfig, - AutoModelForSequenceClassification, - AutoModelForTokenClassification, - AutoModelForQuestionAnswering, StoppingCriteria, - StoppingCriteriaList) +from transformers import ( + pipeline, Pipeline, Conversation, AutoModelForCausalLM, + AutoModelForSeq2SeqLM, AutoTokenizer, AutoConfig, + AutoModelForSequenceClassification, AutoModelForTokenClassification, + AutoModelForQuestionAnswering, StoppingCriteria, StoppingCriteriaList) from transformers.tokenization_utils_base import PreTrainedTokenizerBase from peft import PeftConfig, PeftModel, PeftModelForCausalLM @@ -50,14 +48,18 @@ } LMI_DIST_ADV_MODEL = { - "RWForCausalLM", "GPTNeoXForCausalLM", "T5ForConditionalGeneration", - "LlamaForCausalLM", "FalconForCausalLM", "MPTForCausalLM", + "RWForCausalLM", + "GPTNeoXForCausalLM", + "T5ForConditionalGeneration", + "LlamaForCausalLM", + "FalconForCausalLM", + "MPTForCausalLM", "GPTBigCodeForCausalLM", } # https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#efficient-inference-on-a-single-gpu FLASH_2_SUPPORTED_MODELS = { - "LlamaForCausalLM", "RWForCausalLM", "FalconForCausalLM" + "LlamaForCausalLM", "RWForCausalLM", "FalconForCausalLM" } PEFT_MODEL_TASK_TO_CLS = { @@ -114,14 +116,18 @@ def get_rolling_batch_class_from_str(rolling_batch_type: str, is_mpi: bool, return VLLMRollingBatch raise ValueError(f"Invalid rolling batch type: {rolling_batch_type}") + class StopWord(StoppingCriteria): + def __init__(self, tokenizer, stop_seq): StoppingCriteria.__init__(self) self.tokenizer = tokenizer self.stop_seq = stop_seq - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs): - decoded_input_ids = self.tokenizer.decode(input_ids[0][-len(self.stop_seq):]) + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, + **kwargs): + decoded_input_ids = self.tokenizer.decode( + input_ids[0][-len(self.stop_seq):]) matches = re.search(self.stop_seq, decoded_input_ids) @@ -130,6 +136,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa return False + class HuggingFaceService(object): def __init__(self): @@ -202,8 +209,8 @@ def initialize(self, properties: dict): properties.get("dtype")) if "revision" in properties: kwargs["revision"] = properties.get('revision') - self.disable_flash_attn = properties.get( - "disable_flash_attn", "true").lower() == 'true' + self.disable_flash_attn = properties.get("disable_flash_attn", + "true").lower() == 'true' self.rolling_batch_type = properties.get("rolling_batch", None) self._read_model_config(model_id_or_path, @@ -249,9 +256,12 @@ def parse_stop_sequence_input(self, stop_sequence): Input: stop_sequence (string) Output: list of strings """ - assert stop_sequence[0] == '[' and stop_sequence[-1] == ']', "option.stop_sequence not properly formatted" + assert stop_sequence[0] == '[' and stop_sequence[ + -1] == ']', "option.stop_sequence not properly formatted" stop_sequence = stop_sequence.replace(", ", ",") - stop_seq_list = [element[1:-1] for element in stop_sequence[1:-1].split(",")] + stop_seq_list = [ + element[1:-1] for element in stop_sequence[1:-1].split(",") + ] return stop_seq_list def load_stopping_criteria_list(self, stop_sequence): @@ -262,7 +272,7 @@ def load_stopping_criteria_list(self, stop_sequence): """ if self.tokenizer is None: return - + stop_seq_list = self.parse_stop_sequence_input(stop_sequence) stopwords = [] @@ -505,7 +515,8 @@ def _init_model_and_tokenizer(self, model_id_or_path: str, **kwargs): model_cls = AutoModelForSeq2SeqLM else: model_cls = AutoModelForCausalLM - if architectures[0] in FLASH_2_SUPPORTED_MODELS and enable_flash() and not self.disable_flash_attn: + if architectures[0] in FLASH_2_SUPPORTED_MODELS and enable_flash( + ) and not self.disable_flash_attn: kwargs['use_flash_attention_2'] = True if self.peft_config is not None: @@ -631,15 +642,14 @@ def register_adapter(inputs: Input): adapter_path = inputs.get_property("src") if not os.path.exists(adapter_path): raise ValueError( - f"Only local LoRA models are supported. {adapter_path} is not a valid path") - logging.info( - f"Registering adapter {adapter_name} from {adapter_path}") + f"Only local LoRA models are supported. {adapter_path} is not a valid path" + ) + logging.info(f"Registering adapter {adapter_name} from {adapter_path}") if isinstance(_service.model, PeftModel): _service.model.load_adapter(adapter_path, adapter_name) else: _service.model = PeftModel.from_pretrained(_service.model, - adapter_path, - adapter_name) + adapter_path, adapter_name) if isinstance(_service.hf_pipeline, Pipeline): _service.hf_pipeline.model = _service.model diff --git a/engines/python/setup/djl_python/rolling_batch/scheduler_rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/scheduler_rolling_batch.py index 8b0836f22..e12f5ad6c 100644 --- a/engines/python/setup/djl_python/rolling_batch/scheduler_rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/scheduler_rolling_batch.py @@ -24,9 +24,10 @@ DEFAULT_SEARCH_ALGORITHM = 'greedy' # https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#efficient-inference-on-a-single-gpu FLASH_2_SUPPORTED_MODELS = { - "LlamaForCausalLM", "RWForCausalLM", "FalconForCausalLM" + "LlamaForCausalLM", "RWForCausalLM", "FalconForCausalLM" } + def enable_flash(): if torch.cuda.is_available(): major, _ = torch.cuda.get_device_capability() @@ -134,8 +135,8 @@ def _init_model_and_tokenizer(self, device_map = kwargs.pop('device_map') if architectures[0] in FLASH_2_SUPPORTED_MODELS and enable_flash(): - if properties.get( - "disable_flash_attn", "true").lower() != 'true': + if properties.get("disable_flash_attn", + "true").lower() != 'true': kwargs['use_flash_attention_2'] = True if "lmi_dist_sharding" == multi_gpu: diff --git a/engines/python/setup/djl_python/sm_log_filter.py b/engines/python/setup/djl_python/sm_log_filter.py index b135c5913..e8a8aefb6 100644 --- a/engines/python/setup/djl_python/sm_log_filter.py +++ b/engines/python/setup/djl_python/sm_log_filter.py @@ -11,7 +11,6 @@ # BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for # the specific language governing permissions and limitations under the License. - import copy from collections import defaultdict from djl_python import __version__ @@ -27,18 +26,24 @@ def filter(self, record): try: if isinstance(record.msg, str): for i in self.sm_log_markers: - if record.msg.startswith(i+':'): + if record.msg.startswith(i + ':'): altered_record = copy.deepcopy(record) - tag, metric_name, metric = [i.strip() for i in altered_record.msg.split(':')] + tag, metric_name, metric = [ + i.strip() for i in altered_record.msg.split(':') + ] value, units = metric.split(' ') - altered_metric_name = ''.join([word[0].upper()+word[1:] for word in metric_name.split(' ')]) + altered_metric_name = ''.join([ + word[0].upper() + word[1:] + for word in metric_name.split(' ') + ]) altered_record.msg = f"{tag}.Count:{self.count(altered_metric_name)}|#DJLServing:{__version__},{altered_metric_name}:{value} {units}" return altered_record return False else: return False except Exception as exc: - logging.warning(f"Forwarding {str(record)} failed due to {str(exc)}") + logging.warning( + f"Forwarding {str(record)} failed due to {str(exc)}") return False def count(self, key): diff --git a/engines/python/setup/djl_python/tests/test_sm_log_filter.py b/engines/python/setup/djl_python/tests/test_sm_log_filter.py index f593ae0ac..daeea781b 100644 --- a/engines/python/setup/djl_python/tests/test_sm_log_filter.py +++ b/engines/python/setup/djl_python/tests/test_sm_log_filter.py @@ -40,4 +40,3 @@ def test_filter_miss(self): record.msg = f"LLM sharding and compilation latency: 845.62 : secs" actual = filter.filter(record) self.assertFalse(actual) - diff --git a/engines/python/setup/djl_python_engine.py b/engines/python/setup/djl_python_engine.py index 839cc431c..becd2cb48 100644 --- a/engines/python/setup/djl_python_engine.py +++ b/engines/python/setup/djl_python_engine.py @@ -181,7 +181,8 @@ def configure_sm_logging(): if 'SM_TELEMETRY_LOG_REV_2022_12' in os.environ: # https://docs.aws.amazon.com/deep-learning-containers/latest/devguide/logging-and-monitoring.html root_logger = logging.getLogger() - sm_log_handler = logging.FileHandler(filename=os.getenv('SM_TELEMETRY_LOG_REV_2022_12')) + sm_log_handler = logging.FileHandler( + filename=os.getenv('SM_TELEMETRY_LOG_REV_2022_12')) sm_log_handler.addFilter(SMLogFilter()) root_logger.addHandler(sm_log_handler) diff --git a/tests/integration/llm/prepare.py b/tests/integration/llm/prepare.py index def86992f..b95cf32cb 100644 --- a/tests/integration/llm/prepare.py +++ b/tests/integration/llm/prepare.py @@ -378,7 +378,6 @@ "option.quantize": "smoothquant", "option.smoothquant_alpha": 0.65 } - } transformers_neuronx_handler_list = { diff --git a/tests/integration/lmic_test_builder.py b/tests/integration/lmic_test_builder.py index eb565db2a..281f6b2fc 100644 --- a/tests/integration/lmic_test_builder.py +++ b/tests/integration/lmic_test_builder.py @@ -124,8 +124,7 @@ def _validate_test_series_config(self, series): required_parameters = [] if series == "performance": required_parameters = [ - "tensor_parallel", "batch_size", - "out_tokens", "count" + "tensor_parallel", "batch_size", "out_tokens", "count" ] for param in required_parameters: if param not in self.config or self.config[param] is None: diff --git a/wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java b/wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java index 77e0f7f60..31891e22b 100644 --- a/wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java +++ b/wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java @@ -748,9 +748,9 @@ void checkAvailableMemory(Device device) throws IOException { long reservedMemory = intValue(prop, "reserved_memory_mb", defMemory) * 1024L * 1024; String tpDegreeStr = Utils.getenv("TENSOR_PARALLEL_DEGREE", "0"); tpDegreeStr = prop.getProperty("option.tensor_parallel_degree", tpDegreeStr); - Engine eng = Engine.getEngine(engineName); int tpDegree; if ("max".equals(tpDegreeStr)) { + Engine eng = Engine.getEngine(engineName); if (eng.getGpuCount() > 0) { tpDegree = eng.getGpuCount(); } else {