Skip to content

Commit

Permalink
Join run and experiment in graphql (#11173)
Browse files Browse the repository at this point in the history
Signed-off-by: edwardfeng-db <jinghao.feng@databricks.com>
  • Loading branch information
edwardfeng-db committed Feb 20, 2024
1 parent 8915365 commit 601122d
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 17 deletions.
1 change: 1 addition & 0 deletions dev/proto_to_graphql/autogeneration_utils.py
Expand Up @@ -5,6 +5,7 @@

INDENT = " " * 4
INDENT2 = INDENT * 2
SCHEMA_EXTENSION_MODULE = "mlflow.server.graphql.graphql_schema_extensions"
SCHEMA_EXTENSION = "mlflow/server/graphql/graphql_schema_extensions.py"
AUTOGENERATED_SCHEMA = "mlflow/server/graphql/autogenerated_graphql_schema.py"
DUMMY_FIELD = (
Expand Down
7 changes: 5 additions & 2 deletions dev/proto_to_graphql/schema_autogeneration.py
Expand Up @@ -5,6 +5,7 @@
INDENT,
INDENT2,
SCHEMA_EXTENSION,
SCHEMA_EXTENSION_MODULE,
get_descriptor_full_pascal_name,
get_method_name,
method_descriptor_to_generated_pb2_file_name,
Expand Down Expand Up @@ -82,7 +83,6 @@ def generate_schema(state):
schema_builder += "# Run python3 ./dev/proto_to_graphql/code_generator.py to regenerate.\n"
schema_builder += "import graphene\n"
schema_builder += "import mlflow\n"
schema_builder += "from mlflow.server.graphql import graphql_schema_extensions\n"
schema_builder += "from mlflow.server.graphql.graphql_custom_scalars import LongString\n"
schema_builder += "from mlflow.utils.proto_json_utils import parse_dict\n"
schema_builder += "\n"
Expand Down Expand Up @@ -149,7 +149,10 @@ def generate_schema(state):

def apply_schema_extension(referenced_class_name):
if referenced_class_name in EXTENDED_TO_EXTENDING:
return f"graphql_schema_extensions.{EXTENDED_TO_EXTENDING[referenced_class_name]}"
# Using dotted module path as pointed out in the linked GitHub issue.r
# This is an undocumented feature of Graphene.
# https://github.com/graphql-python/graphene/issues/110#issuecomment-1219737639
return f"'{SCHEMA_EXTENSION_MODULE}.{EXTENDED_TO_EXTENDING[referenced_class_name]}'"
else:
return referenced_class_name

Expand Down
17 changes: 8 additions & 9 deletions mlflow/server/graphql/autogenerated_graphql_schema.py
Expand Up @@ -2,7 +2,6 @@
# Run python3 ./dev/proto_to_graphql/code_generator.py to regenerate.
import graphene
import mlflow
from mlflow.server.graphql import graphql_schema_extensions
from mlflow.server.graphql.graphql_custom_scalars import LongString
from mlflow.utils.proto_json_utils import parse_dict

Expand Down Expand Up @@ -81,7 +80,7 @@ class MlflowRun(graphene.ObjectType):


class MlflowGetRunResponse(graphene.ObjectType):
run = graphene.Field(MlflowRun)
run = graphene.Field('mlflow.server.graphql.graphql_schema_extensions.MlflowRunExtension')


class MlflowExperimentTag(graphene.ObjectType):
Expand Down Expand Up @@ -113,21 +112,21 @@ class MlflowGetExperimentInput(graphene.InputObjectType):


class QueryType(graphene.ObjectType):
mlflow_get_run = graphene.Field(MlflowGetRunResponse, input=MlflowGetRunInput())
mlflow_get_experiment = graphene.Field(MlflowGetExperimentResponse, input=MlflowGetExperimentInput())

def resolve_mlflow_get_run(self, info, input):
input_dict = vars(input)
request_message = mlflow.protos.service_pb2.GetRun()
parse_dict(input_dict, request_message)
return mlflow.server.handlers.get_run_impl(request_message)
mlflow_get_run = graphene.Field(MlflowGetRunResponse, input=MlflowGetRunInput())

def resolve_mlflow_get_experiment(self, info, input):
input_dict = vars(input)
request_message = mlflow.protos.service_pb2.GetExperiment()
parse_dict(input_dict, request_message)
return mlflow.server.handlers.get_experiment_impl(request_message)

def resolve_mlflow_get_run(self, info, input):
input_dict = vars(input)
request_message = mlflow.protos.service_pb2.GetRun()
parse_dict(input_dict, request_message)
return mlflow.server.handlers.get_run_impl(request_message)


class MutationType(graphene.ObjectType):
pass
20 changes: 19 additions & 1 deletion mlflow/server/graphql/graphql_schema_extensions.py
@@ -1,6 +1,13 @@
import graphene

from mlflow.server.graphql.autogenerated_graphql_schema import MutationType, QueryType
import mlflow
from mlflow.server.graphql.autogenerated_graphql_schema import (
MlflowExperiment,
MlflowRun,
MutationType,
QueryType,
)
from mlflow.utils.proto_json_utils import parse_dict


class Test(graphene.ObjectType):
Expand All @@ -11,6 +18,17 @@ class TestMutation(graphene.ObjectType):
output = graphene.String(description="Echoes the input string")


class MlflowRunExtension(MlflowRun):
experiment = graphene.Field(MlflowExperiment)

def resolve_experiment(self, info):
experiment_id = self.info.experiment_id
input_dict = {"experiment_id": experiment_id}
request_message = mlflow.protos.service_pb2.GetExperiment()
parse_dict(input_dict, request_message)
return mlflow.server.handlers.get_experiment_impl(request_message).experiment


class Query(QueryType):
test = graphene.Field(Test, input_string=graphene.String(), description="Simple echoing field")

Expand Down
36 changes: 31 additions & 5 deletions tests/tracking/test_rest_tracking.py
Expand Up @@ -1740,8 +1740,7 @@ def test_upload_artifact_handler(mlflow_client):


def test_graphql_handler(mlflow_client):
response = requests.request(
"post",
response = requests.post(
f"{mlflow_client.tracking_uri}/graphql",
json={
"query": 'query testQuery {test(inputString: "abc") { output }}',
Expand All @@ -1754,8 +1753,7 @@ def test_graphql_handler(mlflow_client):

def test_get_experiment_graphql(mlflow_client):
experiment_id = mlflow_client.create_experiment("GraphqlTest")
response = requests.request(
"post",
response = requests.post(
f"{mlflow_client.tracking_uri}/graphql",
json={
"query": 'query testQuery {mlflowGetExperiment(input: {experimentId: "'
Expand All @@ -1766,4 +1764,32 @@ def test_get_experiment_graphql(mlflow_client):
headers={"content-type": "application/json; charset=utf-8"},
)
assert response.status_code == 200
assert "GraphqlTest" in response.text
json = response.json()
assert json["data"]["mlflowGetExperiment"]["experiment"]["name"] == "GraphqlTest"


def test_get_run_and_experiment_graphql(mlflow_client):
experiment_id = mlflow_client.create_experiment("GraphqlTest")
created_run = mlflow_client.create_run(experiment_id)
run_id = created_run.info.run_id
response = requests.post(
f"{mlflow_client.tracking_uri}/graphql",
json={
"query": f"""
query testQuery {{
mlflowGetRun(input: {{runId: "{run_id}"}}) {{
run {{
experiment {{
name
}}
}}
}}
}}
""",
"operationName": "testQuery",
},
headers={"content-type": "application/json; charset=utf-8"},
)
assert response.status_code == 200
json = response.json()
assert json["data"]["mlflowGetRun"]["run"]["experiment"]["name"] == "GraphqlTest"

0 comments on commit 601122d

Please sign in to comment.