Skip to content

Commit

Permalink
[python] Reformat python code (#1195)
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu committed Oct 18, 2023
1 parent ff0f654 commit ac86197
Show file tree
Hide file tree
Showing 9 changed files with 93 additions and 66 deletions.
71 changes: 42 additions & 29 deletions engines/python/setup/djl_python/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,7 @@
from typing import Optional
from peft import PeftConfig, PeftModel

SUPPORTED_QUANTIZATION_MODE = [
"smoothquant",
"dynamic_int8"
]
SUPPORTED_QUANTIZATION_MODE = ["smoothquant", "dynamic_int8"]

SMOOTHQUANT_SUPPORTED_MODEL_TYPES = {
"gpt2",
Expand Down Expand Up @@ -204,9 +201,13 @@ def _parse_smoothquant_properties(self, properties):
if 0 <= float(properties['smoothquant_alpha']) <= 1:
self.smoothquant_alpha = properties['smoothquant_alpha']
else:
raise ValueError(f"${properties['smoothquant_alpha']} is not between 0 and 1.")
raise ValueError(
f"${properties['smoothquant_alpha']} is not between 0 and 1."
)
except ValueError:
raise ValueError(f"${properties['smoothquant_alpha']} cannot convert to float number.")
raise ValueError(
f"${properties['smoothquant_alpha']} cannot convert to float number."
)

def _get_ds_config(self, properties: dict):
ds_config = {
Expand Down Expand Up @@ -235,12 +236,12 @@ def _get_ds_config(self, properties: dict):
"dtype should also be provided for checkpoint loading")

if self.quantize_mode:
ds_config['dynamic_quant'] = {'enabled': True, 'use_cutlass': False}
ds_config['dynamic_quant'] = {
'enabled': True,
'use_cutlass': False
}
if self.quantize_mode == 'smoothquant':
smoothing_value = {
'smooth': True,
'calibrate': True
}
smoothing_value = {'smooth': True, 'calibrate': True}
if self.smoothquant_alpha:
smoothing_value['alpha'] = self.smoothquant_alpha
ds_config['smoothing'] = smoothing_value
Expand All @@ -265,7 +266,9 @@ def _validate_model_type_and_task(self):

if self.quantize_mode == \
'smoothquant' and self.model_config.model_type not in SMOOTHQUANT_SUPPORTED_MODEL_TYPES:
raise ValueError(f"${self.quantize_mode} does not support model ${self.model_config.model_type}")
raise ValueError(
f"${self.quantize_mode} does not support model ${self.model_config.model_type}"
)

def _read_model_config(self):
try:
Expand Down Expand Up @@ -298,14 +301,19 @@ def infer_task_from_model_architecture(self, config: PretrainedConfig):
f"Task could not be inferred from model config. "
f"Please manually set `task` in serving.properties.")

def get_model_pretrained(self, model_id_or_path, torch_dtype='auto', **kwargs):
def get_model_pretrained(self,
model_id_or_path,
torch_dtype='auto',
**kwargs):
tokenizer = AutoTokenizer.from_pretrained(model_id_or_path)
model = TASK_TO_MODEL[self.task].from_pretrained(model_id_or_path, torch_dtype=torch_dtype, **kwargs)
model = TASK_TO_MODEL[self.task].from_pretrained(
model_id_or_path, torch_dtype=torch_dtype, **kwargs)
return model, tokenizer

def get_model_from_config(self, model_id_or_path, **kwargs):
tokenizer = AutoTokenizer.from_pretrained(model_id_or_path)
model = TASK_TO_MODEL[self.task].from_config(self.model_config, **kwargs)
model = TASK_TO_MODEL[self.task].from_config(self.model_config,
**kwargs)
return model, tokenizer

def get_model(self, model_id_or_path, loading_method, **kwargs):
Expand All @@ -316,16 +324,20 @@ def get_model(self, model_id_or_path, loading_method, **kwargs):
elif loading_method == 'pretrained':
return self.get_model_pretrained(model_id_or_path, **kwargs)
else:
raise RuntimeError(f'Unsupported model loading method, this should not happen.')
raise RuntimeError(
f'Unsupported model loading method, this should not happen.')

def load_model(self, model_id_or_path, loading_method, use_mmap_loader,
**kwargs):

def load_model(self, model_id_or_path, loading_method, use_mmap_loader, **kwargs):
def load_model_with_mmap(model_id_or_path, loading_method):
import mmaploader
import accelerate
with mmaploader.load_mmap_meta() as mmap_loader:
with accelerate.init_empty_weights():
kwargs['low_cpu_mem_usage'] = False
model, tokenizer = self.get_model(model_id_or_path, loading_method, **kwargs)
model, tokenizer = self.get_model(model_id_or_path,
loading_method, **kwargs)
return model, tokenizer, mmap_loader.state_dict_mmap

state_dict_mmap = {}
Expand All @@ -335,14 +347,18 @@ def load_model_with_mmap(model_id_or_path, loading_method):

if use_mmap_loader:
try:
model, tokenizer, state_dict_mmap = load_model_with_mmap(model_id_or_path, loading_method)
model, tokenizer, state_dict_mmap = load_model_with_mmap(
model_id_or_path, loading_method)
done = True
except:
self.logger.warning(f'failed to load model with mmap loader, will load model normally')
self.logger.warning(
f'failed to load model with mmap loader, will load model normally'
)

if not done:
kwargs['low_cpu_mem_usage'] = True
model, tokenizer = self.get_model(model_id_or_path, loading_method, **kwargs)
model, tokenizer = self.get_model(model_id_or_path, loading_method,
**kwargs)

return model, tokenizer, state_dict_mmap

Expand All @@ -365,11 +381,8 @@ def create_model_pipeline(self):
f"Please using quantization with a standard HuggingFace checkpoint or "
f"turn off quantization and try again.")
model, self.tokenizer, state_dict_mmap = self.load_model(
self.model_id_or_path,
'from_config',
self.ds_config['replace_with_kernel_inject'],
**kwargs
)
self.model_id_or_path, 'from_config',
self.ds_config['replace_with_kernel_inject'], **kwargs)
elif self.peft_config is not None:
self.logger.info(
f"Peft Model detected. Instantiating base model {self.peft_config.base_model_name_or_path}"
Expand All @@ -382,7 +395,8 @@ def create_model_pipeline(self):
lora_model = PeftModel.from_pretrained(base_model,
self.model_id_or_path)
model = lora_model.merge_and_unload()
self.tokenizer = AutoTokenizer.from_pretrained(self.peft_config.base_model_name_or_path)
self.tokenizer = AutoTokenizer.from_pretrained(
self.peft_config.base_model_name_or_path)
self.logger.info(
f"Peft Model merged into base model for deepspeed compatibility"
)
Expand All @@ -393,8 +407,7 @@ def create_model_pipeline(self):
self.ds_config['replace_with_kernel_inject'],
low_cpu_mem_usage=self.low_cpu_mem_usage,
trust_remote_code=self.trust_remote_code,
**kwargs
)
**kwargs)
if self.data_type:
self.ds_config["dtype"] = self.data_type
else:
Expand Down
56 changes: 33 additions & 23 deletions engines/python/setup/djl_python/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,11 @@
import re

import torch
from transformers import (pipeline, Pipeline, Conversation,
AutoModelForCausalLM, AutoModelForSeq2SeqLM,
AutoTokenizer, AutoConfig,
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
AutoModelForQuestionAnswering, StoppingCriteria,
StoppingCriteriaList)
from transformers import (
pipeline, Pipeline, Conversation, AutoModelForCausalLM,
AutoModelForSeq2SeqLM, AutoTokenizer, AutoConfig,
AutoModelForSequenceClassification, AutoModelForTokenClassification,
AutoModelForQuestionAnswering, StoppingCriteria, StoppingCriteriaList)
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from peft import PeftConfig, PeftModel, PeftModelForCausalLM

Expand Down Expand Up @@ -50,14 +48,18 @@
}

LMI_DIST_ADV_MODEL = {
"RWForCausalLM", "GPTNeoXForCausalLM", "T5ForConditionalGeneration",
"LlamaForCausalLM", "FalconForCausalLM", "MPTForCausalLM",
"RWForCausalLM",
"GPTNeoXForCausalLM",
"T5ForConditionalGeneration",
"LlamaForCausalLM",
"FalconForCausalLM",
"MPTForCausalLM",
"GPTBigCodeForCausalLM",
}

# https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#efficient-inference-on-a-single-gpu
FLASH_2_SUPPORTED_MODELS = {
"LlamaForCausalLM", "RWForCausalLM", "FalconForCausalLM"
"LlamaForCausalLM", "RWForCausalLM", "FalconForCausalLM"
}

PEFT_MODEL_TASK_TO_CLS = {
Expand Down Expand Up @@ -114,14 +116,18 @@ def get_rolling_batch_class_from_str(rolling_batch_type: str, is_mpi: bool,
return VLLMRollingBatch
raise ValueError(f"Invalid rolling batch type: {rolling_batch_type}")


class StopWord(StoppingCriteria):

def __init__(self, tokenizer, stop_seq):
StoppingCriteria.__init__(self)
self.tokenizer = tokenizer
self.stop_seq = stop_seq

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs):
decoded_input_ids = self.tokenizer.decode(input_ids[0][-len(self.stop_seq):])
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor,
**kwargs):
decoded_input_ids = self.tokenizer.decode(
input_ids[0][-len(self.stop_seq):])

matches = re.search(self.stop_seq, decoded_input_ids)

Expand All @@ -130,6 +136,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa

return False


class HuggingFaceService(object):

def __init__(self):
Expand Down Expand Up @@ -202,8 +209,8 @@ def initialize(self, properties: dict):
properties.get("dtype"))
if "revision" in properties:
kwargs["revision"] = properties.get('revision')
self.disable_flash_attn = properties.get(
"disable_flash_attn", "true").lower() == 'true'
self.disable_flash_attn = properties.get("disable_flash_attn",
"true").lower() == 'true'
self.rolling_batch_type = properties.get("rolling_batch", None)

self._read_model_config(model_id_or_path,
Expand Down Expand Up @@ -249,9 +256,12 @@ def parse_stop_sequence_input(self, stop_sequence):
Input: stop_sequence (string)
Output: list of strings
"""
assert stop_sequence[0] == '[' and stop_sequence[-1] == ']', "option.stop_sequence not properly formatted"
assert stop_sequence[0] == '[' and stop_sequence[
-1] == ']', "option.stop_sequence not properly formatted"
stop_sequence = stop_sequence.replace(", ", ",")
stop_seq_list = [element[1:-1] for element in stop_sequence[1:-1].split(",")]
stop_seq_list = [
element[1:-1] for element in stop_sequence[1:-1].split(",")
]
return stop_seq_list

def load_stopping_criteria_list(self, stop_sequence):
Expand All @@ -262,7 +272,7 @@ def load_stopping_criteria_list(self, stop_sequence):
"""
if self.tokenizer is None:
return

stop_seq_list = self.parse_stop_sequence_input(stop_sequence)

stopwords = []
Expand Down Expand Up @@ -505,7 +515,8 @@ def _init_model_and_tokenizer(self, model_id_or_path: str, **kwargs):
model_cls = AutoModelForSeq2SeqLM
else:
model_cls = AutoModelForCausalLM
if architectures[0] in FLASH_2_SUPPORTED_MODELS and enable_flash() and not self.disable_flash_attn:
if architectures[0] in FLASH_2_SUPPORTED_MODELS and enable_flash(
) and not self.disable_flash_attn:
kwargs['use_flash_attention_2'] = True

if self.peft_config is not None:
Expand Down Expand Up @@ -631,15 +642,14 @@ def register_adapter(inputs: Input):
adapter_path = inputs.get_property("src")
if not os.path.exists(adapter_path):
raise ValueError(
f"Only local LoRA models are supported. {adapter_path} is not a valid path")
logging.info(
f"Registering adapter {adapter_name} from {adapter_path}")
f"Only local LoRA models are supported. {adapter_path} is not a valid path"
)
logging.info(f"Registering adapter {adapter_name} from {adapter_path}")
if isinstance(_service.model, PeftModel):
_service.model.load_adapter(adapter_path, adapter_name)
else:
_service.model = PeftModel.from_pretrained(_service.model,
adapter_path,
adapter_name)
adapter_path, adapter_name)

if isinstance(_service.hf_pipeline, Pipeline):
_service.hf_pipeline.model = _service.model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@
DEFAULT_SEARCH_ALGORITHM = 'greedy'
# https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#efficient-inference-on-a-single-gpu
FLASH_2_SUPPORTED_MODELS = {
"LlamaForCausalLM", "RWForCausalLM", "FalconForCausalLM"
"LlamaForCausalLM", "RWForCausalLM", "FalconForCausalLM"
}


def enable_flash():
if torch.cuda.is_available():
major, _ = torch.cuda.get_device_capability()
Expand Down Expand Up @@ -134,8 +135,8 @@ def _init_model_and_tokenizer(self,
device_map = kwargs.pop('device_map')

if architectures[0] in FLASH_2_SUPPORTED_MODELS and enable_flash():
if properties.get(
"disable_flash_attn", "true").lower() != 'true':
if properties.get("disable_flash_attn",
"true").lower() != 'true':
kwargs['use_flash_attention_2'] = True

if "lmi_dist_sharding" == multi_gpu:
Expand Down
15 changes: 10 additions & 5 deletions engines/python/setup/djl_python/sm_log_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.


import copy
from collections import defaultdict
from djl_python import __version__
Expand All @@ -27,18 +26,24 @@ def filter(self, record):
try:
if isinstance(record.msg, str):
for i in self.sm_log_markers:
if record.msg.startswith(i+':'):
if record.msg.startswith(i + ':'):
altered_record = copy.deepcopy(record)
tag, metric_name, metric = [i.strip() for i in altered_record.msg.split(':')]
tag, metric_name, metric = [
i.strip() for i in altered_record.msg.split(':')
]
value, units = metric.split(' ')
altered_metric_name = ''.join([word[0].upper()+word[1:] for word in metric_name.split(' ')])
altered_metric_name = ''.join([
word[0].upper() + word[1:]
for word in metric_name.split(' ')
])
altered_record.msg = f"{tag}.Count:{self.count(altered_metric_name)}|#DJLServing:{__version__},{altered_metric_name}:{value} {units}"
return altered_record
return False
else:
return False
except Exception as exc:
logging.warning(f"Forwarding {str(record)} failed due to {str(exc)}")
logging.warning(
f"Forwarding {str(record)} failed due to {str(exc)}")
return False

def count(self, key):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,3 @@ def test_filter_miss(self):
record.msg = f"LLM sharding and compilation latency: 845.62 : secs"
actual = filter.filter(record)
self.assertFalse(actual)

3 changes: 2 additions & 1 deletion engines/python/setup/djl_python_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,8 @@ def configure_sm_logging():
if 'SM_TELEMETRY_LOG_REV_2022_12' in os.environ:
# https://docs.aws.amazon.com/deep-learning-containers/latest/devguide/logging-and-monitoring.html
root_logger = logging.getLogger()
sm_log_handler = logging.FileHandler(filename=os.getenv('SM_TELEMETRY_LOG_REV_2022_12'))
sm_log_handler = logging.FileHandler(
filename=os.getenv('SM_TELEMETRY_LOG_REV_2022_12'))
sm_log_handler.addFilter(SMLogFilter())
root_logger.addHandler(sm_log_handler)

Expand Down
1 change: 0 additions & 1 deletion tests/integration/llm/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,6 @@
"option.quantize": "smoothquant",
"option.smoothquant_alpha": 0.65
}

}

transformers_neuronx_handler_list = {
Expand Down
3 changes: 1 addition & 2 deletions tests/integration/lmic_test_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,7 @@ def _validate_test_series_config(self, series):
required_parameters = []
if series == "performance":
required_parameters = [
"tensor_parallel", "batch_size",
"out_tokens", "count"
"tensor_parallel", "batch_size", "out_tokens", "count"
]
for param in required_parameters:
if param not in self.config or self.config[param] is None:
Expand Down
2 changes: 1 addition & 1 deletion wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java
Original file line number Diff line number Diff line change
Expand Up @@ -748,9 +748,9 @@ void checkAvailableMemory(Device device) throws IOException {
long reservedMemory = intValue(prop, "reserved_memory_mb", defMemory) * 1024L * 1024;
String tpDegreeStr = Utils.getenv("TENSOR_PARALLEL_DEGREE", "0");
tpDegreeStr = prop.getProperty("option.tensor_parallel_degree", tpDegreeStr);
Engine eng = Engine.getEngine(engineName);
int tpDegree;
if ("max".equals(tpDegreeStr)) {
Engine eng = Engine.getEngine(engineName);
if (eng.getGpuCount() > 0) {
tpDegree = eng.getGpuCount();
} else {
Expand Down

0 comments on commit ac86197

Please sign in to comment.