Skip to content

Commit

Permalink
Implement searchModelVersions() API in Java client (#7880)
Browse files Browse the repository at this point in the history
* add ModelVersionsPage
Signed-off-by: gabrielfu <hfu.gabriel@gmail.com>

* add searchModelVersions signatures
Signed-off-by: gabrielfu <hfu.gabriel@gmail.com>

* finish searchModelVersions() methods
Signed-off-by: gabrielfu <hfu.gabriel@gmail.com>

* add overloaded searchModelVersions()
Signed-off-by: gabrielfu <hfu.gabriel@gmail.com>

* add test
Signed-off-by: gabrielfu <hfu.gabriel@gmail.com>

* fix http GET for model-versions/search
Signed-off-by: gabrielfu <hfu.gabriel@gmail.com>

* fix test: search by run id
Signed-off-by: gabrielfu <hfu.gabriel@gmail.com>

* remove unused imports
Signed-off-by: gabrielfu <hfu.gabriel@gmail.com>

* reformat line > 100 characters
Signed-off-by: gabrielfu <hfu.gabriel@gmail.com>

* add args to makeSearchModelVersions()
Signed-off-by: gabrielfu <hfu.gabriel@gmail.com>

* add args to searchModelVersions() & overloaded methods
Signed-off-by: gabrielfu <hfu.gabriel@gmail.com>

* add args to ModelVersionsPage
Signed-off-by: gabrielfu <hfu.gabriel@gmail.com>

* add tests for pagination & new args
Signed-off-by: gabrielfu <hfu.gabriel@gmail.com>

* fix builder to handle null search filter & page token
Signed-off-by: gabrielfu <hfu.gabriel@gmail.com>
  • Loading branch information
gabrielfu committed Feb 24, 2023
1 parent cf4e533 commit 36c787a
Show file tree
Hide file tree
Showing 4 changed files with 253 additions and 0 deletions.
Expand Up @@ -929,6 +929,76 @@ public File downloadLatestModelVersion(String modelName, String stage) {
return downloadModelVersion(modelName, details.getVersion());
}

/**
* Return model versions that satisfy the search query.
*
* @param searchFilter SQL compatible search query string.
* Examples:
* - "name = 'model_name'"
* - "run_id = '...'"
* If null, the result will be equivalent to having an empty search filter.
* @param maxResults Maximum number of model versions desired in one page.
* @param orderBy List of properties to order by. Example: "name DESC".
*
* @return A page of model versions that satisfy the search filter.
*/
public ModelVersionsPage searchModelVersions(String searchFilter,
int maxResults,
List<String> orderBy) {
return searchModelVersions(searchFilter, maxResults, orderBy, null);
}

/**
* Return up to 1000 model versions.
*
* @return A page of model versions with up to 1000 items.
*/
public ModelVersionsPage searchModelVersions() {
return searchModelVersions("", 1000, new ArrayList<>(), null);
}

/**
* Return up to 1000 model versions that satisfy the search query.
*
* @param searchFilter SQL compatible search query string.
* Examples:
* - "name = 'model_name'"
* - "run_id = '...'"
* If null, the result will be equivalent to having an empty search filter.
*
* @return A page of model versions with up to 1000 items.
*/
public ModelVersionsPage searchModelVersions(String searchFilter) {
return searchModelVersions(searchFilter, 1000, new ArrayList<>(), null);
}

/**
* Return model versions that satisfy the search query.
*
* @param searchFilter SQL compatible search query string.
* Examples:
* - "name = 'model_name'"
* - "run_id = '...'"
* If null, the result will be equivalent to having an empty search filter.
* @param maxResults Maximum number of model versions desired in one page.
* @param orderBy List of properties to order by. Example: "name DESC".
* @param pageToken String token specifying the next page of results. It should be obtained from
* a call to {@link #searchModelVersions(String)}.
*
* @return A page of model versions that satisfy the search filter.
*/
public ModelVersionsPage searchModelVersions(String searchFilter,
int maxResults,
List<String> orderBy,
String pageToken) {
String json = sendGet(mapper.makeSearchModelVersions(
searchFilter, maxResults, orderBy, pageToken
));
SearchModelVersions.Response response = mapper.toSearchModelVersionsResponse(json);
return new ModelVersionsPage(response.getModelVersionsList(), response.getNextPageToken(),
searchFilter, maxResults, orderBy, this);
}

/**
* Closes the MlflowClient and releases any associated resources.
*/
Expand Down
Expand Up @@ -6,6 +6,7 @@

import java.lang.Iterable;
import java.net.URISyntaxException;
import java.util.List;

import org.apache.http.client.utils.URIBuilder;
import org.mlflow.api.proto.ModelRegistry.*;
Expand Down Expand Up @@ -195,6 +196,30 @@ String makeGetModelVersionDownloadUri(String modelName, String modelVersion) {
}
}

String makeSearchModelVersions(String searchFilter,
int maxResults,
List<String> orderBy,
String pageToken) {
try {
URIBuilder builder = new URIBuilder("model-versions/search")
.addParameter("max_results", Integer.toString(maxResults));
if (searchFilter != null && searchFilter != "") {
builder.addParameter("filter", searchFilter);
}
if (pageToken != null && pageToken != "") {
builder.addParameter("page_token", pageToken);
}
for( String order: orderBy) {
builder.addParameter("order_by", order);
}
return builder.build().toString();
} catch (URISyntaxException e) {
throw new MlflowClientException(
"Failed to construct request URI for search model version.",
e);
}
}

String toJson(MessageOrBuilder mb) {
return print(mb);
}
Expand Down Expand Up @@ -272,6 +297,12 @@ String toGetModelVersionDownloadUriResponse(String json) {
return builder.getArtifactUri();
}

SearchModelVersions.Response toSearchModelVersionsResponse(String json) {
SearchModelVersions.Response.Builder builder = SearchModelVersions.Response.newBuilder();
merge(json, builder);
return builder.build();
}

private String print(MessageOrBuilder message) {
try {
return JsonFormat.printer().preservingProtoFieldNames().print(message);
Expand Down
@@ -0,0 +1,83 @@
package org.mlflow.tracking;

import java.util.Collections;
import java.util.List;
import java.util.Optional;
import org.mlflow.api.proto.ModelRegistry.*;

public class ModelVersionsPage implements Page<ModelVersion> {

private final String token;
private final List<ModelVersion> mvs;

private final MlflowClient client;
private final String searchFilter;
private final List<String> orderBy;
private final int maxResults;

/**
* Creates a fixed size page of ModelVersions.
*/
ModelVersionsPage(List<ModelVersion> mvs,
String token,
String searchFilter,
int maxResults,
List<String> orderBy,
MlflowClient client) {
this.mvs = Collections.unmodifiableList(mvs);
this.token = token;
this.searchFilter = searchFilter;
this.orderBy = orderBy;
this.maxResults = maxResults;
this.client = client;
}

/**
* @return The number of model versions in the page.
*/
public int getPageSize() {
return this.mvs.size();
}

/**
* @return True if a token for the next page exists and isn't empty. Otherwise returns false.
*/
public boolean hasNextPage() {
return this.token != null && this.token != "";
}

/**
* @return An optional with the token for the next page.
* Empty if the token doesn't exist or is empty.
*/
public Optional<String> getNextPageToken() {
if (this.hasNextPage()) {
return Optional.of(this.token);
} else {
return Optional.empty();
}
}

/**
* @return The next page of model versions matching the search criteria.
* If there are no more pages, an {@link org.mlflow.tracking.EmptyPage} will be returned.
*/
public Page<ModelVersion> getNextPage() {
if (this.hasNextPage()) {
return this.client.searchModelVersions(this.searchFilter,
this.maxResults,
this.orderBy,
this.token);
} else {
return new EmptyPage();
}
}

/**
* @return An iterable over the model versions in this page.
*/
public List<ModelVersion> getItems() {
return mvs;
}

}
Expand Up @@ -10,11 +10,13 @@
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.util.Arrays;
import java.util.List;
import java.util.UUID;
import org.apache.commons.io.FileUtils;
import org.mlflow.api.proto.ModelRegistry.ModelVersion;
import org.mlflow.api.proto.ModelRegistry.RegisteredModel;
import org.mlflow.api.proto.Service;
import org.mlflow.api.proto.Service.RunInfo;
import org.mockito.Mockito;
import org.slf4j.Logger;
Expand Down Expand Up @@ -160,4 +162,71 @@ private void validateDetailedModelVersion(ModelVersion details, String modelName
Assert.assertEquals(details.getName(), modelName);
Assert.assertEquals(details.getVersion(), version);
}

@Test
public void testSearchModelVersions() {
List<ModelVersion> mvsBefore = client.searchModelVersions().getItems();

// create new model version of existing registered model
String newVersionRunId = "newVersionRunId";
String newVersionSource = "newVersionSource";
client.sendPost("model-versions/create",
mapper.makeCreateModelVersion(modelName, newVersionRunId, newVersionSource));

// create new registered model
String modelName2 = "modelName2";
String runId2 = "runId2";
String source2 = "source2";
client.sendPost("registered-models/create",
mapper.makeCreateModel(modelName2));
client.sendPost("model-versions/create",
mapper.makeCreateModelVersion(modelName2, runId2, source2));

List<ModelVersion> mvsAfter = client.searchModelVersions().getItems();
Assert.assertEquals(mvsAfter.size(), 2 + mvsBefore.size());

String filter1 = String.format("name = '%s'", modelName);
List<ModelVersion> mvs1 = client.searchModelVersions(filter1).getItems();
Assert.assertEquals(mvs1.size(), 2);
Assert.assertEquals(mvs1.get(0).getName(), modelName);
Assert.assertEquals(mvs1.get(1).getName(), modelName);

String filter2 = String.format("name = '%s'", modelName2);
List<ModelVersion> mvs2 = client.searchModelVersions(filter2).getItems();
Assert.assertEquals(mvs2.size(), 1);
Assert.assertEquals(mvs2.get(0).getName(), modelName2);
Assert.assertEquals(mvs2.get(0).getVersion(), "1");

String filter3 = String.format("run_id = '%s'", newVersionRunId);
List<ModelVersion> mvs3 = client.searchModelVersions(filter3).getItems();
Assert.assertEquals(mvs3.size(), 1);
Assert.assertEquals(mvs3.get(0).getName(), modelName);
Assert.assertEquals(mvs3.get(0).getVersion(), "2");

ModelVersionsPage page1 = client.searchModelVersions(
"", 1, Arrays.asList("creation_timestamp ASC")
);
Assert.assertEquals(page1.getItems().size(), 1);
Assert.assertEquals(page1.getItems().get(0).getName(), modelName);
Assert.assertTrue(page1.getNextPageToken().isPresent());

ModelVersionsPage page2 = client.searchModelVersions(
"",
2,
Arrays.asList("creation_timestamp ASC"),
page1.getNextPageToken().get()
);
Assert.assertEquals(page2.getItems().size(), 2);
Assert.assertEquals(page2.getItems().get(0).getName(), modelName);
Assert.assertEquals(page2.getItems().get(0).getRunId(), newVersionRunId);
Assert.assertEquals(page2.getItems().get(1).getName(), modelName2);
Assert.assertEquals(page2.getItems().get(1).getRunId(), runId2);
Assert.assertFalse(page2.getNextPageToken().isPresent());

ModelVersionsPage nextPageFromPrevPage = (ModelVersionsPage) page1.getNextPage();
Assert.assertEquals(nextPageFromPrevPage.getItems().size(), 1);
Assert.assertEquals(page2.getItems().get(0).getName(), modelName);
Assert.assertEquals(page2.getItems().get(0).getRunId(), newVersionRunId);
Assert.assertTrue(nextPageFromPrevPage.getNextPageToken().isPresent());
}
}

0 comments on commit 36c787a

Please sign in to comment.