Skip to content

Commit

Permalink
refactor(tuner): revert some catalog change before release (#150)
Browse files Browse the repository at this point in the history
* refactor(tuner): move catalog data from constructor

* refactor(tuner): refactor logger and stats into summary
  • Loading branch information
hanxiao committed Oct 19, 2021
1 parent 635cd4c commit 2916e9f
Show file tree
Hide file tree
Showing 35 changed files with 447 additions and 740 deletions.
25 changes: 10 additions & 15 deletions docs/basics/data-format.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,16 @@
Finetuner uses Jina [`Document`](https://docs.jina.ai/fundamentals/document/) as the primitive data type. In
particular, [`DocumentArray`](https://docs.jina.ai/fundamentals/document/documentarray-api/)
and [`DocumentArrayMemap`](https://docs.jina.ai/fundamentals/document/documentarraymemmap-api/) are the input data type
for Tailor and Tuner. This means, your training dataset and evaluation dataset should be stored in `DocumentArray`
or `DocumentArrayMemap`, where each training or evaluation instance is a `Document` object.
in the high-level `finetuner.fit()` API. This means, your training dataset and evaluation dataset should be stored in `DocumentArray`
or `DocumentArrayMemap`, where each training or evaluation instance is a `Document` object:

```python
import finetuner

finetuner.fit(model,
train_data=...,
eval_data=...)
```

This chapter introduces how to construct a `Document` in a way that Finetuner will accept.

Expand Down Expand Up @@ -137,19 +145,6 @@ Yes. Labels should reflect the groundtruth as-is. If a Document contains only po
However, if all match labels from all Documents are the same, then Finetuner cannot learn anything useful.
```

### Catalog

In search, queries and search results are often distinct sets.
Specifying a `catalog` helps you keep this distinction during finetuning.
When using `finetuner.fit(train_data=...,eval_data=..., catalog=...)`, `train_data` and `eval_data` specify the potential queries and the `catalog` specifies the potential results.
This distinction is mainly used

- in the Labeler, when new sets of unlabeled results are generated and
- during evaluation, for the NDCG calculation.

A `catalog` is either a `DocumentArray` or a `DocumentArrayMemmap`.
If no `catalog` is specified, the Finetuner will implicitly use `train_data` as catalog.

## Data source

After organizing the labeled `Document` into `DocumentArray` or `DocumentArrayMemmap`, you can feed them
Expand Down
9 changes: 5 additions & 4 deletions finetuner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from typing import Dict, Optional, overload, TYPE_CHECKING, Tuple

if TYPE_CHECKING:
from .helper import AnyDNN, DocumentArrayLike, TunerReturnType
from .helper import AnyDNN, DocumentArrayLike
from .tuner.summary import SummaryCollection


# fit interface generated from Tuner
Expand All @@ -25,7 +26,7 @@ def fit(
optimizer: str = 'adam',
optimizer_kwargs: Optional[Dict] = None,
device: str = 'cpu',
) -> 'TunerReturnType':
) -> 'SummaryCollection':
...


Expand All @@ -48,7 +49,7 @@ def fit(
output_dim: Optional[int] = None,
freeze: bool = False,
device: str = 'cpu',
) -> 'TunerReturnType':
) -> 'SummaryCollection':
...


Expand Down Expand Up @@ -96,7 +97,7 @@ def fit(

def fit(
model: 'AnyDNN', train_data: 'DocumentArrayLike', *args, **kwargs
) -> Optional['TunerReturnType']:
) -> Optional['SummaryCollection']:
if kwargs.get('to_embedding_model', False):
from .tailor import to_embedding_model

Expand Down
3 changes: 0 additions & 3 deletions finetuner/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,6 @@
LayerInfoType = List[
Dict[str, Any]
] #: The type of embedding layer information used in Tailor
TunerReturnType = Dict[
str, Dict[str, Any]
] #: The type of loss, metric information Tuner returns


def get_framework(dnn_model: AnyDNN) -> str:
Expand Down
23 changes: 3 additions & 20 deletions finetuner/labeler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Optional

import jina.helper
from jina import Flow, DocumentArrayMemmap
from jina import Flow
from jina.logging.predefined import default_logger

from .executor import FTExecutor, DataIterator
Expand All @@ -14,15 +14,13 @@
def fit(
embed_model: AnyDNN,
train_data: DocumentArrayLike,
catalog: Optional[DocumentArrayLike] = None,
clear_labels_on_start: bool = False,
port_expose: Optional[int] = None,
runtime_backend: str = 'thread',
loss: str = 'CosineSiameseLoss',
**kwargs,
) -> None:
dam_path = tempfile.mkdtemp()
catalog_dam_path = init_catalog(dam_path, catalog, train_data)

class MyExecutor(FTExecutor):
def get_embed_model(self):
Expand All @@ -39,14 +37,13 @@ def get_embed_model(self):
uses=DataIterator,
uses_with={
'dam_path': dam_path,
'catalog_dam_path': catalog_dam_path,
'clear_labels_on_start': clear_labels_on_start,
},
)
.add(
uses=MyExecutor,
uses_with={
'catalog_dam_path': catalog_dam_path,
'dam_path': dam_path,
'loss': loss,
},
)
Expand Down Expand Up @@ -91,22 +88,8 @@ def open_frontend_in_browser(req):
f.post(
'/feed',
train_data,
request_size=128,
request_size=10,
show_progress=True,
on_done=open_frontend_in_browser,
)
f.block()


def init_catalog(
dam_path: str, catalog: DocumentArrayLike, train_data: DocumentArrayLike
):
if isinstance(catalog, DocumentArrayMemmap):
catalog_dam_path = catalog.path
else:
catalog_dam_path = dam_path + '/catalog'
catalog_memmap = DocumentArrayMemmap(catalog_dam_path)
if catalog is None:
catalog = train_data() if callable(train_data) else train_data
catalog_memmap.extend(catalog)
return catalog_dam_path
30 changes: 11 additions & 19 deletions finetuner/labeler/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@
class FTExecutor(Executor):
def __init__(
self,
catalog_dam_path: str,
dam_path: str,
metric: str = 'cosine',
loss: str = 'CosineSiameseLoss',
**kwargs,
):
super().__init__(**kwargs)
self._catalog = DocumentArrayMemmap(catalog_dam_path)
self._all_data = DocumentArrayMemmap(dam_path)
self._metric = metric
self._loss = loss

Expand All @@ -33,9 +33,9 @@ def _embed_model(self):
def embed(self, docs: DocumentArray, parameters: Dict, **kwargs):
if not docs:
return
self._catalog.reload()
da = self._catalog.sample(
min(len(self._catalog), int(parameters.get('sample_size', 1000)))
self._all_data.reload()
da = self._all_data.sample(
min(len(self._all_data), int(parameters.get('sample_size', 1000)))
)

f_type = get_framework(self._embed_model)
Expand Down Expand Up @@ -77,7 +77,6 @@ def fit(self, docs: DocumentArray, parameters: Dict, **kwargs):
fit(
self._embed_model,
docs,
self._catalog,
epochs=int(parameters.get('epochs', 10)),
loss=self._loss,
)
Expand All @@ -93,40 +92,33 @@ class DataIterator(Executor):
def __init__(
self,
dam_path: str,
catalog_dam_path: str,
labeled_dam_path: Optional[str] = None,
clear_labels_on_start: bool = False,
**kwargs,
):
super().__init__(**kwargs)
self._all_data = DocumentArrayMemmap(dam_path)
self._catalog = DocumentArrayMemmap(catalog_dam_path)
if not labeled_dam_path:
labeled_dam_path = dam_path + '/labeled'
self._labeled_dam = DocumentArrayMemmap(labeled_dam_path)
if clear_labels_on_start:
self._labeled_dam.clear()

@requests(on='/feed')
def store_data(self, docs: DocumentArray, parameters: Dict, **kwargs):
if parameters.get('type', 'query') == 'query':
self._all_data.extend(docs)
else:
self._catalog.extend(docs)
def store_data(self, docs: DocumentArray, **kwargs):
self._all_data.extend(docs)

@requests(on='/next')
def take_batch(self, parameters: Dict, **kwargs):
count = int(parameters.get('new_examples', 5))
st = int(parameters.get('start', 0))
ed = int(parameters.get('end', 1))

self._all_data.reload()
count = min(max(count, 0), len(self._all_data))
return self._all_data.sample(k=count)
return self._all_data[st:ed]

@requests(on='/fit')
def add_fit_data(self, docs: DocumentArray, **kwargs):
for d in docs.traverse_flat(['r']):
for d in docs.traverse_flat(['r', 'm']):
d.content = self._all_data[d.id].content
for d in docs.traverse_flat(['m']):
d.content = self._catalog[d.id].content
self._labeled_dam.extend(docs)
return self._labeled_dam
2 changes: 1 addition & 1 deletion finetuner/labeler/ui/js/components/image-match-card.vue.js
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ const imageMatchCard = {
template: `
<div class="card image-card">
<div class="card-header">
<p class="fs-6 fw-light mb-2">Select all images similar to the image on right</p>
<p class="fs-6 fw-light mb-2 hint-text">Select all images similar to the image on right</p>
<img v-bind:src="getContent(doc)" class="img-thumbnail img-fluid my-2">
</div>
<div class="card-body">
Expand Down
2 changes: 1 addition & 1 deletion finetuner/labeler/ui/js/components/mesh-match-card.vue.js
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ const meshMatchCard = {
template: `
<div class="card mesh-card">
<div class="card-header">
<p class="fs-6 fw-light mb-2">Select all images similar to the image on right</p>
<p class="fs-6 fw-light mb-2 hint-text">Select all meshes similar to the image on right</p>
<model-viewer
v-bind:src="getContent(doc)"
v-on:click="toggleRelevance(match)"
Expand Down
7 changes: 4 additions & 3 deletions finetuner/labeler/ui/js/main.js
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ const app = new Vue({
grabCursor: false,
centeredSlides: true,
slidesPerView: 3,
allowTouchMove: false,
keyboard: {
enabled: true,
},
Expand Down Expand Up @@ -156,12 +157,11 @@ const app = new Vue({
},
next_batch: function () {
let end_idx = app.labeler_config.start_idx + (app.labeler_config.example_per_view - app.cur_batch.length)
if (end_idx <= app.labeler_config.start_idx) {
if (end_idx === app.labeler_config.start_idx) {
return
}
let start_idx = app.labeler_config.start_idx
app.labeler_config.start_idx = end_idx
let new_examples = end_idx - start_idx
app.is_busy = true
app.is_conn_broken = false
$.ajax({
Expand All @@ -170,7 +170,8 @@ const app = new Vue({
data: JSON.stringify({
data: [],
parameters: {
'new_examples': new_examples,
'start': start_idx,
'end': end_idx,
'topk': app.labeler_config.topk_per_example,
'sample_size': app.advanced_config.sample_size.value
}
Expand Down
33 changes: 21 additions & 12 deletions finetuner/labeler/ui/main.css
Original file line number Diff line number Diff line change
Expand Up @@ -238,24 +238,31 @@ footer a {
position: absolute;
top: 0;
right: 0;
margin-left: .8rem;
line-height: 0.8rem;
padding: .2rem .4rem;
color: #000000;
background-color: #fff;
border-radius: 1px;
box-shadow: 0 0 0 4px #d9d9d9,
2px 2.5px 4px #adb5bd,
0 -1px 2.5px #adb5bd;
cursor: pointer;

border: 1px solid gray;
font-size: 1.2em;
box-shadow: 1px 0 1px 0 #eee, 0 2px 0 2px #ccc, 0 2px 0 3px #444;
-webkit-border-radius: 3px;
-moz-border-radius: 3px;
border-radius: 3px;
padding: 2px 2px;
max-height: 1em;
font-family: monospace;
min-width: 1em;
}

.btn .kbd {
position: relative;
top: -10px;
right: -10px;
}

.text-card .kbd {
position: relative;
.hint-text {
max-width: 50%;
}

.text-card .positive-match::before {
Expand Down Expand Up @@ -346,14 +353,16 @@ footer a {

.image-card .card-header .img-thumbnail {
margin-left: .25rem;
min-height: 20vh;
width: auto;
width: 20vh;
height: 20vh;
object-fit: contain;
padding: 0;
}

.image-card .card-body .img-thumbnail {
width: auto;
max-height: 20vh;
width: 20vh;
height: 20vh;
object-fit: contain;
padding: 0;
}

Expand Down
Loading

0 comments on commit 2916e9f

Please sign in to comment.