Skip to content

Commit

Permalink
[ML] Using model supplied mask token (elastic#162168)
Browse files Browse the repository at this point in the history
Fixes elastic#159577

Using the `mask_token` property from the model config for testing the
model.
This is shown in the input placeholder text, in the input validation and
for displaying the results.

<img width="433" alt="image"
src="https://github.com/elastic/kibana/assets/7405507/bc63f9e6-a3d5-402c-a451-8d80b758acbc">
  • Loading branch information
jgowdyelastic authored and dgieselaar committed Jul 23, 2023
1 parent 66aadaa commit 27c3db3
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 8 deletions.
Expand Up @@ -16,7 +16,7 @@ import { getGeneralInputComponent } from '../text_input';
import { getFillMaskOutputComponent } from './fill_mask_output';
import { trainedModelsApiProvider } from '../../../../services/ml_api_service/trained_models';

const MASK = '[MASK]';
const DEFAULT_MASK_TOKEN = '[MASK]';

export class FillMaskInference extends InferenceBase<TextClassificationResponse> {
protected inferenceType = SUPPORTED_PYTORCH_TASKS.FILL_MASK;
Expand All @@ -30,6 +30,7 @@ export class FillMaskInference extends InferenceBase<TextClassificationResponse>
defaultMessage: 'Test how well the model predicts a missing word in a phrase.',
}),
];
private maskToken = DEFAULT_MASK_TOKEN;

constructor(
trainedModelsApi: ReturnType<typeof trainedModelsApiProvider>,
Expand All @@ -38,9 +39,14 @@ export class FillMaskInference extends InferenceBase<TextClassificationResponse>
deploymentId: string
) {
super(trainedModelsApi, model, inputType, deploymentId);
// @ts-expect-error mask_token is missing in type
const maskToken = model.inference_config?.[this.inferenceType]?.mask_token;
if (maskToken) {
this.maskToken = maskToken;
}

this.initialize([
this.inputText$.pipe(map((inputText) => inputText.every((t) => t.includes(MASK)))),
this.inputText$.pipe(map((inputText) => inputText.every((t) => t.includes(this.maskToken)))),
]);
}

Expand Down Expand Up @@ -71,16 +77,16 @@ export class FillMaskInference extends InferenceBase<TextClassificationResponse>

public predictedValue(resp: TextClassificationResponse) {
const { response, inputText } = resp;
return response[0]?.value ? inputText.replace(MASK, response[0].value) : inputText;
return response[0]?.value ? inputText.replace(this.maskToken, response[0].value) : inputText;
}

public getInputComponent(): JSX.Element | null {
if (this.inputType === INPUT_TYPE.TEXT) {
const placeholder = i18n.translate(
'xpack.ml.trainedModels.testModelsFlyout.fillMask.inputText',
{
defaultMessage:
'Enter a phrase to test. Use [MASK] as a placeholder for the missing words.',
defaultMessage: `Enter a phrase to test. Use {maskToken} as a placeholder for the missing words.`,
values: { maskToken: this.maskToken },
}
);

Expand Down
1 change: 0 additions & 1 deletion x-pack/plugins/translations/translations/fr-FR.json
Expand Up @@ -25599,7 +25599,6 @@
"xpack.ml.trainedModels.nodesList.totalAmountLabel": "Nombre total de nœuds Machine Learning",
"xpack.ml.trainedModels.testModelsFlyout.deploymentIdLabel": "ID de déploiement",
"xpack.ml.trainedModels.testModelsFlyout.fillMask.info1": "Testez la capacité du modèle à prédire un mot manquant dans une phrase.",
"xpack.ml.trainedModels.testModelsFlyout.fillMask.inputText": "Entrez une expression à tester. Utilisez [MASK] comme espace réservé pour les mots manquants.",
"xpack.ml.trainedModels.testModelsFlyout.fillMask.label": "Masque de remplissage",
"xpack.ml.trainedModels.testModelsFlyout.generalTextInput.inputText": "Texte d'entrée",
"xpack.ml.trainedModels.testModelsFlyout.generalTextInput.inputTitle": "Texte d'entrée",
Expand Down
1 change: 0 additions & 1 deletion x-pack/plugins/translations/translations/ja-JP.json
Expand Up @@ -25598,7 +25598,6 @@
"xpack.ml.trainedModels.nodesList.totalAmountLabel": "合計機械学習ノード",
"xpack.ml.trainedModels.testModelsFlyout.deploymentIdLabel": "デプロイID",
"xpack.ml.trainedModels.testModelsFlyout.fillMask.info1": "モデルがフレーズの不足している単語を予測する精度をテストします。",
"xpack.ml.trainedModels.testModelsFlyout.fillMask.inputText": "テストするフレーズを入力してください。足りない語句のプレースホルダーとして[MASK]を使用します。",
"xpack.ml.trainedModels.testModelsFlyout.fillMask.label": "マスクを塗りつぶす",
"xpack.ml.trainedModels.testModelsFlyout.generalTextInput.inputText": "入力テキスト",
"xpack.ml.trainedModels.testModelsFlyout.generalTextInput.inputTitle": "入力テキスト",
Expand Down
1 change: 0 additions & 1 deletion x-pack/plugins/translations/translations/zh-CN.json
Expand Up @@ -25597,7 +25597,6 @@
"xpack.ml.trainedModels.nodesList.totalAmountLabel": "Machine Learning 节点总数",
"xpack.ml.trainedModels.testModelsFlyout.deploymentIdLabel": "部署 ID",
"xpack.ml.trainedModels.testModelsFlyout.fillMask.info1": "测试模型预测短语中缺失的词的表现。",
"xpack.ml.trainedModels.testModelsFlyout.fillMask.inputText": "输入短语以进行测试。将 [MASK] 用作缺失词的占位符。",
"xpack.ml.trainedModels.testModelsFlyout.fillMask.label": "填充掩码",
"xpack.ml.trainedModels.testModelsFlyout.generalTextInput.inputText": "输入文本",
"xpack.ml.trainedModels.testModelsFlyout.generalTextInput.inputTitle": "输入文本",
Expand Down

0 comments on commit 27c3db3

Please sign in to comment.