Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix make_model_config_json function #188

Merged
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)

### Fixed
- Fix ModelUploader bug & Update model tracing demo notebook by @thanawan-atc in ([#185](https://github.com/opensearch-project/opensearch-py-ml/pull/185))
- Fix make_model_config_json function by @thanawan-atc in ([#188](https://github.com/opensearch-project/opensearch-py-ml/pull/188))

## [1.0.0]

Expand Down
107 changes: 90 additions & 17 deletions opensearch_py_ml/ml_models/sentencetransformermodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -978,11 +978,14 @@
self,
model_name: str = None,
version_number: str = 1,
model_format: str = "TORCH_SCRIPT",
embedding_dimension: int = None,
pooling_mode: str = None,
normalize_result: bool = None,
all_config: str = None,
model_type: str = None,
verbose: bool = False,
) -> None:
) -> str:
"""
parse from config.json file of pre-trained hugging-face model to generate a ml-commons_model_config.json file. If all required
fields are given by users, use the given parameters and will skip reading the config.json
Expand All @@ -991,12 +994,21 @@
Optional, The name of the model. If None, default to parse from model id, for example,
'msmarco-distilbert-base-tas-b'
:type model_name: string
:param model_format:
Optional, The format of the model. Default is "TORCH_SCRIPT".
:type model_format: string
:param version_number:
Optional, The version number of the model. default is 1
Optional, The version number of the model. Default is 1
:type version_number: string
:param embedding_dimension: Optional, the embedding_dimension of the model. If None, parse embedding_dimension
from the config file of pre-trained hugging-face model, if not found, default to be 768
:param embedding_dimension: Optional, the embedding dimension of the model. If None, parse embedding_dimension
from the config file of pre-trained hugging-face model. If not found, default to be 768
:type embedding_dimension: int
:param pooling_mode: Optional, the pooling mode of the model. If None, parse pooling_mode
from the config file of pre-trained hugging-face model. If not found, do not include it.
:type pooling_mode: string
:param normalize_result: Optional, whether to normalize the result of the model. If None, check if 2_Normalize folder
exists in the pre-trained hugging-face model folder. If not found, do not include it.
:type normalize_result: bool
:param all_config:
Optional, the all_config of the model. If None, parse all contents from the config file of pre-trained
hugging-face model
Expand All @@ -1008,8 +1020,8 @@
:param verbose:
optional, use printing more logs. Default as false
:type verbose: bool
:return: no return value expected
:rtype: None
:return: model config file path. The file path where the model config file is being saved
:rtype: string
"""
folder_path = self.folder_path
config_json_file_path = os.path.join(folder_path, "config.json")
Expand Down Expand Up @@ -1057,27 +1069,27 @@
if mapping_item in config_content.keys():
embedding_dimension = config_content[mapping_item]
break
else:
print(
'Cannot find "dim" or "hidden_size" or "d_model" in config.json file at ',
config_json_file_path,
)
print(
"Please add in the config file or input in the argument for embedding_dimension "
)
embedding_dimension = 768
else:
print(

Check warning on line 1073 in opensearch_py_ml/ml_models/sentencetransformermodel.py

View check run for this annotation

Codecov / codecov/patch

opensearch_py_ml/ml_models/sentencetransformermodel.py#L1073

Added line #L1073 was not covered by tests
'Cannot find "dim" or "hidden_size" or "d_model" in config.json file at ',
config_json_file_path,
)
print(

Check warning on line 1077 in opensearch_py_ml/ml_models/sentencetransformermodel.py

View check run for this annotation

Codecov / codecov/patch

opensearch_py_ml/ml_models/sentencetransformermodel.py#L1077

Added line #L1077 was not covered by tests
"Please add in the config file or input in the argument for embedding_dimension."
)
embedding_dimension = 768

Check warning on line 1080 in opensearch_py_ml/ml_models/sentencetransformermodel.py

View check run for this annotation

Codecov / codecov/patch

opensearch_py_ml/ml_models/sentencetransformermodel.py#L1080

Added line #L1080 was not covered by tests
except IOError:
print(
"Cannot open in config.json file at ",
config_json_file_path,
". Please check the config.son ",
". Please check the config.json ",
"file in the path.",
)

model_config_content = {
"name": model_name,
"version": version_number,
"model_format": "TORCH_SCRIPT",
"model_format": model_format,
"model_task_type": "TEXT_EMBEDDING",
"model_config": {
"model_type": model_type,
Expand All @@ -1086,6 +1098,65 @@
"all_config": json.dumps(all_config),
},
}

if pooling_mode is not None:
thanawan-atc marked this conversation as resolved.
Show resolved Hide resolved
model_config_content["model_config"]["pooling_mode"] = pooling_mode

Check warning on line 1103 in opensearch_py_ml/ml_models/sentencetransformermodel.py

View check run for this annotation

Codecov / codecov/patch

opensearch_py_ml/ml_models/sentencetransformermodel.py#L1103

Added line #L1103 was not covered by tests
else:
pooling_config_json_file_path = os.path.join(
folder_path, "1_Pooling", "config.json"
)
if os.path.exists(pooling_config_json_file_path):
try:
with open(pooling_config_json_file_path) as f:
if verbose:
print(

Check warning on line 1112 in opensearch_py_ml/ml_models/sentencetransformermodel.py

View check run for this annotation

Codecov / codecov/patch

opensearch_py_ml/ml_models/sentencetransformermodel.py#L1112

Added line #L1112 was not covered by tests
"reading pooling config file from: "
+ pooling_config_json_file_path
)
pooling_config_content = json.load(f)
if pooling_mode is None:
thanawan-atc marked this conversation as resolved.
Show resolved Hide resolved
pooling_mode_mapping_dict = {
"pooling_mode_cls_token": "CLS",
"pooling_mode_mean_tokens": "MEAN",
"pooling_mode_max_tokens": "MAX",
"pooling_mode_mean_sqrt_len_tokens": "MEAN_SQRT_LEN",
}
for mapping_item in pooling_mode_mapping_dict:
if (
mapping_item in pooling_config_content.keys()
and pooling_config_content[mapping_item]
):
pooling_mode = pooling_mode_mapping_dict[
mapping_item
]
model_config_content["model_config"][
"pooling_mode"
] = pooling_mode
break
else:
print(

Check warning on line 1137 in opensearch_py_ml/ml_models/sentencetransformermodel.py

View check run for this annotation

Codecov / codecov/patch

opensearch_py_ml/ml_models/sentencetransformermodel.py#L1137

Added line #L1137 was not covered by tests
'Cannot find "pooling_mode_[mode]_token(s)" with value true in config.json file at ',
pooling_config_json_file_path,
)
print(

Check warning on line 1141 in opensearch_py_ml/ml_models/sentencetransformermodel.py

View check run for this annotation

Codecov / codecov/patch

opensearch_py_ml/ml_models/sentencetransformermodel.py#L1141

Added line #L1141 was not covered by tests
"Please add in the pooling config file or input in the argument for pooling_mode."
)

except IOError:
print(

Check warning on line 1146 in opensearch_py_ml/ml_models/sentencetransformermodel.py

View check run for this annotation

Codecov / codecov/patch

opensearch_py_ml/ml_models/sentencetransformermodel.py#L1145-L1146

Added lines #L1145 - L1146 were not covered by tests
"Cannot open in config.json file at ",
pooling_config_json_file_path,
". Please check the config.json ",
"file in the path.",
)

if normalize_result is not None:
model_config_content["model_config"]["normalize_result"] = normalize_result

Check warning on line 1154 in opensearch_py_ml/ml_models/sentencetransformermodel.py

View check run for this annotation

Codecov / codecov/patch

opensearch_py_ml/ml_models/sentencetransformermodel.py#L1154

Added line #L1154 was not covered by tests
else:
normalize_result_json_file_path = os.path.join(folder_path, "2_Normalize")
if os.path.exists(normalize_result_json_file_path):
model_config_content["model_config"]["normalize_result"] = True

Check warning on line 1158 in opensearch_py_ml/ml_models/sentencetransformermodel.py

View check run for this annotation

Codecov / codecov/patch

opensearch_py_ml/ml_models/sentencetransformermodel.py#L1158

Added line #L1158 was not covered by tests

if verbose:
print("generating ml-commons_model_config.json file...\n")
print(model_config_content)
Expand All @@ -1100,6 +1171,8 @@
"ml-commons_model_config.json file is saved at : ", model_config_file_path
)

return model_config_file_path

# private methods
def __qryrem(self, x):
# for removing the "QRY:" token if they exist in passages
Expand Down
Loading