Skip to content

Commit

Permalink
[Semantic Text UI] Handle the case when the model is not yet download…
Browse files Browse the repository at this point in the history
…ed (#182040)

When the trained model is not yet downloaded, it can't be deployed. This
PR has covered the following:
- Download the model if it does not exist
- Tests to support this change

### How to test the changes locally
- Download the elasticsearch changes from GitHub
[branch](https://github.com/elastic/elasticsearch/tree/feature/semantic-text)
- Run the elasticsearch: `./gradlew :run -Drun.license_type=trial`
- Download the changes of this PR in local kibana and do the following
steps
+ Set isSemanticTextEnabled = true in this
[location](https://github.com/elastic/kibana/pull/180246/files#diff-92f4739f8a4a6917951a1b6e1af21a96d54313eaa2b5ce4c0e0553dd2ee11fcaL80)
    +  Run `yarn start`
  • Loading branch information
saikatsarkar056 committed May 8, 2024
1 parent d2fef07 commit fd44e1f
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,15 @@ import { act } from 'react-dom/test-utils';
const mlMock: any = {
mlApi: {
inferenceModels: {
createInferenceEndpoint: jest.fn(),
createInferenceEndpoint: jest.fn().mockResolvedValue({}),
},
trainedModels: {
startModelAllocation: jest.fn(),
startModelAllocation: jest.fn().mockResolvedValue({}),
getTrainedModels: jest.fn().mockResolvedValue([
{
fully_defined: true,
},
]),
},
},
};
Expand Down Expand Up @@ -93,6 +98,11 @@ describe('useSemanticText', () => {
result.current.handleSemanticText(mockFieldData);
});

expect(mlMock.mlApi.trainedModels.startModelAllocation).toHaveBeenCalledWith('.elser_model_2');
expect(mockDispatch).toHaveBeenCalledWith({
type: 'field.addSemanticText',
value: mockFieldData,
});
expect(mlMock.mlApi.inferenceModels.createInferenceEndpoint).toHaveBeenCalledWith(
'elser_model_2',
'text_embedding',
Expand All @@ -105,16 +115,58 @@ describe('useSemanticText', () => {
},
}
);
expect(mlMock.mlApi.trainedModels.startModelAllocation).toHaveBeenCalledWith('.elser_model_2');
expect(mockDispatch).toHaveBeenCalledWith({
type: 'field.addSemanticText',
value: mockFieldData,
});

it('should invoke the download api if the model does not exist', async () => {
const mlMockWithModelNotDownloaded: any = {
mlApi: {
inferenceModels: {
createInferenceEndpoint: jest.fn(),
},
trainedModels: {
startModelAllocation: jest.fn(),
getTrainedModels: jest.fn().mockResolvedValue([
{
fully_defined: false,
},
]),
installElasticTrainedModelConfig: jest.fn().mockResolvedValue({}),
},
},
};
const { result } = renderHook(() =>
useSemanticText({
form,
setErrorsInTrainedModelDeployment: jest.fn(),
ml: mlMockWithModelNotDownloaded,
})
);

await act(async () => {
result.current.handleSemanticText(mockFieldData);
});

expect(
mlMockWithModelNotDownloaded.mlApi.trainedModels.installElasticTrainedModelConfig
).toHaveBeenCalledWith('.elser_model_2');
expect(
mlMockWithModelNotDownloaded.mlApi.trainedModels.startModelAllocation
).toHaveBeenCalledWith('.elser_model_2');
expect(
mlMockWithModelNotDownloaded.mlApi.inferenceModels.createInferenceEndpoint
).toHaveBeenCalledWith('elser_model_2', 'text_embedding', {
service: 'elasticsearch',
service_settings: {
num_allocations: 1,
num_threads: 1,
model_id: '.elser_model_2',
},
});
});

it('handles errors correctly', async () => {
const mockError = new Error('Test error');
mlMock.mlApi.inferenceModels.createInferenceEndpoint.mockImplementationOnce(() => {
mlMock.mlApi?.trainedModels.startModelAllocation.mockImplementationOnce(() => {
throw mockError;
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
*/

import { i18n } from '@kbn/i18n';
import { MlPluginStart } from '@kbn/ml-plugin/public';
import { useCallback } from 'react';
import { MlPluginStart, TrainedModelConfigResponse } from '@kbn/ml-plugin/public';
import React, { useEffect, useState } from 'react';
import { useComponentTemplatesContext } from '../../../../../../component_templates/component_templates_context';
import { useDispatch, useMappingsState } from '../../../../../mappings_state_context';
Expand Down Expand Up @@ -64,20 +65,37 @@ export function useSemanticText(props: UseSemanticTextProps) {
}
}, [form, inferenceId, inferenceToModelIdMap]);

const handleSemanticText = (data: Field) => {
data.inferenceId = inferenceValue;
const isModelDownloaded = useCallback(
async (modelId: string) => {
try {
const response: TrainedModelConfigResponse[] | undefined =
await ml?.mlApi?.trainedModels.getTrainedModels(modelId, {
include: 'definition_status',
});
return !!response?.[0]?.fully_defined;
} catch (error) {
if (error.body.statusCode !== 404) {
throw error;
}
}
return false;
},
[ml?.mlApi?.trainedModels]
);

const createInferenceEndpoint = (
trainedModelId: string,
defaultInferenceEndpoint: boolean,
data: Field
) => {
if (data.inferenceId === undefined) {
return;
}

const inferenceData = inferenceToModelIdMap?.[data.inferenceId];

if (!inferenceData) {
return;
throw new Error(
i18n.translate('xpack.idxMgmt.mappingsEditor.createField.undefinedInferenceIdError', {
defaultMessage: 'InferenceId is undefined while creating the inference endpoint.',
})
);
}

const { trainedModelId, defaultInferenceEndpoint, isDeployed, isDeployable } = inferenceData;

if (trainedModelId && defaultInferenceEndpoint) {
const modelConfig = {
service: 'elasticsearch',
Expand All @@ -87,28 +105,45 @@ export function useSemanticText(props: UseSemanticTextProps) {
model_id: trainedModelId,
},
};
try {
ml?.mlApi?.inferenceModels?.createInferenceEndpoint(
data.inferenceId,
'text_embedding',
modelConfig
);
} catch (error) {
setErrorsInTrainedModelDeployment?.((prevItems) => [...prevItems, trainedModelId]);
toasts?.addError(error.body && error.body.message ? new Error(error.body.message) : error, {
title: i18n.translate(
'xpack.idxMgmt.mappingsEditor.createField.inferenceEndpointCreationErrorTitle',
{
defaultMessage: 'Inference endpoint creation failed',
}
),
});
}

ml?.mlApi?.inferenceModels?.createInferenceEndpoint(
data.inferenceId,
'text_embedding',
modelConfig
);
}
};

const handleSemanticText = async (data: Field) => {
data.inferenceId = inferenceValue;
if (data.inferenceId === undefined) {
return;
}

const inferenceData = inferenceToModelIdMap?.[data.inferenceId];

if (!inferenceData) {
return;
}

const { trainedModelId, defaultInferenceEndpoint, isDeployed, isDeployable } = inferenceData;

if (isDeployable && trainedModelId && !isDeployed) {
if (isDeployable && trainedModelId) {
try {
ml?.mlApi?.trainedModels.startModelAllocation(trainedModelId);
const modelDownloaded: boolean = await isModelDownloaded(trainedModelId);

if (isDeployed) {
createInferenceEndpoint(trainedModelId, defaultInferenceEndpoint, data);
} else if (modelDownloaded) {
ml?.mlApi?.trainedModels
.startModelAllocation(trainedModelId)
.then(() => createInferenceEndpoint(trainedModelId, defaultInferenceEndpoint, data));
} else {
ml?.mlApi?.trainedModels
.installElasticTrainedModelConfig(trainedModelId)
.then(() => ml?.mlApi?.trainedModels.startModelAllocation(trainedModelId))
.then(() => createInferenceEndpoint(trainedModelId, defaultInferenceEndpoint, data));
}
toasts?.addSuccess({
title: i18n.translate(
'xpack.idxMgmt.mappingsEditor.createField.modelDeploymentStartedNotification',
Expand Down
1 change: 1 addition & 0 deletions x-pack/plugins/ml/public/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ export const plugin: PluginInitializer<
> = (initializerContext: PluginInitializerContext) => new MlPlugin(initializerContext);

export type { MlPluginSetup, MlPluginStart };
export type { TrainedModelConfigResponse } from '../common/types/trained_models';

export type { MlCapabilitiesResponse } from '../common/types/capabilities';
export type { MlSummaryJob } from '../common/types/anomaly_detection_jobs';
Expand Down

0 comments on commit fd44e1f

Please sign in to comment.