Skip to content

Commit

Permalink
[UnitTest] adding more properties testing (#1306)
Browse files Browse the repository at this point in the history
Co-authored-by: Somasundaram <sindhuso@5c52309d1b09.ant.amazon.com>
  • Loading branch information
sindhuvahinis and Somasundaram committed Nov 10, 2023
1 parent b644e07 commit 14a6b45
Showing 1 changed file with 68 additions and 5 deletions.
73 changes: 68 additions & 5 deletions engines/python/setup/djl_python/tests/test_properties_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,38 @@ def test_tnx_configs(self):
self.assertEqual(tnx_configs.compiled_graph_path,
str(properties['compiled_graph_path']))

def test_tnx_all_configs(self):
# TODO: Replace with actual example of context_length_estimate

properties = {
"n_positions":
"2048",
"load_split_model":
"true",
"load_in_8bit":
"true",
"compiled_graph_path":
"s3://test/bucket/folder",
"low_cpu_mem_usage":
"true",
'context_length_estimate':
'{"context_length": "128", "variable_size": "12"}'
}
tnx_configs = TransformerNeuronXProperties(**common_properties,
**properties)
self.assertEqual(tnx_configs.n_positions, 2048)
self.assertEqual(tnx_configs.compiled_graph_path,
properties['compiled_graph_path'])

self.assertTrue(tnx_configs.load_split_model)
self.assertTrue(tnx_configs.load_in_8bit)
self.assertTrue(tnx_configs.low_cpu_mem_usage)

self.assertDictEqual(tnx_configs.context_length_estimate, {
'context_length': '128',
'variable_size': '12'
})

def test_tnx_configs_error_case(self):
properties = {
"n_positions": "256",
Expand Down Expand Up @@ -273,15 +305,43 @@ def test_ds_invalid_dtype():
test_ds_invalid_dtype()

def test_hf_configs(self):
properties = {
"model_id": "model_id",
"model_dir": "model_dir",
"low_cpu_mem_usage": "true",
"disable_flash_attn": "false",
"engine": "MPI",
}

hf_configs = HuggingFaceProperties(**properties)
self.assertIsNone(hf_configs.load_in_8bit)
self.assertIsNone(hf_configs.device)
self.assertTrue(hf_configs.low_cpu_mem_usage)
self.assertFalse(hf_configs.disable_flash_attn)
self.assertIsNone(hf_configs.device_map)
self.assertTrue(hf_configs.is_mpi)
self.assertDictEqual(hf_configs.kwargs, {
'trust_remote_code': False,
"low_cpu_mem_usage": True,
})

def test_hf_all_configs(self):
properties = {
"model_id": "model_id",
"model_dir": "model_dir",
"tensor_parallel_degree": "4",
"load_in_4bit": "false",
"load_in_8bit": "true",
"low_cpu_mem_usage": "True",
"low_cpu_mem_usage": "true",
"disable_flash_attn": "false",
"engine": "MPI",
"device_map": "auto"
"device_map": "cpu",
"quantize": "bitsandbytes8",
"output_formatter": "jsonlines",
"waiting_steps": '12',
"trust_remote_code": "true",
"rolling_batch": "auto",
"dtype": "bf16"
}

hf_configs = HuggingFaceProperties(**properties)
Expand All @@ -292,10 +352,13 @@ def test_hf_configs(self):
self.assertTrue(hf_configs.is_mpi)
self.assertDictEqual(
hf_configs.kwargs, {
'trust_remote_code': False,
'trust_remote_code': True,
"low_cpu_mem_usage": True,
"device_map": 'auto',
"load_in_8bit": True
"device_map": 'cpu',
"load_in_8bit": True,
"waiting_steps": 12,
"output_formatter": "jsonlines",
"torch_dtype": torch.bfloat16
})

def test_hf_quantize(self):
Expand Down

0 comments on commit 14a6b45

Please sign in to comment.