Skip to content

Conversation

@jonathan-buttner
Copy link
Contributor

@jonathan-buttner jonathan-buttner commented Jan 24, 2025

This PR adds the first iteration model id for the elastic inference service.

Model id: rainbow-sprinkles

The default endpoint id: .rainbow-sprinkles-elastic

Testing

Without EIS

GET _inference/_all

elastic should not be listed in the response

GET _inference/_services

.rainbow-sprinkles-elastic should not be listed in the response

With EIS

Get the right certs directory.

Run the gateway:

make TLS_VERIFY_CLIENT_CERTS=false run

Run ES:

./gradlew :run -Drun.license_type=trial -Dtests.es.xpack.inference.elastic.url=https://localhost:8443 -Dtests.es.xpack.inference.elastic.http.ssl.verification_mode=none
Retrieve all the default inference endpoints
GET _inference/_all
{
    "endpoints": [
        ...
        {
            "inference_id": ".rainbow-sprinkles-elastic",
            "task_type": "chat_completion",
            "service": "elastic",
            "service_settings": {
                "model_id": "rainbow-sprinkles",
                "rate_limit": {
                    "requests_per_minute": 240
                }
            }
        },
        ...
    ]
}

Retrieving all the available services for sparse embedding
GET _inference/_services/sparse_embedding
[
    ...
    {
        "service": "elastic",
        "name": "Elastic",
        "task_types": [
            "sparse_embedding",
            "chat_completion"
        ],
        "configurations": {
            "rate_limit.requests_per_minute": {
                "description": "Minimize the number of rate limit errors.",
                "label": "Rate Limit",
                "required": false,
                "sensitive": false,
                "updatable": false,
                "type": "int",
                "supported_task_types": [
                    "sparse_embedding",
                    "chat_completion"
                ]
            },
            "model_id": {
                "description": "The name of the model to use for the inference task.",
                "label": "Model ID",
                "required": true,
                "sensitive": false,
                "updatable": false,
                "type": "str",
                "supported_task_types": [
                    "sparse_embedding",
                    "chat_completion"
                ]
            },
            "max_input_tokens": {
                "description": "Allows you to specify the maximum number of tokens per input.",
                "label": "Maximum Input Tokens",
                "required": false,
                "sensitive": false,
                "updatable": false,
                "type": "int",
                "supported_task_types": [
                    "sparse_embedding"
                ]
            }
        }
    },
    ...
]

@jonathan-buttner jonathan-buttner added >refactoring :ml Machine learning Team:ML Meta label for the ML team auto-backport Automatically create backport pull requests when merged v9.0.0 v8.18.0 labels Jan 24, 2025
import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.assertStatusOkOrCreated;
import static org.hamcrest.Matchers.equalTo;

public class InferenceGetServicesIT extends ESRestTestCase {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved this to BaseMockEISAuthServerTest

private static final Logger logger = LogManager.getLogger(ElasticInferenceService.class);
private static final EnumSet<TaskType> IMPLEMENTED_TASK_TYPES = EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION);
private static final String SERVICE_NAME = "Elastic";
static final String DEFAULT_CHAT_COMPLETION_MODEL_ID_V1 = "rainbow-sprinkles";
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Model name

private static final EnumSet<TaskType> IMPLEMENTED_TASK_TYPES = EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION);
private static final String SERVICE_NAME = "Elastic";
static final String DEFAULT_CHAT_COMPLETION_MODEL_ID_V1 = "rainbow-sprinkles";
static final String DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1 = Strings.format(".%s-elastic", DEFAULT_CHAT_COMPLETION_MODEL_ID_V1);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inference endpoint ID

);
}

private record AuthorizedContent(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just an aggregation of all the different pieces we need to expose (enabled task types, DefaultConfigId objects, and a list of models).

if (auth.getEnabledTaskTypes().contains(model.getTaskType()) == false) {
logger.warn(
Strings.format(
"The authorization response included the default model: %s, "
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the unlikely chance that the gateway and the definition of the default model have differing task types, we'll enable the model anyway because that's what the gateway said to do.

This would only happen if the authorization response returned something different from how we've set the task type for the model here.

private Set<String> getEnabledDefaultModelIds(ElasticInferenceServiceAuthorization auth) {
var enabledModels = auth.getEnabledModels();
var enabledDefaultModelIds = new HashSet<>(defaultModels.keySet());
enabledDefaultModelIds.retainAll(enabledModels);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Return the model ids where there is overlap between what the gateway authorized and the default ones we've defined.

* This is a helper class for managing the response from {@link ElasticInferenceServiceAuthorizationHandler}.
*/
public record ElasticInferenceServiceAuthorization(Map<String, EnumSet<TaskType>> enabledModels) {
public class ElasticInferenceServiceAuthorization {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I refactored this because we need both the authorized task types and the authorized models.

public record ElasticInferenceServiceAuthorization(Map<String, EnumSet<TaskType>> enabledModels) {
public class ElasticInferenceServiceAuthorization {

private final Map<TaskType, Set<String>> taskTypeToModels;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This mapping helps when we need to create a new object that's limited to the what the service actually supports. So we can easily grab the models that were authorized for a particular task type.

@jonathan-buttner jonathan-buttner marked this pull request as ready for review January 27, 2025 20:47
@elasticsearchmachine
Copy link
Collaborator

Pinging @elastic/ml-core (Team:ML)

Copy link
Member

@davidkyle davidkyle left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Member

@joshdevins joshdevins left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had a quick scan only. Minor comments.

.setting("xpack.security.enabled", "true")
// Adding both settings unless one feature flag is disabled in a particular environment
.setting("xpack.inference.elastic.url", mockEISServer::getUrl)
// TODO remove this once we've removed DEPRECATED_ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG and EIS_GATEWAY_URL
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is gone now. @vidok ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like it's being removed in this PR: #120842

"task_types": ["chat"]
},
{
"model_name": ".elser_model_2",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EIS will expose elser-v2. Not sure it matters for this test though.
See: #120981

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll update it 👍 the model ID here isn't actually being used but might as well try to align it for the future when we do use it.

@jonathan-buttner jonathan-buttner merged commit 1fa1ba7 into elastic:main Jan 28, 2025
16 checks passed
@jonathan-buttner jonathan-buttner deleted the ml-eis-default-endpoint branch January 28, 2025 14:57
@elasticsearchmachine
Copy link
Collaborator

💔 Backport failed

Status Branch Result
8.x Commit could not be cherrypicked due to conflicts

You can use sqren/backport to manually backport by running backport --upstream elastic/elasticsearch --pr 120847

@jonathan-buttner
Copy link
Contributor Author

💚 All backports created successfully

Status Branch Result
8.x

Questions ?

Please refer to the Backport tool documentation

jonathan-buttner added a commit to jonathan-buttner/elasticsearch that referenced this pull request Jan 28, 2025
…lastic#120847)

* Starting new auth class implementation

* Fixing some tests

* Working tests

* Refactoring

* Addressing feedback and pull main

(cherry picked from commit 1fa1ba7)
elasticsearchmachine pushed a commit that referenced this pull request Jan 28, 2025
…120847) (#121061)

* Starting new auth class implementation

* Fixing some tests

* Working tests

* Refactoring

* Addressing feedback and pull main

(cherry picked from commit 1fa1ba7)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

auto-backport Automatically create backport pull requests when merged backport pending :ml Machine learning >refactoring Team:ML Meta label for the ML team v8.18.0 v9.0.0

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants