Skip to content

Commit

Permalink
Fix make_model_config_json function (#188) (#190)
Browse files Browse the repository at this point in the history
* Improve make_model_config_json function

Signed-off-by: Thanawan Atchariyachanvanit <latchari@amazon.com>

* Fix linting issues

Signed-off-by: Thanawan Atchariyachanvanit <latchari@amazon.com>

* Add CHANGELOG.md

Signed-off-by: Thanawan Atchariyachanvanit <latchari@amazon.com>

* Add unittest

Signed-off-by: Thanawan Atchariyachanvanit <latchari@amazon.com>

* Correct linting

Signed-off-by: Thanawan Atchariyachanvanit <latchari@amazon.com>

* Correct linting

Signed-off-by: Thanawan Atchariyachanvanit <latchari@amazon.com>

* Fix bug

Signed-off-by: Thanawan Atchariyachanvanit <latchari@amazon.com>

* Minor Edit + Add more tests

Signed-off-by: Thanawan Atchariyachanvanit <latchari@amazon.com>

* Fix test

Signed-off-by: Thanawan Atchariyachanvanit <latchari@amazon.com>

* Correct typos

Signed-off-by: Thanawan Atchariyachanvanit <latchari@amazon.com>

* Increase test coverage

Signed-off-by: Thanawan Atchariyachanvanit <latchari@amazon.com>

* Increase test coverage (2)

Signed-off-by: Thanawan Atchariyachanvanit <latchari@amazon.com>

* Increase test coverage (3)

Signed-off-by: Thanawan Atchariyachanvanit <latchari@amazon.com>

* Remove redundant line

Signed-off-by: Thanawan Atchariyachanvanit <latchari@amazon.com>

---------

Signed-off-by: Thanawan Atchariyachanvanit <latchari@amazon.com>
(cherry picked from commit 13149e8)

Co-authored-by: Thanawan Atchariyachanvanit <latchari@amazon.com>
  • Loading branch information
opensearch-trigger-bot[bot] and thanawan-atc committed Jul 10, 2023
1 parent ff1d75c commit 58bf518
Show file tree
Hide file tree
Showing 3 changed files with 366 additions and 17 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)

### Fixed
- Fix ModelUploader bug & Update model tracing demo notebook by @thanawan-atc in ([#185](https://github.com/opensearch-project/opensearch-py-ml/pull/185))
- Fix make_model_config_json function by @thanawan-atc in ([#188](https://github.com/opensearch-project/opensearch-py-ml/pull/188))

## [1.0.0]

Expand Down
100 changes: 83 additions & 17 deletions opensearch_py_ml/ml_models/sentencetransformermodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -978,11 +978,14 @@ def make_model_config_json(
self,
model_name: str = None,
version_number: str = 1,
model_format: str = "TORCH_SCRIPT",
embedding_dimension: int = None,
pooling_mode: str = None,
normalize_result: bool = None,
all_config: str = None,
model_type: str = None,
verbose: bool = False,
) -> None:
) -> str:
"""
parse from config.json file of pre-trained hugging-face model to generate a ml-commons_model_config.json file. If all required
fields are given by users, use the given parameters and will skip reading the config.json
Expand All @@ -991,12 +994,21 @@ def make_model_config_json(
Optional, The name of the model. If None, default to parse from model id, for example,
'msmarco-distilbert-base-tas-b'
:type model_name: string
:param model_format:
Optional, The format of the model. Default is "TORCH_SCRIPT".
:type model_format: string
:param version_number:
Optional, The version number of the model. default is 1
Optional, The version number of the model. Default is 1
:type version_number: string
:param embedding_dimension: Optional, the embedding_dimension of the model. If None, parse embedding_dimension
from the config file of pre-trained hugging-face model, if not found, default to be 768
:param embedding_dimension: Optional, the embedding dimension of the model. If None, parse embedding_dimension
from the config file of pre-trained hugging-face model. If not found, default to be 768
:type embedding_dimension: int
:param pooling_mode: Optional, the pooling mode of the model. If None, parse pooling_mode
from the config file of pre-trained hugging-face model. If not found, do not include it.
:type pooling_mode: string
:param normalize_result: Optional, whether to normalize the result of the model. If None, check if 2_Normalize folder
exists in the pre-trained hugging-face model folder. If not found, do not include it.
:type normalize_result: bool
:param all_config:
Optional, the all_config of the model. If None, parse all contents from the config file of pre-trained
hugging-face model
Expand All @@ -1008,8 +1020,8 @@ def make_model_config_json(
:param verbose:
optional, use printing more logs. Default as false
:type verbose: bool
:return: no return value expected
:rtype: None
:return: model config file path. The file path where the model config file is being saved
:rtype: string
"""
folder_path = self.folder_path
config_json_file_path = os.path.join(folder_path, "config.json")
Expand Down Expand Up @@ -1057,27 +1069,25 @@ def make_model_config_json(
if mapping_item in config_content.keys():
embedding_dimension = config_content[mapping_item]
break
else:
print(
'Cannot find "dim" or "hidden_size" or "d_model" in config.json file at ',
config_json_file_path,
)
print(
"Please add in the config file or input in the argument for embedding_dimension "
)
embedding_dimension = 768
else:
print(
'Cannot find "dim" or "hidden_size" or "d_model" in config.json file at ',
config_json_file_path,
". Please add in the config file or input in the argument for embedding_dimension.",
)
embedding_dimension = 768
except IOError:
print(
"Cannot open in config.json file at ",
config_json_file_path,
". Please check the config.son ",
". Please check the config.json ",
"file in the path.",
)

model_config_content = {
"name": model_name,
"version": version_number,
"model_format": "TORCH_SCRIPT",
"model_format": model_format,
"model_task_type": "TEXT_EMBEDDING",
"model_config": {
"model_type": model_type,
Expand All @@ -1086,6 +1096,60 @@ def make_model_config_json(
"all_config": json.dumps(all_config),
},
}

if pooling_mode is not None:
model_config_content["model_config"]["pooling_mode"] = pooling_mode
else:
pooling_config_json_file_path = os.path.join(
folder_path, "1_Pooling", "config.json"
)
if os.path.exists(pooling_config_json_file_path):
try:
with open(pooling_config_json_file_path) as f:
if verbose:
print(
"reading pooling config file from: "
+ pooling_config_json_file_path
)
pooling_config_content = json.load(f)
pooling_mode_mapping_dict = {
"pooling_mode_cls_token": "CLS",
"pooling_mode_mean_tokens": "MEAN",
"pooling_mode_max_tokens": "MAX",
"pooling_mode_mean_sqrt_len_tokens": "MEAN_SQRT_LEN",
}
for mapping_item in pooling_mode_mapping_dict:
if (
mapping_item in pooling_config_content.keys()
and pooling_config_content[mapping_item]
):
pooling_mode = pooling_mode_mapping_dict[mapping_item]
model_config_content["model_config"][
"pooling_mode"
] = pooling_mode
break
else:
print(
'Cannot find "pooling_mode_[mode]_token(s)" with value true in config.json file at ',
pooling_config_json_file_path,
". Please add in the pooling config file or input in the argument for pooling_mode.",
)

except IOError:
print(
"Cannot open in config.json file at ",
pooling_config_json_file_path,
". Please check the config.json ",
"file in the path.",
)

if normalize_result is not None:
model_config_content["model_config"]["normalize_result"] = normalize_result
else:
normalize_result_json_file_path = os.path.join(folder_path, "2_Normalize")
if os.path.exists(normalize_result_json_file_path):
model_config_content["model_config"]["normalize_result"] = True

if verbose:
print("generating ml-commons_model_config.json file...\n")
print(model_config_content)
Expand All @@ -1100,6 +1164,8 @@ def make_model_config_json(
"ml-commons_model_config.json file is saved at : ", model_config_file_path
)

return model_config_file_path

# private methods
def __qryrem(self, x):
# for removing the "QRY:" token if they exist in passages
Expand Down
Loading

0 comments on commit 58bf518

Please sign in to comment.