Skip to content

Commit

Permalink
feat: add stream endpoint (#504)
Browse files Browse the repository at this point in the history
* feat: add stream endpoint

* feat: add stream response endpoint

* chore: add changelog

* feat: update return type hints

* feat: handle exception if run is not ready
  • Loading branch information
bwanglzu committed Aug 16, 2022
1 parent 5508b98 commit 7af6234
Show file tree
Hide file tree
Showing 8 changed files with 77 additions and 12 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Enable wandb callback. ([#494](https://github.com/jina-ai/finetuner/pull/494))

- Support log streaming in finetuner client. ([#504](https://github.com/jina-ai/finetuner/pull/504))

### Removed

### Changed
Expand Down
3 changes: 2 additions & 1 deletion docs/walkthrough/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ run = finetuner.fit(
)

print(run.name)
print(run.logs())
for log_entry in run.stream_logs():
print(log_entry)

# When ready
run.save_artifact(directory='experiment')
Expand Down
4 changes: 3 additions & 1 deletion docs/walkthrough/save-model.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ experiment = finetuner.get_experiment('finetune-flickr-dataset')
run = experiment.get_run('finetune-flickr-dataset-efficientnet-1')
print(f'Run status: {run.status()}')
print(f'Run artifact id: {run.artifact_id}')
print(f'Run logs: {run.logs()}')
# Once run status is `STARTED`, you can stream logs with:
for log_entry in run.stream_logs():
print(log_entry)
# save the artifact.
run.save_artifact('tuned_model')
```
Expand Down
21 changes: 16 additions & 5 deletions finetuner/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,24 @@ def _handle_request(
method: str,
params: Optional[dict] = None,
json_data: Optional[dict] = None,
) -> Union[dict, List[dict], str]:
stream: bool = False,
) -> Union[dict, List[dict], str, requests.Response]:
"""The base request handler.
:param url: The url of the request.
:param method: The request type (GET, POST or DELETE).
:param params: Optional parameters for the request.
:param json_data: Optional data payloads to be sent along with the request.
:param stream: If the request is a streaming request set to True.
:return: Response to the request.
"""
response = self._session.request(
url=url, method=method, json=json_data, params=params, allow_redirects=False
url=url,
method=method,
json=json_data,
params=params,
allow_redirects=False,
stream=stream,
)
if response.status_code == 307:
response = self._session.request(
Expand All @@ -72,13 +79,17 @@ def _handle_request(
json=json_data,
params=params,
allow_redirects=False,
stream=stream,
)
if not response.ok:
raise FinetunerServerError(
message=response.reason,
code=response.status_code,
details=response.json()['detail'],
)
if TEXT in response.headers['content-type']:
return response.text
return response.json()
if stream:
return response
else:
if TEXT in response.headers['content-type']:
return response.text
return response.json()
25 changes: 24 additions & 1 deletion finetuner/client/client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional
from typing import Iterator, List, Optional

from finetuner.client.base import _BaseClient
from finetuner.constants import (
Expand All @@ -12,6 +12,7 @@
GET,
GPUS,
LOGS,
LOGSTREAM,
NAME,
POST,
RUNS,
Expand Down Expand Up @@ -166,6 +167,28 @@ def get_run_logs(self, experiment_name: str, run_name: str) -> str:
)
return self._handle_request(url=url, method=GET)

def stream_run_logs(self, experiment_name: str, run_name: str) -> Iterator[str]:
"""Streaming log events to the client as ServerSentEvents.
:param experiment_name: The name of the experiment.
:param run_name: The name of the run.
:yield: A log entry.
"""
url = self._construct_url(
self._base_url,
API_VERSION,
EXPERIMENTS,
experiment_name,
RUNS,
run_name,
LOGSTREAM,
)
response = self._handle_request(url=url, method=GET, stream=True)
for entry in response.iter_lines():
entry = entry.decode('utf-8', errors='ignore')
if entry:
yield entry

def create_run(
self,
experiment_name: str,
Expand Down
1 change: 1 addition & 0 deletions finetuner/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
RUNS = 'runs'
STATUS = 'status'
LOGS = 'logs'
LOGSTREAM = 'logstream'
EXPERIMENTS = 'experiments'
API_VERSION = 'api/v1'
AUTHORIZATION = 'Authorization'
Expand Down
4 changes: 4 additions & 0 deletions finetuner/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ class RunInProgressError(Exception):
...


class RunPreparingError(Exception):
...


class RunFailedError(Exception):
...

Expand Down
29 changes: 25 additions & 4 deletions finetuner/run.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Iterator

from finetuner.client import FinetunerV1Client
from finetuner.constants import (
ARTIFACT_ID,
Expand All @@ -7,7 +9,7 @@
STARTED,
STATUS,
)
from finetuner.exception import RunFailedError, RunInProgressError
from finetuner.exception import RunFailedError, RunInProgressError, RunPreparingError
from finetuner.hubble import download_artifact


Expand Down Expand Up @@ -67,11 +69,22 @@ def logs(self) -> str:
:returns: A string dump of the run logs.
"""
self._check_run_status_started()
return self._client.get_run_logs(
experiment_name=self._experiment_name, run_name=self._name
)

def _check_run_status(self):
def stream_logs(self) -> Iterator[str]:
"""Stream the run logs.
:yield: An iterators keep stream the logs from server.
"""
self._check_run_status_started()
return self._client.stream_run_logs(
experiment_name=self._experiment_name, run_name=self._name
)

def _check_run_status_finished(self):
status = self.status()[STATUS]
if status in [CREATED, STARTED]:
raise RunInProgressError(
Expand All @@ -82,13 +95,21 @@ def _check_run_status(self):
'The run failed, please check the `logs` for detailed information.'
)

def _check_run_status_started(self):
status = self.status()[STATUS]
if status == CREATED:
raise RunPreparingError(
'The run is preparing to run, logs will be ready to pull when '
'`status` is `STARTED`.'
)

def save_artifact(self, directory: str = ARTIFACTS_DIR) -> str:
"""Save artifact if the run is finished.
:param directory: Directory where the artifact will be stored.
:returns: A string object that indicates the download path.
"""
self._check_run_status()
self._check_run_status_finished()
return download_artifact(
client=self._client,
artifact_id=self._run[ARTIFACT_ID],
Expand All @@ -107,5 +128,5 @@ def artifact_id(self):
:return: Artifact id as string object.
"""
self._check_run_status()
self._check_run_status_finished()
return self._run[ARTIFACT_ID]

0 comments on commit 7af6234

Please sign in to comment.