Skip to content
This repository has been archived by the owner on Sep 22, 2023. It is now read-only.

Commit

Permalink
fix: Support new standard-complient GQL endpoint (#168)
Browse files Browse the repository at this point in the history
* refactor: Reuse Admin._query() for all other GQL invocations

Backported-From: main (21.09)
Backported-To: 20.09
  • Loading branch information
achimnol committed Jul 19, 2021
1 parent 65ba8c3 commit b9d93dd
Show file tree
Hide file tree
Showing 13 changed files with 140 additions and 420 deletions.
1 change: 1 addition & 0 deletions changes/168.fix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Support the new standard-compliant GQL endpoint in the manager with the API version v6.20210815
2 changes: 1 addition & 1 deletion src/ai/backend/client/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class Undefined(enum.Enum):
_config = None
_undefined = Undefined.token

API_VERSION = (6, '20200815')
API_VERSION = (6, '20210815')

DEFAULT_CHUNK_SIZE = 16 * (2**20) # 16 MiB
MAX_INFLIGHT_CHUNKS = 4
Expand Down
36 changes: 31 additions & 5 deletions src/ai/backend/client/func/admin.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import Any, Mapping, Optional

from .base import api_function, BaseFunction
from ..exceptions import BackendAPIError
from ..request import Request
from ..session import api_session

__all__ = (
'Admin',
Expand All @@ -22,7 +24,8 @@ class Admin(BaseFunction):
@api_function
@classmethod
async def query(
cls, query: str,
cls,
query: str,
variables: Optional[Mapping[str, Any]] = None,
) -> Any:
"""
Expand All @@ -35,11 +38,34 @@ async def query(
:returns: The object parsed from the response JSON string.
"""
return await cls._query(query, variables)

@classmethod
async def _query(
cls,
query: str,
variables: Optional[Mapping[str, Any]] = None,
) -> Any:
"""
Internal async implementation of the query() method,
which may be reused by other functional APIs to make GQL requests.
"""
gql_query = {
'query': query,
'variables': variables if variables else {},
}
rqst = Request('POST', '/admin/graphql')
rqst.set_json(gql_query)
async with rqst.fetch() as resp:
return await resp.json()
if api_session.get().api_version >= (6, '20210815'):
rqst = Request('POST', '/admin/gql')
rqst.set_json(gql_query)
async with rqst.fetch() as resp:
response = await resp.json()
errors = response.get("errors", [])
if errors:
raise BackendAPIError(400, reason="GraphQL-generated error", data=errors)
else:
return response["data"]
else:
rqst = Request('POST', '/admin/graphql')
rqst.set_json(gql_query)
async with rqst.fetch() as resp:
return await resp.json()
11 changes: 3 additions & 8 deletions src/ai/backend/client/func/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from .base import api_function, BaseFunction
from ..request import Request
from ..session import api_session
from ..pagination import generate_paginated_results

__all__ = (
Expand Down Expand Up @@ -86,14 +87,8 @@ async def detail(
""")
query = query.replace('$fields', ' '.join(fields))
variables = {'agent_id': agent_id}
rqst = Request('POST', '/admin/graphql')
rqst.set_json({
'query': query,
'variables': variables,
})
async with rqst.fetch() as resp:
data = await resp.json()
return data['agent']
data = await api_session.get().Admin._query(query, variables)
return data['agent']


class AgentWatcher(BaseFunction):
Expand Down
61 changes: 13 additions & 48 deletions src/ai/backend/client/func/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Iterable, Sequence

from .base import api_function, BaseFunction
from ..request import Request
from ..session import api_session

__all__ = (
'Domain',
Expand Down Expand Up @@ -39,13 +39,8 @@ async def list(cls, fields: Iterable[str] = None) -> Sequence[dict]:
}
""")
query = query.replace('$fields', ' '.join(fields))
rqst = Request('POST', '/admin/graphql')
rqst.set_json({
'query': query,
})
async with rqst.fetch() as resp:
data = await resp.json()
return data['domains']
data = await api_session.get().Admin._query(query)
return data['domains']

@api_function
@classmethod
Expand All @@ -67,14 +62,8 @@ async def detail(cls, name: str, fields: Iterable[str] = None) -> Sequence[dict]
""")
query = query.replace('$fields', ' '.join(fields))
variables = {'name': name}
rqst = Request('POST', '/admin/graphql')
rqst.set_json({
'query': query,
'variables': variables,
})
async with rqst.fetch() as resp:
data = await resp.json()
return data['domain']
data = await api_session.get().Admin._query(query, variables)
return data['domain']

@api_function
@classmethod
Expand Down Expand Up @@ -109,14 +98,8 @@ async def create(cls, name: str, description: str = '', is_active: bool = True,
'integration_id': integration_id,
},
}
rqst = Request('POST', '/admin/graphql')
rqst.set_json({
'query': query,
'variables': variables,
})
async with rqst.fetch() as resp:
data = await resp.json()
return data['create_domain']
data = await api_session.get().Admin._query(query, variables)
return data['create_domain']

@api_function
@classmethod
Expand Down Expand Up @@ -149,14 +132,8 @@ async def update(cls, name: str, new_name: str = None, description: str = None,
'integration_id': integration_id,
},
}
rqst = Request('POST', '/admin/graphql')
rqst.set_json({
'query': query,
'variables': variables,
})
async with rqst.fetch() as resp:
data = await resp.json()
return data['modify_domain']
data = await api_session.get().Admin._query(query, variables)
return data['modify_domain']

@api_function
@classmethod
Expand All @@ -172,14 +149,8 @@ async def delete(cls, name: str):
}
""")
variables = {'name': name}
rqst = Request('POST', '/admin/graphql')
rqst.set_json({
'query': query,
'variables': variables,
})
async with rqst.fetch() as resp:
data = await resp.json()
return data['delete_domain']
data = await api_session.get().Admin._query(query, variables)
return data['delete_domain']

@api_function
@classmethod
Expand All @@ -195,11 +166,5 @@ async def purge(cls, name: str):
}
""")
variables = {'name': name}
rqst = Request('POST', '/admin/graphql')
rqst.set_json({
'query': query,
'variables': variables,
})
async with rqst.fetch() as resp:
data = await resp.json()
return data['purge_domain']
data = await api_session.get().Admin._query(query, variables)
return data['purge_domain']
66 changes: 9 additions & 57 deletions src/ai/backend/client/func/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Iterable, Sequence

from .base import api_function, BaseFunction
from ..request import Request
from ..session import api_session

__all__ = (
'Group',
Expand Down Expand Up @@ -42,13 +42,7 @@ async def list(cls, domain_name: str,
""")
query = query.replace('$fields', ' '.join(fields))
variables = {'domain_name': domain_name}
rqst = Request('POST', '/admin/graphql')
rqst.set_json({
'query': query,
'variables': variables,
})
async with rqst.fetch() as resp:
data = await resp.json()
data = await api_session.get().Admin._query(query, variables)
return data['groups']

@api_function
Expand All @@ -70,13 +64,7 @@ async def detail(cls, gid: str, fields: Iterable[str] = None) -> Sequence[dict]:
""")
query = query.replace('$fields', ' '.join(fields))
variables = {'gid': gid}
rqst = Request('POST', '/admin/graphql')
rqst.set_json({
'query': query,
'variables': variables,
})
async with rqst.fetch() as resp:
data = await resp.json()
data = await api_session.get().Admin._query(query, variables)
return data['group']

@api_function
Expand Down Expand Up @@ -111,13 +99,7 @@ async def create(cls, domain_name: str, name: str, description: str = '',
'integration_id': integration_id,
},
}
rqst = Request('POST', '/admin/graphql')
rqst.set_json({
'query': query,
'variables': variables,
})
async with rqst.fetch() as resp:
data = await resp.json()
data = await api_session.get().Admin._query(query, variables)
return data['create_group']

@api_function
Expand Down Expand Up @@ -149,13 +131,7 @@ async def update(cls, gid: str, name: str = None, description: str = None,
'integration_id': integration_id,
},
}
rqst = Request('POST', '/admin/graphql')
rqst.set_json({
'query': query,
'variables': variables,
})
async with rqst.fetch() as resp:
data = await resp.json()
data = await api_session.get().Admin._query(query, variables)
return data['modify_group']

@api_function
Expand All @@ -172,13 +148,7 @@ async def delete(cls, gid: str):
}
""")
variables = {'gid': gid}
rqst = Request('POST', '/admin/graphql')
rqst.set_json({
'query': query,
'variables': variables,
})
async with rqst.fetch() as resp:
data = await resp.json()
data = await api_session.get().Admin._query(query, variables)
return data['delete_group']

@api_function
Expand All @@ -195,13 +165,7 @@ async def purge(cls, gid: str):
}
""")
variables = {'gid': gid}
rqst = Request('POST', '/admin/graphql')
rqst.set_json({
'query': query,
'variables': variables,
})
async with rqst.fetch() as resp:
data = await resp.json()
data = await api_session.get().Admin._query(query, variables)
return data['purge_group']

@api_function
Expand All @@ -226,13 +190,7 @@ async def add_users(cls, gid: str, user_uuids: Iterable[str],
'user_uuids': user_uuids,
},
}
rqst = Request('POST', '/admin/graphql')
rqst.set_json({
'query': query,
'variables': variables,
})
async with rqst.fetch() as resp:
data = await resp.json()
data = await api_session.get().Admin._query(query, variables)
return data['modify_group']

@api_function
Expand All @@ -257,11 +215,5 @@ async def remove_users(cls, gid: str, user_uuids: Iterable[str],
'user_uuids': user_uuids,
},
}
rqst = Request('POST', '/admin/graphql')
rqst.set_json({
'query': query,
'variables': variables,
})
async with rqst.fetch() as resp:
data = await resp.json()
data = await api_session.get().Admin._query(query, variables)
return data['modify_group']

0 comments on commit b9d93dd

Please sign in to comment.