Skip to content

Commit

Permalink
Introduce catalog + ndcg (#120)
Browse files Browse the repository at this point in the history
* feat: add ndcg calc

* feat: added hits metric and refactoring

* feat: fix ndcg

* feat: gpu support for eval

* feat: add gpu for eval

* fix: sample size

* feat: labeler working

* feat: paddle and torch

* fix: train data callable

* test: fixed

* feat: fmnist with catalog

* feat: used dam for catalog

* refactor: qa toy data

* fix: gpu tests

* test: fix test size

* test: fix wrong arg

* test: speed up test data generation

* feat: removed train metrics

* fix: next only called when needed

* feat: restored old toy data generation behavior
  • Loading branch information
maximilianwerk committed Oct 19, 2021
1 parent b0da1bf commit 0be69a4
Show file tree
Hide file tree
Showing 32 changed files with 613 additions and 320 deletions.
2 changes: 1 addition & 1 deletion docs/get-started/covid-qa.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ import finetuner
finetuner.fit(
embed_model,
train_data=generate_qa_match,
train_data=generate_qa_match(),
interactive=True)
```
Expand Down
2 changes: 1 addition & 1 deletion docs/get-started/fashion-mnist.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ import finetuner

finetuner.fit(
embed_model,
train_data=generate_fashion_match,
train_data=generate_fashion_match(),
interactive=True)
```

Expand Down
23 changes: 20 additions & 3 deletions finetuner/labeler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Optional

import jina.helper
from jina import Flow
from jina import Flow, DocumentArrayMemmap
from jina.logging.predefined import default_logger

from .executor import FTExecutor, DataIterator
Expand All @@ -14,13 +14,15 @@
def fit(
embed_model: AnyDNN,
train_data: DocumentArrayLike,
catalog: Optional[DocumentArrayLike] = None,
clear_labels_on_start: bool = False,
port_expose: Optional[int] = None,
runtime_backend: str = 'thread',
loss: str = 'CosineSiameseLoss',
**kwargs,
) -> None:
dam_path = tempfile.mkdtemp()
catalog_dam_path = init_catalog(dam_path, catalog, train_data)

class MyExecutor(FTExecutor):
def get_embed_model(self):
Expand All @@ -37,13 +39,14 @@ def get_embed_model(self):
uses=DataIterator,
uses_with={
'dam_path': dam_path,
'catalog_dam_path': catalog_dam_path,
'clear_labels_on_start': clear_labels_on_start,
},
)
.add(
uses=MyExecutor,
uses_with={
'dam_path': dam_path,
'catalog_dam_path': catalog_dam_path,
'loss': loss,
},
)
Expand Down Expand Up @@ -88,8 +91,22 @@ def open_frontend_in_browser(req):
f.post(
'/feed',
train_data,
request_size=10,
request_size=128,
show_progress=True,
on_done=open_frontend_in_browser,
)
f.block()


def init_catalog(
dam_path: str, catalog: DocumentArrayLike, train_data: DocumentArrayLike
):
if isinstance(catalog, DocumentArrayMemmap):
catalog_dam_path = catalog.path
else:
catalog_dam_path = dam_path + '/catalog'
catalog_memmap = DocumentArrayMemmap(catalog_dam_path)
if catalog is None:
catalog = train_data() if callable(train_data) else train_data
catalog_memmap.extend(catalog)
return catalog_dam_path
30 changes: 19 additions & 11 deletions finetuner/labeler/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@
class FTExecutor(Executor):
def __init__(
self,
dam_path: str,
catalog_dam_path: str,
metric: str = 'cosine',
loss: str = 'CosineSiameseLoss',
**kwargs,
):
super().__init__(**kwargs)
self._all_data = DocumentArrayMemmap(dam_path)
self._catalog = DocumentArrayMemmap(catalog_dam_path)
self._metric = metric
self._loss = loss

Expand All @@ -33,9 +33,9 @@ def _embed_model(self):
def embed(self, docs: DocumentArray, parameters: Dict, **kwargs):
if not docs:
return
self._all_data.reload()
da = self._all_data.sample(
min(len(self._all_data), int(parameters.get('sample_size', 1000)))
self._catalog.reload()
da = self._catalog.sample(
min(len(self._catalog), int(parameters.get('sample_size', 1000)))
)

f_type = get_framework(self._embed_model)
Expand Down Expand Up @@ -76,6 +76,7 @@ def fit(self, docs, parameters: Dict, **kwargs):
fit(
self._embed_model,
docs,
self._catalog,
epochs=int(parameters.get('epochs', 10)),
loss=self._loss,
)
Expand All @@ -91,33 +92,40 @@ class DataIterator(Executor):
def __init__(
self,
dam_path: str,
catalog_dam_path: str,
labeled_dam_path: Optional[str] = None,
clear_labels_on_start: bool = False,
**kwargs,
):
super().__init__(**kwargs)
self._all_data = DocumentArrayMemmap(dam_path)
self._catalog = DocumentArrayMemmap(catalog_dam_path)
if not labeled_dam_path:
labeled_dam_path = dam_path + '/labeled'
self._labeled_dam = DocumentArrayMemmap(labeled_dam_path)
if clear_labels_on_start:
self._labeled_dam.clear()

@requests(on='/feed')
def store_data(self, docs: DocumentArray, **kwargs):
self._all_data.extend(docs)
def store_data(self, docs: DocumentArray, parameters: Dict, **kwargs):
if parameters.get('type', 'query') == 'query':
self._all_data.extend(docs)
else:
self._catalog.extend(docs)

@requests(on='/next')
def take_batch(self, parameters: Dict, **kwargs):
st = int(parameters.get('start', 0))
ed = int(parameters.get('end', 1))
count = int(parameters.get('new_examples', 5))

self._all_data.reload()
return self._all_data[st:ed]
count = min(max(count, 0), len(self._all_data))
return self._all_data.sample(k=count)

@requests(on='/fit')
def add_fit_data(self, docs: DocumentArray, **kwargs):
for d in docs.traverse_flat(['r', 'm']):
for d in docs.traverse_flat(['r']):
d.content = self._all_data[d.id].content
for d in docs.traverse_flat(['m']):
d.content = self._catalog[d.id].content
self._labeled_dam.extend(docs)
return self._labeled_dam
8 changes: 4 additions & 4 deletions finetuner/labeler/ui/js/main.js
Original file line number Diff line number Diff line change
Expand Up @@ -156,11 +156,12 @@ const app = new Vue({
},
next_batch: function () {
let end_idx = app.labeler_config.start_idx + (app.labeler_config.example_per_view - app.cur_batch.length)
if (end_idx === app.labeler_config.start_idx) {
if (end_idx <= app.labeler_config.start_idx) {
return
}
let start_idx = app.labeler_config.start_idx
app.labeler_config.start_idx = end_idx
let new_examples = end_idx - start_idx
app.is_busy = true
app.is_conn_broken = false
$.ajax({
Expand All @@ -169,8 +170,7 @@ const app = new Vue({
data: JSON.stringify({
data: [],
parameters: {
'start': start_idx,
'end': end_idx,
'new_examples': new_examples,
'topk': app.labeler_config.topk_per_example,
'sample_size': app.advanced_config.sample_size.value
}
Expand Down Expand Up @@ -243,4 +243,4 @@ const app = new Vue({

Vue.nextTick(function () {
app.next_batch()
})
})
Loading

0 comments on commit 0be69a4

Please sign in to comment.