Skip to content

Commit

Permalink
add test pipeline action for dfa trained models
Browse files Browse the repository at this point in the history
  • Loading branch information
alvarezmelissa87 committed Oct 9, 2023
1 parent fb23474 commit c66f468
Show file tree
Hide file tree
Showing 7 changed files with 102 additions and 18 deletions.
Expand Up @@ -157,7 +157,7 @@ export const AddInferencePipelineFlyout: FC<AddInferencePipelineFlyoutProps> = (
/>
)}
{step === ADD_INFERENCE_PIPELINE_STEPS.TEST && (
<TestPipeline sourceIndex={sourceIndex} state={formState} />
<TestPipeline sourceIndex={sourceIndex} state={formState} mode="step" />
)}
{step === ADD_INFERENCE_PIPELINE_STEPS.CREATE && (
<ReviewAndCreatePipeline
Expand Down
Expand Up @@ -39,9 +39,10 @@ import type { MlInferenceState } from '../types';
interface Props {
sourceIndex?: string;
state: MlInferenceState;
mode: 'standAlone' | 'step';
}

export const TestPipeline: FC<Props> = memo(({ state, sourceIndex }) => {
export const TestPipeline: FC<Props> = memo(({ state, sourceIndex, mode }) => {
const [simulatePipelineResult, setSimulatePipelineResult] = useState<
undefined | estypes.IngestSimulateResponse
>();
Expand Down Expand Up @@ -147,13 +148,17 @@ export const TestPipeline: FC<Props> = memo(({ state, sourceIndex }) => {
<EuiFlexItem>
<EuiText color="subdued" size="s">
<p>
<strong>
{i18n.translate(
'xpack.ml.trainedModels.content.indices.pipelines.addInferencePipelineModal.steps.test.optionalCallout',
{ defaultMessage: 'This is an optional step.' }
)}
</strong>
&nbsp;
{mode === 'step' ? (
<>
<strong>
{i18n.translate(
'xpack.ml.trainedModels.content.indices.pipelines.addInferencePipelineModal.steps.test.optionalCallout',
{ defaultMessage: 'This is an optional step.' }
)}
</strong>
&nbsp;
</>
) : null}
<FormattedMessage
id="xpack.ml.trainedModels.content.indices.pipelines.addInferencePipelineModal.steps.test.description"
defaultMessage="Run a simulation of the pipeline to confirm it produces the anticipated results."
Expand Down
Expand Up @@ -29,10 +29,11 @@ import { useToastNotificationService } from '../services/toast_notification_serv
import { getUserInputModelDeploymentParamsProvider } from './deployment_setup';
import { useMlKibana, useMlLocator, useNavigateToPath } from '../contexts/kibana';
import { ML_PAGES } from '../../../common/constants/locator';
import { isTestable } from './test_models';
import { isTestable, isDfaTrainedModel } from './test_models';
import { ModelItem } from './models_list';

export function useModelActions({
onDfaTestAction,
onTestAction,
onModelsDeleteRequest,
onModelDeployRequest,
Expand All @@ -42,6 +43,7 @@ export function useModelActions({
modelAndDeploymentIds,
}: {
isLoading: boolean;
onDfaTestAction: (model: ModelItem) => void;
onTestAction: (model: ModelItem) => void;
onModelsDeleteRequest: (models: ModelItem[]) => void;
onModelDeployRequest: (model: ModelItem) => void;
Expand Down Expand Up @@ -463,13 +465,8 @@ export function useModelActions({
onModelDeployRequest(model);
},
available: (item) => {
const isDfaTrainedModel =
item.metadata?.analytics_config !== undefined ||
item.inference_config?.regression !== undefined ||
item.inference_config?.classification !== undefined;

return (
isDfaTrainedModel &&
isDfaTrainedModel(item) &&
!isBuiltInModel(item) &&
!item.putModelConfig &&
canManageIngestPipelines
Expand Down Expand Up @@ -540,7 +537,13 @@ export function useModelActions({
type: 'icon',
isPrimary: true,
available: isTestable,
onClick: (item) => onTestAction(item),
onClick: (item) => {
if (isDfaTrainedModel(item)) {
onDfaTestAction(item);
} else {
onTestAction(item);
}
},
enabled: (item) => {
return canTestTrainedModels && isTestable(item, true) && !isLoading;
},
Expand Down Expand Up @@ -599,6 +602,7 @@ export function useModelActions({
canDeleteTrainedModels,
isBuiltInModel,
onTestAction,
onDfaTestAction,
canTestTrainedModels,
canManageIngestPipelines,
]
Expand Down
Expand Up @@ -68,6 +68,7 @@ import { useFieldFormatter } from '../contexts/kibana/use_field_formatter';
import { useRefresh } from '../routing/use_refresh';
import { SavedObjectsWarning } from '../components/saved_objects_warning';
import { TestTrainedModelFlyout } from './test_models';
import { TestDfaModelsFlyout } from './test_dfa_models_flyout';
import { AddInferencePipelineFlyout } from '../components/ml_inference';
import { useEnabledFeatures } from '../contexts/ml';

Expand Down Expand Up @@ -162,6 +163,7 @@ export const ModelsList: FC<Props> = ({
{}
);
const [modelToTest, setModelToTest] = useState<ModelItem | null>(null);
const [dfaModelToTest, setDfaModelToTest] = useState<ModelItem | null>(null);

const isBuiltInModel = useCallback(
(item: ModelItem) => item.tags.includes(BUILT_IN_MODEL_TAG),
Expand Down Expand Up @@ -409,6 +411,7 @@ export const ModelsList: FC<Props> = ({
isLoading,
fetchModels: fetchModelsData,
onTestAction: setModelToTest,
onDfaTestAction: setDfaModelToTest,
onModelsDeleteRequest: setModelsToDelete,
onModelDeployRequest: setModelToDeploy,
onLoading: setIsLoading,
Expand Down Expand Up @@ -762,6 +765,9 @@ export const ModelsList: FC<Props> = ({
{modelToTest === null ? null : (
<TestTrainedModelFlyout model={modelToTest} onClose={setModelToTest.bind(null, null)} />
)}
{dfaModelToTest === null ? null : (
<TestDfaModelsFlyout model={dfaModelToTest} onClose={setDfaModelToTest.bind(null, null)} />
)}
{modelToDeploy !== undefined ? (
<AddInferencePipelineFlyout
onClose={setModelToDeploy.bind(null, undefined)}
Expand Down
@@ -0,0 +1,57 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import React, { FC, useMemo } from 'react';
import { EuiFlyout, EuiFlyoutBody, EuiFlyoutHeader, EuiSpacer, EuiTitle } from '@elastic/eui';
import { FormattedMessage } from '@kbn/i18n-react';

import { TestPipeline } from '../components/ml_inference/components/test_pipeline';
import { getInitialState } from '../components/ml_inference/state';
import type { ModelItem } from './models_list';

interface Props {
model: ModelItem;
onClose: () => void;
}

export const TestDfaModelsFlyout: FC<Props> = ({ model, onClose }) => {
const sourceIndex = useMemo(
() =>
Array.isArray(model.metadata?.analytics_config.source.index)
? model.metadata?.analytics_config.source.index.join()
: model.metadata?.analytics_config.source.index,
// eslint-disable-next-line react-hooks/exhaustive-deps
[model?.model_id]
);

const state = useMemo(
() => getInitialState(model),
// eslint-disable-next-line react-hooks/exhaustive-deps
[model?.model_id]
);
return (
<EuiFlyout size="l" onClose={onClose} data-test-subj="mlTestModelsFlyout">
<EuiFlyoutHeader hasBorder>
<EuiTitle size="m">
<h2>
<FormattedMessage
id="xpack.ml.trainedModels.testDfaModelsFlyout.headerLabel"
defaultMessage="Test trained model"
/>
</h2>
</EuiTitle>
<EuiSpacer size="s" />
<EuiTitle size="xs">
<h4>{model.model_id}</h4>
</EuiTitle>
</EuiFlyoutHeader>
<EuiFlyoutBody>
<TestPipeline state={state} sourceIndex={sourceIndex} mode="standAlone" />
</EuiFlyoutBody>
</EuiFlyout>
);
};
Expand Up @@ -6,4 +6,4 @@
*/

export { TestTrainedModelFlyout } from './test_flyout';
export { isTestable } from './utils';
export { isTestable, isDfaTrainedModel } from './utils';
Expand Up @@ -15,6 +15,14 @@ import type { ModelItem } from '../models_list';

const PYTORCH_TYPES = Object.values(SUPPORTED_PYTORCH_TASKS);

export function isDfaTrainedModel(modelItem: ModelItem) {
return (
modelItem.metadata?.analytics_config !== undefined ||
modelItem.inference_config?.regression !== undefined ||
modelItem.inference_config?.classification !== undefined
);
}

export function isTestable(modelItem: ModelItem, checkForState = false) {
if (
modelItem.model_type === TRAINED_MODEL_TYPE.PYTORCH &&
Expand All @@ -31,5 +39,9 @@ export function isTestable(modelItem: ModelItem, checkForState = false) {
return true;
}

if (isDfaTrainedModel(modelItem)) {
return true;
}

return false;
}

0 comments on commit c66f468

Please sign in to comment.