-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[ML] Add default Elastic Inference Service chat completion endpoint #120847
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ML] Add default Elastic Inference Service chat completion endpoint #120847
Conversation
eda2068 to
665e700
Compare
| import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.assertStatusOkOrCreated; | ||
| import static org.hamcrest.Matchers.equalTo; | ||
|
|
||
| public class InferenceGetServicesIT extends ESRestTestCase { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved this to BaseMockEISAuthServerTest
| private static final Logger logger = LogManager.getLogger(ElasticInferenceService.class); | ||
| private static final EnumSet<TaskType> IMPLEMENTED_TASK_TYPES = EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION); | ||
| private static final String SERVICE_NAME = "Elastic"; | ||
| static final String DEFAULT_CHAT_COMPLETION_MODEL_ID_V1 = "rainbow-sprinkles"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Model name
| private static final EnumSet<TaskType> IMPLEMENTED_TASK_TYPES = EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION); | ||
| private static final String SERVICE_NAME = "Elastic"; | ||
| static final String DEFAULT_CHAT_COMPLETION_MODEL_ID_V1 = "rainbow-sprinkles"; | ||
| static final String DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1 = Strings.format(".%s-elastic", DEFAULT_CHAT_COMPLETION_MODEL_ID_V1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Inference endpoint ID
| ); | ||
| } | ||
|
|
||
| private record AuthorizedContent( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just an aggregation of all the different pieces we need to expose (enabled task types, DefaultConfigId objects, and a list of models).
| if (auth.getEnabledTaskTypes().contains(model.getTaskType()) == false) { | ||
| logger.warn( | ||
| Strings.format( | ||
| "The authorization response included the default model: %s, " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the unlikely chance that the gateway and the definition of the default model have differing task types, we'll enable the model anyway because that's what the gateway said to do.
This would only happen if the authorization response returned something different from how we've set the task type for the model here.
| private Set<String> getEnabledDefaultModelIds(ElasticInferenceServiceAuthorization auth) { | ||
| var enabledModels = auth.getEnabledModels(); | ||
| var enabledDefaultModelIds = new HashSet<>(defaultModels.keySet()); | ||
| enabledDefaultModelIds.retainAll(enabledModels); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Return the model ids where there is overlap between what the gateway authorized and the default ones we've defined.
| * This is a helper class for managing the response from {@link ElasticInferenceServiceAuthorizationHandler}. | ||
| */ | ||
| public record ElasticInferenceServiceAuthorization(Map<String, EnumSet<TaskType>> enabledModels) { | ||
| public class ElasticInferenceServiceAuthorization { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I refactored this because we need both the authorized task types and the authorized models.
| public record ElasticInferenceServiceAuthorization(Map<String, EnumSet<TaskType>> enabledModels) { | ||
| public class ElasticInferenceServiceAuthorization { | ||
|
|
||
| private final Map<TaskType, Set<String>> taskTypeToModels; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This mapping helps when we need to create a new object that's limited to the what the service actually supports. So we can easily grab the models that were authorized for a particular task type.
|
Pinging @elastic/ml-core (Team:ML) |
davidkyle
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
joshdevins
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Had a quick scan only. Minor comments.
| .setting("xpack.security.enabled", "true") | ||
| // Adding both settings unless one feature flag is disabled in a particular environment | ||
| .setting("xpack.inference.elastic.url", mockEISServer::getUrl) | ||
| // TODO remove this once we've removed DEPRECATED_ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG and EIS_GATEWAY_URL |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is gone now. @vidok ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like it's being removed in this PR: #120842
| "task_types": ["chat"] | ||
| }, | ||
| { | ||
| "model_name": ".elser_model_2", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
EIS will expose elser-v2. Not sure it matters for this test though.
See: #120981
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll update it 👍 the model ID here isn't actually being used but might as well try to align it for the future when we do use it.
💔 Backport failed
You can use sqren/backport to manually backport by running |
💚 All backports created successfully
Questions ?Please refer to the Backport tool documentation |
…lastic#120847) * Starting new auth class implementation * Fixing some tests * Working tests * Refactoring * Addressing feedback and pull main (cherry picked from commit 1fa1ba7)
This PR adds the first iteration model id for the elastic inference service.
Model id:
rainbow-sprinklesThe default endpoint id:
.rainbow-sprinkles-elasticTesting
Without EIS
GET _inference/_allelasticshould not be listed in the responseGET _inference/_services.rainbow-sprinkles-elasticshould not be listed in the responseWith EIS
Get the right certs directory.
Run the gateway:
make TLS_VERIFY_CLIENT_CERTS=false runRun ES:
Retrieve all the default inference endpoints
Retrieving all the available services for sparse embedding