Skip to content

Commit

Permalink
[ML] Add E5 model configs (#172053)
Browse files Browse the repository at this point in the history
## Summary

- Adds E5 model configurations available for download, portable and x86
linux optimized.
- Adds `getCuratedModelConfig` shared service to retrieve the model ID
and configuration appropriate for the current cluster architecture.
- Updates description for the ELSER model 
- Renames tabs in the "Add trained model" flyout 
- Renames the `name` property in the `ModelDefinitionResponse` interface
with `model_id`

<img width="1835" alt="image"
src="https://github.com/elastic/kibana/assets/5236598/abaf4f47-d581-493a-af1b-c663a0af9da6">

### Checklist

- [x] Any text added follows [EUI's writing
guidelines](https://elastic.github.io/eui/#/guidelines/writing), uses
sentence case text and includes [i18n
support](https://github.com/elastic/kibana/blob/main/packages/kbn-i18n/README.md)
- [ ]
[Documentation](https://www.elastic.co/guide/en/kibana/master/development-documentation.html)
was added for features that require explanation or tutorials
- [x] [Unit or functional
tests](https://www.elastic.co/guide/en/kibana/master/development-tests.html)
were updated or added to match the most common scenarios
  • Loading branch information
darnautov committed Dec 1, 2023
1 parent 1f8c816 commit 823552f
Show file tree
Hide file tree
Showing 14 changed files with 303 additions and 83 deletions.
1 change: 1 addition & 0 deletions packages/kbn-doc-links/src/get_doc_links.ts
Expand Up @@ -520,6 +520,7 @@ export const getDocLinks = ({ kibanaBranch }: GetDocLinkOptions): DocLinks => {
trainedModels: `${MACHINE_LEARNING_DOCS}ml-trained-models.html`,
startTrainedModelsDeployment: `${MACHINE_LEARNING_DOCS}ml-nlp-deploy-model.html`,
nlpElser: `${MACHINE_LEARNING_DOCS}ml-nlp-elser.html`,
nlpE5: `${MACHINE_LEARNING_DOCS}ml-nlp-e5.html`,
nlpImportModel: `${MACHINE_LEARNING_DOCS}ml-nlp-import-model.html`,
},
transforms: {
Expand Down
3 changes: 2 additions & 1 deletion x-pack/packages/ml/trained_models_utils/index.ts
Expand Up @@ -19,7 +19,8 @@ export {
type ModelDefinition,
type ModelDefinitionResponse,
type ElserVersion,
type GetElserOptions,
type GetModelDownloadConfigOptions,
type ElasticCuratedModelName,
ELSER_ID_V1,
ELASTIC_MODEL_TAG,
ELASTIC_MODEL_TYPE,
Expand Down
Expand Up @@ -61,6 +61,7 @@ export const ELASTIC_MODEL_DEFINITIONS: Record<string, ModelDefinition> = Object
description: i18n.translate('xpack.ml.trainedModels.modelsList.elserDescription', {
defaultMessage: 'Elastic Learned Sparse EncodeR v1 (Tech Preview)',
}),
type: ['elastic', 'pytorch', 'text_expansion'],
},
'.elser_model_2': {
modelName: 'elser',
Expand All @@ -74,6 +75,7 @@ export const ELASTIC_MODEL_DEFINITIONS: Record<string, ModelDefinition> = Object
description: i18n.translate('xpack.ml.trainedModels.modelsList.elserV2Description', {
defaultMessage: 'Elastic Learned Sparse EncodeR v2',
}),
type: ['elastic', 'pytorch', 'text_expansion'],
},
'.elser_model_2_linux-x86_64': {
modelName: 'elser',
Expand All @@ -88,14 +90,49 @@ export const ELASTIC_MODEL_DEFINITIONS: Record<string, ModelDefinition> = Object
description: i18n.translate('xpack.ml.trainedModels.modelsList.elserV2x86Description', {
defaultMessage: 'Elastic Learned Sparse EncodeR v2, optimized for linux-x86_64',
}),
type: ['elastic', 'pytorch', 'text_expansion'],
},
'.multilingual-e5-small': {
modelName: 'e5',
version: 1,
default: true,
config: {
input: {
field_names: ['text_field'],
},
},
description: i18n.translate('xpack.ml.trainedModels.modelsList.e5v1Description', {
defaultMessage: 'E5 (EmbEddings from bidirEctional Encoder rEpresentations)',
}),
license: 'MIT',
type: ['pytorch', 'text_embedding'],
},
'.multilingual-e5-small_linux-x86_64': {
modelName: 'e5',
version: 1,
os: 'Linux',
arch: 'amd64',
config: {
input: {
field_names: ['text_field'],
},
},
description: i18n.translate('xpack.ml.trainedModels.modelsList.e5v1x86Description', {
defaultMessage:
'E5 (EmbEddings from bidirEctional Encoder rEpresentations), optimized for linux-x86_64',
}),
license: 'MIT',
type: ['pytorch', 'text_embedding'],
},
} as const);

export type ElasticCuratedModelName = 'elser' | 'e5';

export interface ModelDefinition {
/**
* Model name, e.g. elser
*/
modelName: string;
modelName: ElasticCuratedModelName;
version: number;
/**
* Default PUT model configuration
Expand All @@ -107,13 +144,15 @@ export interface ModelDefinition {
default?: boolean;
recommended?: boolean;
hidden?: boolean;
license?: string;
type?: readonly string[];
}

export type ModelDefinitionResponse = ModelDefinition & {
/**
* Complete model id, e.g. .elser_model_2_linux-x86_64
*/
name: string;
model_id: string;
};

export type ElasticModelId = keyof typeof ELASTIC_MODEL_DEFINITIONS;
Expand All @@ -129,6 +168,6 @@ export type ModelState = typeof MODEL_STATE[keyof typeof MODEL_STATE] | null;

export type ElserVersion = 1 | 2;

export interface GetElserOptions {
export interface GetModelDownloadConfigOptions {
version?: ElserVersion;
}
2 changes: 1 addition & 1 deletion x-pack/plugins/elastic_assistant/server/plugin.ts
Expand Up @@ -80,7 +80,7 @@ export class ElasticAssistantPlugin
const getElserId: GetElser = once(
async (request: KibanaRequest, savedObjectsClient: SavedObjectsClientContract) => {
return (await plugins.ml.trainedModelsProvider(request, savedObjectsClient).getELSER())
.name;
.model_id;
}
);

Expand Down
Expand Up @@ -179,7 +179,7 @@ export const TextExpansionCalloutLogic = kea<
afterMount: async () => {
const elserModel = await KibanaLogic.values.ml.elasticModels?.getELSER({ version: 2 });
if (elserModel != null) {
actions.setElserModelId(elserModel.name);
actions.setElserModelId(elserModel.model_id);
actions.fetchTextExpansionModel();
}
},
Expand Down
Expand Up @@ -42,52 +42,52 @@ export interface AddModelFlyoutProps {
onSubmit: (modelId: string) => void;
}

type FlyoutTabId = 'clickToDownload' | 'manualDownload';

/**
* Flyout for downloading elastic curated models and showing instructions for importing third-party models.
*/
export const AddModelFlyout: FC<AddModelFlyoutProps> = ({ onClose, onSubmit, modelDownloads }) => {
const canCreateTrainedModels = usePermissionCheck('canCreateTrainedModels');
const isElserTabVisible = canCreateTrainedModels && modelDownloads.length > 0;
const isClickToDownloadTabVisible = canCreateTrainedModels && modelDownloads.length > 0;

const [selectedTabId, setSelectedTabId] = useState(isElserTabVisible ? 'elser' : 'thirdParty');
const [selectedTabId, setSelectedTabId] = useState<FlyoutTabId>(
isClickToDownloadTabVisible ? 'clickToDownload' : 'manualDownload'
);

const tabs = useMemo(() => {
return [
...(isElserTabVisible
...(isClickToDownloadTabVisible
? [
{
id: 'elser',
id: 'clickToDownload' as const,
name: (
<EuiFlexGroup gutterSize={'s'} alignItems={'center'}>
<EuiFlexItem grow={false}>
<EuiIcon type="logoElastic" size="m" />
</EuiFlexItem>
<EuiFlexItem grow={false}>
<FormattedMessage
id="xpack.ml.trainedModels.addModelFlyout.elserTabLabel"
defaultMessage="ELSER"
/>
</EuiFlexItem>
</EuiFlexGroup>
<FormattedMessage
id="xpack.ml.trainedModels.addModelFlyout.clickToDownloadTabLabel"
defaultMessage="Click to Download"
/>
),
content: (
<ElserTabContent modelDownloads={modelDownloads} onModelDownload={onSubmit} />
<ClickToDownloadTabContent
modelDownloads={modelDownloads}
onModelDownload={onSubmit}
/>
),
},
]
: []),
{
id: 'thirdParty',
id: 'manualDownload' as const,
name: (
<FormattedMessage
id="xpack.ml.trainedModels.addModelFlyout.thirdPartyLabel"
defaultMessage="Third-party"
defaultMessage="Manual Download"
/>
),
content: <ThirdPartyTabContent />,
content: <ManualDownloadTabContent />,
},
];
}, [isElserTabVisible, modelDownloads, onSubmit]);
}, [isClickToDownloadTabVisible, modelDownloads, onSubmit]);

const selectedTabContent = useMemo(() => {
return tabs.find((obj) => obj.id === selectedTabId)?.content;
Expand Down Expand Up @@ -133,15 +133,18 @@ export const AddModelFlyout: FC<AddModelFlyoutProps> = ({ onClose, onSubmit, mod
);
};

interface ElserTabContentProps {
interface ClickToDownloadTabContentProps {
modelDownloads: ModelItem[];
onModelDownload: (modelId: string) => void;
}

/**
* ELSER tab content for selecting a model to download.
* Tab content for selecting a model to download.
*/
const ElserTabContent: FC<ElserTabContentProps> = ({ modelDownloads, onModelDownload }) => {
const ClickToDownloadTabContent: FC<ClickToDownloadTabContentProps> = ({
modelDownloads,
onModelDownload,
}) => {
const {
services: { docLinks },
} = useMlKibana();
Expand All @@ -157,26 +160,33 @@ const ElserTabContent: FC<ElserTabContentProps> = ({ modelDownloads, onModelDown
<React.Fragment key={modelName}>
{modelName === 'elser' ? (
<div>
<EuiTitle size={'s'}>
<h3>
<FormattedMessage
id="xpack.ml.trainedModels.modelsList.elserTitle"
defaultMessage="Elastic Learned Sparse EncodeR (ELSER)"
/>
</h3>
</EuiTitle>
<EuiFlexGroup gutterSize={'s'} alignItems={'center'}>
<EuiFlexItem grow={false}>
<EuiIcon type="logoElastic" size="l" />
</EuiFlexItem>
<EuiFlexItem grow={false}>
<EuiTitle size={'s'}>
<h3>
<FormattedMessage
id="xpack.ml.trainedModels.modelsList.elserTitle"
defaultMessage="ELSER (Elastic Learned Sparse EncodeR)"
/>
</h3>
</EuiTitle>
</EuiFlexItem>
</EuiFlexGroup>
<EuiSpacer size="s" />
<p>
<EuiText color={'subdued'} size={'s'}>
<FormattedMessage
id="xpack.ml.trainedModels.addModelFlyout.elserDescription"
defaultMessage="ELSER is designed to efficiently use context in natural language queries with better results than BM25 alone."
defaultMessage="ELSER is Elastic's NLP model for English semantic search, utilizing sparse vectors. It prioritizes intent and contextual meaning over literal term matching, optimized specifically for English documents and queries on the Elastic platform."
/>
</EuiText>
</p>
<EuiSpacer size="s" />
<p>
<EuiLink href={docLinks.links.ml.nlpElser} external>
<EuiLink href={docLinks.links.ml.nlpElser} external target={'_blank'}>
<FormattedMessage
id="xpack.ml.trainedModels.modelsList.elserViewDocumentationLinkLabel"
defaultMessage="View documentation"
Expand All @@ -187,6 +197,52 @@ const ElserTabContent: FC<ElserTabContentProps> = ({ modelDownloads, onModelDown
</div>
) : null}

{modelName === 'e5' ? (
<div>
<EuiTitle size={'s'}>
<h3>
<FormattedMessage
id="xpack.ml.trainedModels.modelsList.e5Title"
defaultMessage="E5 (EmbEddings from bidirEctional Encoder rEpresentations)"
/>
</h3>
</EuiTitle>
<EuiSpacer size="s" />
<p>
<EuiText color={'subdued'} size={'s'}>
<FormattedMessage
id="xpack.ml.trainedModels.addModelFlyout.e5Description"
defaultMessage="E5 is an NLP model that enables you to perform multi-lingual semantic search by using dense vector representations. This model performs best for non-English language documents and queries."
/>
</EuiText>
</p>
<EuiSpacer size="s" />
<EuiFlexGroup justifyContent={'spaceBetween'} gutterSize={'none'}>
<EuiFlexItem grow={false}>
<EuiLink href={docLinks.links.ml.nlpE5} external target={'_blank'}>
<FormattedMessage
id="xpack.ml.trainedModels.modelsList.elserViewDocumentationLinkLabel"
defaultMessage="View documentation"
/>
</EuiLink>
</EuiFlexItem>
<EuiFlexItem grow={false}>
<EuiBadge
color="hollow"
target={'_blank'}
href={'https://huggingface.co/elastic/multilingual-e5-small-optimized'}
>
<FormattedMessage
id="xpack.ml.trainedModels.modelsList.mitLicenseLabel"
defaultMessage="License: MIT"
/>
</EuiBadge>
</EuiFlexItem>
</EuiFlexGroup>
<EuiSpacer size={'l'} />
</div>
) : null}

<EuiFormFieldset
legend={{
children: (
Expand All @@ -197,7 +253,7 @@ const ElserTabContent: FC<ElserTabContentProps> = ({ modelDownloads, onModelDown
),
}}
>
{models.map((model) => {
{models.map((model, index) => {
return (
<React.Fragment key={model.model_id}>
<EuiCheckableCard
Expand Down Expand Up @@ -256,11 +312,12 @@ const ElserTabContent: FC<ElserTabContentProps> = ({ modelDownloads, onModelDown
checked={model.model_id === selectedModelId}
onChange={setSelectedModelId.bind(null, model.model_id)}
/>
<EuiSpacer size="m" />
{index < models.length - 1 ? <EuiSpacer size="m" /> : null}
</React.Fragment>
);
})}
</EuiFormFieldset>
<EuiSpacer size="xxl" />
</React.Fragment>
);
})}
Expand All @@ -279,9 +336,9 @@ const ElserTabContent: FC<ElserTabContentProps> = ({ modelDownloads, onModelDown
};

/**
* Third-party tab content for showing instructions for importing third-party models.
* Manual download tab content for showing instructions for importing third-party models.
*/
const ThirdPartyTabContent: FC = () => {
const ManualDownloadTabContent: FC = () => {
const {
services: { docLinks },
} = useMlKibana();
Expand Down
Expand Up @@ -262,17 +262,17 @@ export const ModelsList: FC<Props> = ({
);
const forDownload = await trainedModelsApiService.getTrainedModelDownloads();
const notDownloaded: ModelItem[] = forDownload
.filter(({ name, hidden, recommended }) => {
if (recommended && idMap.has(name)) {
idMap.get(name)!.recommended = true;
.filter(({ model_id: modelId, hidden, recommended }) => {
if (recommended && idMap.has(modelId)) {
idMap.get(modelId)!.recommended = true;
}
return !idMap.has(name) && !hidden;
return !idMap.has(modelId) && !hidden;
})
.map<ModelItem>((modelDefinition) => {
return {
model_id: modelDefinition.name,
type: [ELASTIC_MODEL_TYPE],
tags: [ELASTIC_MODEL_TAG],
model_id: modelDefinition.model_id,
type: modelDefinition.type,
tags: modelDefinition.type?.includes(ELASTIC_MODEL_TAG) ? [ELASTIC_MODEL_TAG] : [],
putModelConfig: modelDefinition.config,
description: modelDefinition.description,
state: MODEL_STATE.NOT_DOWNLOADED,
Expand Down

0 comments on commit 823552f

Please sign in to comment.