Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

feat(labeler): gently terminate the labler UI from frontend #177

Merged
merged 2 commits into from
Oct 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added docs/_static/favicon.ico
Binary file not shown.
12 changes: 5 additions & 7 deletions finetuner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def fit(
optimizer: str = 'adam',
optimizer_kwargs: Optional[Dict] = None,
device: str = 'cpu',
) -> 'Summary':
) -> Tuple['AnyDNN', 'Summary']:
...


Expand All @@ -49,7 +49,7 @@ def fit(
output_dim: Optional[int] = None,
freeze: bool = False,
device: str = 'cpu',
) -> 'Summary':
) -> Tuple['AnyDNN', 'Summary']:
...


Expand All @@ -67,7 +67,7 @@ def fit(
optimizer: str = 'adam',
optimizer_kwargs: Optional[Dict] = None,
device: str = 'cpu',
) -> None:
) -> Tuple['AnyDNN', 'Summary']:
...


Expand All @@ -91,13 +91,13 @@ def fit(
output_dim: Optional[int] = None,
freeze: bool = False,
device: str = 'cpu',
) -> None:
) -> Tuple['AnyDNN', 'Summary']:
...


def fit(
model: 'AnyDNN', train_data: 'DocumentArrayLike', *args, **kwargs
) -> Optional[Tuple['AnyDNN', 'Summary']]:
) -> Tuple['AnyDNN', Optional['Summary']]:
if kwargs.get('to_embedding_model', False):
from .tailor import to_embedding_model

Expand All @@ -106,8 +106,6 @@ def fit(
if kwargs.get('interactive', False):
from .labeler import fit

# TODO: atm return will never hit as labeler UI hangs the
# flow via `.block()`
return model, fit(model, train_data, *args, **kwargs)
else:
from .tuner import fit
Expand Down
8 changes: 7 additions & 1 deletion finetuner/labeler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import tempfile
import threading
import webbrowser
from typing import Optional

Expand Down Expand Up @@ -36,11 +37,15 @@ def fit(
:param kwargs: Additional keyword arguments.
"""
dam_path = tempfile.mkdtemp()
stop_event = threading.Event()

class MyExecutor(FTExecutor):
def get_embed_model(self):
return embed_model

def get_stop_event(self):
return stop_event

f = (
Flow(
protocol='http',
Expand Down Expand Up @@ -70,6 +75,7 @@ def get_embed_model(self):
f.expose_endpoint(
'/save'
) #: for signaling the backend to save the current state of the model
f.expose_endpoint('/terminate') #: for terminating the flow from frontend

def extend_rest_function(app):
"""Allow FastAPI frontend to serve finetuner UI as a static webpage"""
Expand Down Expand Up @@ -107,4 +113,4 @@ def open_frontend_in_browser(req):
show_progress=True,
on_done=open_frontend_in_browser,
)
f.block()
f.block(stop_event=stop_event)
8 changes: 8 additions & 0 deletions finetuner/labeler/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ def __init__(
def get_embed_model(self):
...

@abc.abstractmethod
def get_stop_event(self):
...

@cached_property
def _embed_model(self):
return self.get_embed_model()
Expand Down Expand Up @@ -65,6 +69,10 @@ def save(self, parameters: Dict, **kwargs):
save(self._embed_model, model_path)
print(f'model is saved to {model_path}')

@requests(on='/terminate')
def terminate(self, **kwargs):
self.get_stop_event().set()


class DataIterator(Executor):
def __init__(
Expand Down
Binary file added finetuner/labeler/ui/favicon.ico
Binary file not shown.
2 changes: 1 addition & 1 deletion finetuner/labeler/ui/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
</div>
<sidebar :labeler-config="labeler_config" :view-template="view_template" :tags="tags" :is-busy="is_busy"
:progress-stats="progress_stats" :positive-rate="positive_rate" :negative-rate="negative_rate"
:advanced-config="advanced_config" :save-progress="saveProgress"
:advanced-config="advanced_config" :save-progress="saveProgress" :terminate-flow="terminateFlow"
:next-batch="next_batch"></sidebar>
<div class="b-example-divider"></div>
<div class="flex-grow-1 p-1 overflow-hidden">
Expand Down
7 changes: 6 additions & 1 deletion finetuner/labeler/ui/js/components/sidebar.vue.js
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ const sidebar = {
advancedConfig: Object,
saveProgress: Function,
nextBatch: Function,
terminateFlow: Function,
},
template: `
<div class="d-flex flex-column flex-shrink-0 p-3 sidebar">
Expand Down Expand Up @@ -124,10 +125,14 @@ const sidebar = {
</div>
</div>
<div class="my-3 d-flex justify-content-center">
<button class="btn btn btn-outline-primary"
<button class="btn btn btn-outline-primary m-2"
v-on:click="saveProgress()">
Save model
</button>
<button class="btn btn btn-outline-secondary m-2"
v-on:click="terminateFlow()">
Terminate
</button>
</div>
</div>
</div>
Expand Down
33 changes: 29 additions & 4 deletions finetuner/labeler/ui/js/main.js
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ const app = new Vue({
server_address: `http://localhost`,
next_endpoint: '/next',
fit_endpoint: '/fit',
saveEndpoint: '/save',
save_endpoint: '/save',
stop_endpoint: '/terminate',
},
advanced_config: {
pos_value: {text: 'Positive label', value: 1, type: 'number'},
Expand Down Expand Up @@ -80,8 +81,11 @@ const app = new Vue({
fit_address: function () {
return `${this.host_address}${this.general_config.fit_endpoint}`
},
saveAddress: function () {
return `${this.host_address}${this.general_config.saveEndpoint}`
save_address: function () {
return `${this.host_address}${this.general_config.save_endpoint}`
},
stop_address: function () {
return `${this.host_address}${this.general_config.stop_endpoint}`
},
positive_rate: function () {
return this.progress_stats.positive.value / (this.progress_stats.positive.value + this.progress_stats.negative.value) * 100
Expand Down Expand Up @@ -208,7 +212,7 @@ const app = new Vue({
app.is_conn_broken = false
$.ajax({
type: "POST",
url: app.saveAddress,
url: app.save_address,
data: JSON.stringify({
data: [],
parameters: {
Expand All @@ -225,6 +229,27 @@ const app = new Vue({
app.is_busy = false
});
},
terminateFlow: () => {
app.is_busy = true
app.is_conn_broken = false
$.ajax({
type: "POST",
url: app.stop_address,
data: JSON.stringify({
data: [],
parameters: {
}
}),
contentType: "application/json; charset=utf-8",
dataType: "json",
}).success(function (data, textStatus, jqXHR) {
app.is_busy = false
close();
}).fail(function () {
console.error("Error: ", error)
app.is_busy = false
});
},
handleKeyPress(event) {
let key = event.key
if (event.target instanceof HTMLInputElement) {
Expand Down
5 changes: 5 additions & 0 deletions finetuner/tuner/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ def __init__(self, *records: ScalarSequence):
"""Create a collection of summaries. """
self._records = [r for r in records if r]

def __iadd__(self, other: 'Summary'):
if isinstance(other, Summary):
self._records += other._records
return self

def save(self, filepath: str):
"""Store all summary into a JSON file"""
with open(filepath, 'w') as fp:
Expand Down
15 changes: 14 additions & 1 deletion tests/integration/labeler/test_tune_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,17 @@ def _run(framework_name, loss, port_expose):
),
}

fit(
rv1, rv2 = fit(
embed_models[framework_name](),
generate_qa_match(num_total=10, num_neg=0),
loss=loss,
interactive=True,
port_expose=port_expose,
)

assert rv1
assert not rv2


all_test_losses = [
'CosineSiameseLoss',
Expand Down Expand Up @@ -144,6 +147,16 @@ def test_all_frameworks(framework, loss, tmpdir):
assert req.status_code == 200
assert os.path.isfile(model_path)

req = requests.post(
f'http://localhost:{port}/terminate',
json={
'data': [],
'parameters': {},
},
)
assert req.status_code == 200
assert os.path.isfile(model_path)

except:
raise
finally:
Expand Down
28 changes: 26 additions & 2 deletions tests/integration/labeler/test_tune_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,21 +53,24 @@ def _run(framework_name, loss, port_expose):
),
}

fit(
rv1, rv2 = fit(
embed_models[framework_name](),
generate_fashion_match(num_total=10, num_pos=0, num_neg=0),
loss=loss,
interactive=True,
port_expose=port_expose,
)

assert rv1
assert not rv2


# 'keras' does not work under this test setup
# Exception ... ust be from the same graph as Tensor ...
# TODO: add keras backend back to the test
@pytest.mark.parametrize('framework', ['pytorch', 'paddle'])
@pytest.mark.parametrize('loss', all_test_losses)
def test_all_frameworks(framework, loss):
def test_all_frameworks(framework, loss, tmpdir):
port = random_port()
p = multiprocessing.Process(
target=_run,
Expand Down Expand Up @@ -125,6 +128,27 @@ def test_all_frameworks(framework, loss):
json={'data': rj['data']['docs'], 'parameters': {'epochs': 10}},
)
assert req.status_code == 200

model_path = os.path.join(tmpdir, 'model.train')
req = requests.post(
f'http://localhost:{port}/save',
json={
'data': [],
'parameters': {'model_path': model_path},
},
)
assert req.status_code == 200
assert os.path.isfile(model_path)

req = requests.post(
f'http://localhost:{port}/terminate',
json={
'data': [],
'parameters': {},
},
)
assert req.status_code == 200

except:
raise
finally:
Expand Down