Skip to content

Commit

Permalink
feat(labeler): gently terminate the labler UI from frontend (#177)
Browse files Browse the repository at this point in the history
* feat(labeler): gently terminate the labler UI from frontend

* feat(labeler): gently terminate the labler UI from frontend
  • Loading branch information
hanxiao committed Oct 26, 2021
1 parent 40261d4 commit df19264
Show file tree
Hide file tree
Showing 11 changed files with 101 additions and 17 deletions.
Binary file added docs/_static/favicon.ico
Binary file not shown.
12 changes: 5 additions & 7 deletions finetuner/__init__.py
Expand Up @@ -26,7 +26,7 @@ def fit(
optimizer: str = 'adam',
optimizer_kwargs: Optional[Dict] = None,
device: str = 'cpu',
) -> 'Summary':
) -> Tuple['AnyDNN', 'Summary']:
...


Expand All @@ -49,7 +49,7 @@ def fit(
output_dim: Optional[int] = None,
freeze: bool = False,
device: str = 'cpu',
) -> 'Summary':
) -> Tuple['AnyDNN', 'Summary']:
...


Expand All @@ -67,7 +67,7 @@ def fit(
optimizer: str = 'adam',
optimizer_kwargs: Optional[Dict] = None,
device: str = 'cpu',
) -> None:
) -> Tuple['AnyDNN', 'Summary']:
...


Expand All @@ -91,13 +91,13 @@ def fit(
output_dim: Optional[int] = None,
freeze: bool = False,
device: str = 'cpu',
) -> None:
) -> Tuple['AnyDNN', 'Summary']:
...


def fit(
model: 'AnyDNN', train_data: 'DocumentArrayLike', *args, **kwargs
) -> Optional[Tuple['AnyDNN', 'Summary']]:
) -> Tuple['AnyDNN', Optional['Summary']]:
if kwargs.get('to_embedding_model', False):
from .tailor import to_embedding_model

Expand All @@ -106,8 +106,6 @@ def fit(
if kwargs.get('interactive', False):
from .labeler import fit

# TODO: atm return will never hit as labeler UI hangs the
# flow via `.block()`
return model, fit(model, train_data, *args, **kwargs)
else:
from .tuner import fit
Expand Down
8 changes: 7 additions & 1 deletion finetuner/labeler/__init__.py
@@ -1,5 +1,6 @@
import os
import tempfile
import threading
import webbrowser
from typing import Optional

Expand Down Expand Up @@ -36,11 +37,15 @@ def fit(
:param kwargs: Additional keyword arguments.
"""
dam_path = tempfile.mkdtemp()
stop_event = threading.Event()

class MyExecutor(FTExecutor):
def get_embed_model(self):
return embed_model

def get_stop_event(self):
return stop_event

f = (
Flow(
protocol='http',
Expand Down Expand Up @@ -70,6 +75,7 @@ def get_embed_model(self):
f.expose_endpoint(
'/save'
) #: for signaling the backend to save the current state of the model
f.expose_endpoint('/terminate') #: for terminating the flow from frontend

def extend_rest_function(app):
"""Allow FastAPI frontend to serve finetuner UI as a static webpage"""
Expand Down Expand Up @@ -107,4 +113,4 @@ def open_frontend_in_browser(req):
show_progress=True,
on_done=open_frontend_in_browser,
)
f.block()
f.block(stop_event=stop_event)
8 changes: 8 additions & 0 deletions finetuner/labeler/executor.py
Expand Up @@ -25,6 +25,10 @@ def __init__(
def get_embed_model(self):
...

@abc.abstractmethod
def get_stop_event(self):
...

@cached_property
def _embed_model(self):
return self.get_embed_model()
Expand Down Expand Up @@ -65,6 +69,10 @@ def save(self, parameters: Dict, **kwargs):
save(self._embed_model, model_path)
print(f'model is saved to {model_path}')

@requests(on='/terminate')
def terminate(self, **kwargs):
self.get_stop_event().set()


class DataIterator(Executor):
def __init__(
Expand Down
Binary file added finetuner/labeler/ui/favicon.ico
Binary file not shown.
2 changes: 1 addition & 1 deletion finetuner/labeler/ui/index.html
Expand Up @@ -22,7 +22,7 @@
</div>
<sidebar :labeler-config="labeler_config" :view-template="view_template" :tags="tags" :is-busy="is_busy"
:progress-stats="progress_stats" :positive-rate="positive_rate" :negative-rate="negative_rate"
:advanced-config="advanced_config" :save-progress="saveProgress"
:advanced-config="advanced_config" :save-progress="saveProgress" :terminate-flow="terminateFlow"
:next-batch="next_batch"></sidebar>
<div class="b-example-divider"></div>
<div class="flex-grow-1 p-1 overflow-hidden">
Expand Down
7 changes: 6 additions & 1 deletion finetuner/labeler/ui/js/components/sidebar.vue.js
Expand Up @@ -10,6 +10,7 @@ const sidebar = {
advancedConfig: Object,
saveProgress: Function,
nextBatch: Function,
terminateFlow: Function,
},
template: `
<div class="d-flex flex-column flex-shrink-0 p-3 sidebar">
Expand Down Expand Up @@ -124,10 +125,14 @@ const sidebar = {
</div>
</div>
<div class="my-3 d-flex justify-content-center">
<button class="btn btn btn-outline-primary"
<button class="btn btn btn-outline-primary m-2"
v-on:click="saveProgress()">
Save model
</button>
<button class="btn btn btn-outline-secondary m-2"
v-on:click="terminateFlow()">
Terminate
</button>
</div>
</div>
</div>
Expand Down
33 changes: 29 additions & 4 deletions finetuner/labeler/ui/js/main.js
Expand Up @@ -41,7 +41,8 @@ const app = new Vue({
server_address: `http://localhost`,
next_endpoint: '/next',
fit_endpoint: '/fit',
saveEndpoint: '/save',
save_endpoint: '/save',
stop_endpoint: '/terminate',
},
advanced_config: {
pos_value: {text: 'Positive label', value: 1, type: 'number'},
Expand Down Expand Up @@ -80,8 +81,11 @@ const app = new Vue({
fit_address: function () {
return `${this.host_address}${this.general_config.fit_endpoint}`
},
saveAddress: function () {
return `${this.host_address}${this.general_config.saveEndpoint}`
save_address: function () {
return `${this.host_address}${this.general_config.save_endpoint}`
},
stop_address: function () {
return `${this.host_address}${this.general_config.stop_endpoint}`
},
positive_rate: function () {
return this.progress_stats.positive.value / (this.progress_stats.positive.value + this.progress_stats.negative.value) * 100
Expand Down Expand Up @@ -208,7 +212,7 @@ const app = new Vue({
app.is_conn_broken = false
$.ajax({
type: "POST",
url: app.saveAddress,
url: app.save_address,
data: JSON.stringify({
data: [],
parameters: {
Expand All @@ -225,6 +229,27 @@ const app = new Vue({
app.is_busy = false
});
},
terminateFlow: () => {
app.is_busy = true
app.is_conn_broken = false
$.ajax({
type: "POST",
url: app.stop_address,
data: JSON.stringify({
data: [],
parameters: {
}
}),
contentType: "application/json; charset=utf-8",
dataType: "json",
}).success(function (data, textStatus, jqXHR) {
app.is_busy = false
close();
}).fail(function () {
console.error("Error: ", error)
app.is_busy = false
});
},
handleKeyPress(event) {
let key = event.key
if (event.target instanceof HTMLInputElement) {
Expand Down
5 changes: 5 additions & 0 deletions finetuner/tuner/summary.py
Expand Up @@ -48,6 +48,11 @@ def __init__(self, *records: ScalarSequence):
"""Create a collection of summaries. """
self._records = [r for r in records if r]

def __iadd__(self, other: 'Summary'):
if isinstance(other, Summary):
self._records += other._records
return self

def save(self, filepath: str):
"""Store all summary into a JSON file"""
with open(filepath, 'w') as fp:
Expand Down
15 changes: 14 additions & 1 deletion tests/integration/labeler/test_tune_lstm.py
Expand Up @@ -55,14 +55,17 @@ def _run(framework_name, loss, port_expose):
),
}

fit(
rv1, rv2 = fit(
embed_models[framework_name](),
generate_qa_match(num_total=10, num_neg=0),
loss=loss,
interactive=True,
port_expose=port_expose,
)

assert rv1
assert not rv2


all_test_losses = [
'CosineSiameseLoss',
Expand Down Expand Up @@ -144,6 +147,16 @@ def test_all_frameworks(framework, loss, tmpdir):
assert req.status_code == 200
assert os.path.isfile(model_path)

req = requests.post(
f'http://localhost:{port}/terminate',
json={
'data': [],
'parameters': {},
},
)
assert req.status_code == 200
assert os.path.isfile(model_path)

except:
raise
finally:
Expand Down
28 changes: 26 additions & 2 deletions tests/integration/labeler/test_tune_mlp.py
Expand Up @@ -53,21 +53,24 @@ def _run(framework_name, loss, port_expose):
),
}

fit(
rv1, rv2 = fit(
embed_models[framework_name](),
generate_fashion_match(num_total=10, num_pos=0, num_neg=0),
loss=loss,
interactive=True,
port_expose=port_expose,
)

assert rv1
assert not rv2


# 'keras' does not work under this test setup
# Exception ... ust be from the same graph as Tensor ...
# TODO: add keras backend back to the test
@pytest.mark.parametrize('framework', ['pytorch', 'paddle'])
@pytest.mark.parametrize('loss', all_test_losses)
def test_all_frameworks(framework, loss):
def test_all_frameworks(framework, loss, tmpdir):
port = random_port()
p = multiprocessing.Process(
target=_run,
Expand Down Expand Up @@ -125,6 +128,27 @@ def test_all_frameworks(framework, loss):
json={'data': rj['data']['docs'], 'parameters': {'epochs': 10}},
)
assert req.status_code == 200

model_path = os.path.join(tmpdir, 'model.train')
req = requests.post(
f'http://localhost:{port}/save',
json={
'data': [],
'parameters': {'model_path': model_path},
},
)
assert req.status_code == 200
assert os.path.isfile(model_path)

req = requests.post(
f'http://localhost:{port}/terminate',
json={
'data': [],
'parameters': {},
},
)
assert req.status_code == 200

except:
raise
finally:
Expand Down

0 comments on commit df19264

Please sign in to comment.