Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix make_model_config_json function #188

Merged
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)

### Fixed
- Fix ModelUploader bug & Update model tracing demo notebook by @thanawan-atc in ([#185](https://github.com/opensearch-project/opensearch-py-ml/pull/185))
- Fix make_model_config_json function by @thanawan-atc in ([#188](https://github.com/opensearch-project/opensearch-py-ml/pull/188))

## [1.0.0]

Expand Down
100 changes: 83 additions & 17 deletions opensearch_py_ml/ml_models/sentencetransformermodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -978,11 +978,14 @@
self,
model_name: str = None,
version_number: str = 1,
model_format: str = "TORCH_SCRIPT",
embedding_dimension: int = None,
pooling_mode: str = None,
normalize_result: bool = None,
all_config: str = None,
model_type: str = None,
verbose: bool = False,
) -> None:
) -> str:
"""
parse from config.json file of pre-trained hugging-face model to generate a ml-commons_model_config.json file. If all required
fields are given by users, use the given parameters and will skip reading the config.json
Expand All @@ -991,12 +994,21 @@
Optional, The name of the model. If None, default to parse from model id, for example,
'msmarco-distilbert-base-tas-b'
:type model_name: string
:param model_format:
Optional, The format of the model. Default is "TORCH_SCRIPT".
:type model_format: string
:param version_number:
Optional, The version number of the model. default is 1
Optional, The version number of the model. Default is 1
:type version_number: string
:param embedding_dimension: Optional, the embedding_dimension of the model. If None, parse embedding_dimension
from the config file of pre-trained hugging-face model, if not found, default to be 768
:param embedding_dimension: Optional, the embedding dimension of the model. If None, parse embedding_dimension
from the config file of pre-trained hugging-face model. If not found, default to be 768
:type embedding_dimension: int
:param pooling_mode: Optional, the pooling mode of the model. If None, parse pooling_mode
from the config file of pre-trained hugging-face model. If not found, do not include it.
:type pooling_mode: string
:param normalize_result: Optional, whether to normalize the result of the model. If None, check if 2_Normalize folder
exists in the pre-trained hugging-face model folder. If not found, do not include it.
:type normalize_result: bool
:param all_config:
Optional, the all_config of the model. If None, parse all contents from the config file of pre-trained
hugging-face model
Expand All @@ -1008,8 +1020,8 @@
:param verbose:
optional, use printing more logs. Default as false
:type verbose: bool
:return: no return value expected
:rtype: None
:return: model config file path. The file path where the model config file is being saved
:rtype: string
"""
folder_path = self.folder_path
config_json_file_path = os.path.join(folder_path, "config.json")
Expand Down Expand Up @@ -1057,27 +1069,25 @@
if mapping_item in config_content.keys():
embedding_dimension = config_content[mapping_item]
break
else:
print(
'Cannot find "dim" or "hidden_size" or "d_model" in config.json file at ',
config_json_file_path,
)
print(
"Please add in the config file or input in the argument for embedding_dimension "
)
embedding_dimension = 768
else:
print(
'Cannot find "dim" or "hidden_size" or "d_model" in config.json file at ',
config_json_file_path,
". Please add in the config file or input in the argument for embedding_dimension.",
)
embedding_dimension = 768
except IOError:
print(
"Cannot open in config.json file at ",
config_json_file_path,
". Please check the config.son ",
". Please check the config.json ",
"file in the path.",
)

model_config_content = {
"name": model_name,
"version": version_number,
"model_format": "TORCH_SCRIPT",
"model_format": model_format,
"model_task_type": "TEXT_EMBEDDING",
"model_config": {
"model_type": model_type,
Expand All @@ -1086,6 +1096,60 @@
"all_config": json.dumps(all_config),
},
}

if pooling_mode is not None:
thanawan-atc marked this conversation as resolved.
Show resolved Hide resolved
model_config_content["model_config"]["pooling_mode"] = pooling_mode
else:
pooling_config_json_file_path = os.path.join(
folder_path, "1_Pooling", "config.json"
)
if os.path.exists(pooling_config_json_file_path):
try:
with open(pooling_config_json_file_path) as f:
if verbose:
print(
"reading pooling config file from: "
+ pooling_config_json_file_path
)
pooling_config_content = json.load(f)
pooling_mode_mapping_dict = {
"pooling_mode_cls_token": "CLS",
"pooling_mode_mean_tokens": "MEAN",
"pooling_mode_max_tokens": "MAX",
"pooling_mode_mean_sqrt_len_tokens": "MEAN_SQRT_LEN",
}
for mapping_item in pooling_mode_mapping_dict:
if (
mapping_item in pooling_config_content.keys()
and pooling_config_content[mapping_item]
):
pooling_mode = pooling_mode_mapping_dict[mapping_item]
model_config_content["model_config"][
"pooling_mode"
] = pooling_mode
break
else:
print(
'Cannot find "pooling_mode_[mode]_token(s)" with value true in config.json file at ',
pooling_config_json_file_path,
". Please add in the pooling config file or input in the argument for pooling_mode.",
)

except IOError:
print(

Check warning on line 1139 in opensearch_py_ml/ml_models/sentencetransformermodel.py

View check run for this annotation

Codecov / codecov/patch

opensearch_py_ml/ml_models/sentencetransformermodel.py#L1138-L1139

Added lines #L1138 - L1139 were not covered by tests
"Cannot open in config.json file at ",
pooling_config_json_file_path,
". Please check the config.json ",
"file in the path.",
)

if normalize_result is not None:
model_config_content["model_config"]["normalize_result"] = normalize_result
else:
normalize_result_json_file_path = os.path.join(folder_path, "2_Normalize")
if os.path.exists(normalize_result_json_file_path):
model_config_content["model_config"]["normalize_result"] = True

if verbose:
print("generating ml-commons_model_config.json file...\n")
print(model_config_content)
Expand All @@ -1100,6 +1164,8 @@
"ml-commons_model_config.json file is saved at : ", model_config_file_path
)

return model_config_file_path

# private methods
def __qryrem(self, x):
# for removing the "QRY:" token if they exist in passages
Expand Down
Loading
Loading