Skip to content

Commit

Permalink
Increase test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
thanawan-atc committed Jul 7, 2023
1 parent 2359406 commit 36cfeeb
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 15 deletions.
8 changes: 2 additions & 6 deletions opensearch_py_ml/ml_models/sentencetransformermodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1073,9 +1073,7 @@ def make_model_config_json(
print(
'Cannot find "dim" or "hidden_size" or "d_model" in config.json file at ',
config_json_file_path,
)
print(
"Please add in the config file or input in the argument for embedding_dimension."
". Please add in the config file or input in the argument for embedding_dimension.",
)
embedding_dimension = 768
except IOError:
Expand Down Expand Up @@ -1134,9 +1132,7 @@ def make_model_config_json(
print(
'Cannot find "pooling_mode_[mode]_token(s)" with value true in config.json file at ',
pooling_config_json_file_path,
)
print(
"Please add in the pooling config file or input in the argument for pooling_mode."
". Please add in the pooling config file or input in the argument for pooling_mode.",
)

except IOError:
Expand Down
88 changes: 79 additions & 9 deletions tests/ml_models/test_sentencetransformermodel_pytest.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,7 @@ def test_make_model_config_json_for_torch_script():

assert (
"name" in model_config_data_torch
and model_config_data_torch["name"]
== "sentence-transformers/multi-qa-MiniLM-L6-cos-v1"
and model_config_data_torch["name"] == model_id
), "Missing or Wrong model name in torch script model config file"
assert (
"model_format" in model_config_data_torch
Expand Down Expand Up @@ -248,9 +247,7 @@ def test_make_model_config_json_for_onnx():
), f"Creating model config file for tracing in onnx raised an exception {exec}"

assert (
"name" in model_config_data_onnx
and model_config_data_onnx["name"]
== "sentence-transformers/paraphrase-MiniLM-L3-v2"
"name" in model_config_data_onnx and model_config_data_onnx["name"] == model_id
), "Missing or Wrong model name in onnx model config file'"
assert (
"model_format" in model_config_data_onnx
Expand Down Expand Up @@ -308,8 +305,7 @@ def test_overwrite_fields_in_model_config():

assert (
"name" in model_config_data_torch
and model_config_data_torch["name"]
== "sentence-transformers/all-distilroberta-v1"
and model_config_data_torch["name"] == model_id
), "Missing or Wrong model name in torch script model config file"
assert (
"model_format" in model_config_data_torch
Expand Down Expand Up @@ -355,8 +351,7 @@ def test_overwrite_fields_in_model_config():

assert (
"name" in model_config_data_torch
and model_config_data_torch["name"]
== "sentence-transformers/all-distilroberta-v1"
and model_config_data_torch["name"] == model_id
), "Missing or Wrong model name in torch script model config file"
assert (
"model_format" in model_config_data_torch
Expand All @@ -379,5 +374,80 @@ def test_overwrite_fields_in_model_config():
clean_test_folder(TEST_FOLDER)


def test_missing_fields_in_config_json():
model_id = "sentence-transformers/msmarco-distilbert-base-tas-b"
expected_model_config_data = {
"embedding_dimension": 768,
"normalize_result": False,
}

clean_test_folder(TEST_FOLDER)
test_model9 = SentenceTransformerModel(
folder_path=TEST_FOLDER,
model_id=model_id,
)

test_model9.save_as_pt(model_id=model_id, sentences=["today is sunny"])

test_pooling_folder = os.path.join(TEST_FOLDER, "1_Pooling")
clean_test_folder(test_pooling_folder)

config_json_file_path = os.path.join(TEST_FOLDER, "config.json")
try:
with open(config_json_file_path, 'r') as f:
config_content = json.load(f)
embedding_dimension_mapping_list = [
"dim",
"hidden_size",
"d_model",
]
for mapping_item in embedding_dimension_mapping_list:
config_content.pop(mapping_item, None)

with open(config_json_file_path, 'w') as f:
json.dump(config_content, f)
except Exception as exec:
assert False, f"Modifying config file raised an exception {exec}"

model_config_path_torch = test_model9.make_model_config_json(
model_format="TORCH_SCRIPT"
)
try:
with open(model_config_path_torch) as json_file:
model_config_data_torch = json.load(json_file)
except Exception as exec:
assert (
False
), f"Creating model config file for tracing in torch_script raised an exception {exec}"

assert (
"name" in model_config_data_torch
and model_config_data_torch["name"] == model_id
), "Missing or Wrong model name in torch script model config file"
assert (
"model_format" in model_config_data_torch
and model_config_data_torch["model_format"] == "TORCH_SCRIPT"
)
assert (
"model_config" in model_config_data_torch
), "Missing 'model_config' in torch script model config file"

for k, v in expected_model_config_data.items():
assert (
k in model_config_data_torch["model_config"]
and model_config_data_torch["model_config"][k] == v
) or (
k not in model_config_data_torch["model_config"]
and k == "normalize_result"
and not v
), f"make_model_config_json() does not generate an expected model config"

assert (
"pooling_mode" not in model_config_data_torch
), "make_model_config_json() does not generate an expected model config"

clean_test_folder(TEST_FOLDER)


clean_test_folder(TEST_FOLDER)
clean_test_folder(TESTDATA_UNZIP_FOLDER)

0 comments on commit 36cfeeb

Please sign in to comment.