Skip to content

Commit

Permalink
fix batch_size type conversion (#1299)
Browse files Browse the repository at this point in the history
Co-authored-by: sindhuso <sindhuso@3c0630156b9b.ant.amazon.com>
  • Loading branch information
sindhuvahinis and sindhuso committed Nov 9, 2023
1 parent eb9d0ef commit a4b3ac2
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class DeepSpeedProperties(Properties):
training_mp_size: Optional[int] = 1
checkpoint: Optional[str] = None
save_mp_checkpoint_path: Optional[str] = None
ds_config: Optional[Any] = None
ds_config: Optional[Any] = Field(default={}, alias='deepspeed_config_path')

@validator('device', always=True)
def set_device(cls, device):
Expand Down Expand Up @@ -111,10 +111,7 @@ def get_torch_dtype_from_str(dtype: str):

@root_validator()
def construct_ds_config(cls, properties):
if properties.get("deepspeed_config_path"):
with open(properties.get("deepspeed_config_path"), "r") as f:
properties['ds_config'] = json.load(f)
else:
if not properties.get("ds_config"):
ds_config = {
"tensor_parallel": {
"tp_size": properties['tensor_parallel_degree']
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def validate_enable_streaming(cls, enable_streaming: str) -> str:

@validator('batch_size', pre=True)
def validate_batch_size(cls, batch_size, values):
batch_size = int(batch_size)
if batch_size > 1:
if not is_rolling_batch_enabled(
values['rolling_batch']) and is_streaming_enabled(
Expand Down
58 changes: 52 additions & 6 deletions engines/python/setup/djl_python/tests/test_properties_manager.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os
import json
import unittest
from djl_python.properties_manager.properties import Properties
from djl_python.properties_manager.tnx_properties import TransformerNeuronXProperties
Expand All @@ -19,9 +21,12 @@
"model_dir": "model_dir",
"rolling_batch": "disable",
"tensor_parallel_degree": "4",
'batch_size': 4,
'max_rolling_batch_size': 2,
'enable_streaming': 'False'
'batch_size': "4",
'max_rolling_batch_size': '2',
'enable_streaming': 'False',
'dtype': 'fp16',
'revision': 'shdghdfgdfg',
'trust_remote_code': 'true'
}


Expand All @@ -44,11 +49,32 @@ def test_common_configs(self):
self.assertIsNone(configs.dtype)
self.assertIsNone(configs.revision)

def test_all_common_configs(self):
configs = Properties(**common_properties)
self.assertEqual(configs.batch_size, 4)
self.assertEqual(configs.tensor_parallel_degree, 4)
self.assertEqual(common_properties['model_id'],
configs.model_id_or_path)
self.assertEqual(common_properties['rolling_batch'],
configs.rolling_batch)
self.assertEqual(int(common_properties['tensor_parallel_degree']),
configs.tensor_parallel_degree)

self.assertEqual(int(common_properties['batch_size']),
configs.batch_size)
self.assertEqual(int(common_properties['max_rolling_batch_size']),
configs.max_rolling_batch_size)
self.assertEqual(configs.enable_streaming.value, 'false')

self.assertTrue(configs.trust_remote_code)
self.assertEqual(configs.dtype, common_properties['dtype'])
self.assertEqual(configs.revision, common_properties['revision'])

def test_common_configs_error_case(self):
other_properties = min_common_properties
other_properties["rolling_batch"] = "disable"
other_properties["enable_streaming"] = "true"
other_properties["batch_size"] = 2
other_properties["batch_size"] = '2'
with self.assertRaises(ValueError):
Properties(**other_properties)

Expand Down Expand Up @@ -130,7 +156,6 @@ def test_trtllm_rb_invalid():
def test_ds_properties(self):
ds_properties = {
'quantize': "dynamic_int8",
'dtype': 'fp16',
'max_tokens': "2048",
'task': 'fill-mask',
'low_cpu_mem_usage': "false",
Expand Down Expand Up @@ -173,6 +198,26 @@ def test_ds_basic_configs():

self.assertDictEqual(ds_configs.ds_config, ds_config)

def test_deepspeed_configs_file():
ds_properties['deepspeed_config_path'] = './sample.json'
ds_config = {
'tensor_parallel': {
'tp_size': 42
},
'save_mp_checkpoint_path': None,
'dynamic_quant': {
'enabled': False,
'use_cutlass': True
}
}
with open('sample.json', 'w') as fp:
json.dump(ds_config, fp)

ds_configs = DeepSpeedProperties(**ds_properties,
**common_properties)
self.assertDictEqual(ds_configs.ds_config, ds_config)
os.remove('sample.json')

def test_ds_smoothquant_configs():
ds_properties['quantize'] = 'smoothquant'
ds_configs = DeepSpeedProperties(**ds_properties,
Expand All @@ -199,6 +244,7 @@ def test_ds_invalid_quant_method():
test_ds_basic_configs()
test_ds_smoothquant_configs()
test_ds_invalid_quant_method()
test_deepspeed_configs_file()

def test_ds_error_properties(self):
ds_properties = {
Expand Down Expand Up @@ -263,7 +309,7 @@ def test_hf_quantize(self):
HFQuantizeMethods.bitsandbytes.value)

def test_hf_error_case(self):
properties = {"model_id": "model_id", 'load_in_8bit': True}
properties = {"model_id": "model_id", 'load_in_8bit': 'true'}
with self.assertRaises(ValueError):
HuggingFaceProperties(**properties)

Expand Down

0 comments on commit a4b3ac2

Please sign in to comment.