Skip to content

Commit

Permalink
docs: fix generate docs (#144)
Browse files Browse the repository at this point in the history
* docs: fix generate docs

* docs: added function signature
  • Loading branch information
maximilianwerk committed Oct 19, 2021
1 parent ac2d23d commit 177a78d
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 20 deletions.
28 changes: 14 additions & 14 deletions docs/components/tuner.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ Although siamese and triplet loss works on pair and triplet inputs respectively,

finetuner.fit(
embed_model,
train_data=lambda: generate_fashion_match(num_pos=10, num_neg=10),
eval_data=lambda: generate_fashion_match(num_pos=10, num_neg=10, is_testset=True)
train_data=generate_fashion_match(num_pos=10, num_neg=10),
eval_data=generate_fashion_match(num_pos=10, num_neg=10, is_testset=True)
)
```

Expand All @@ -127,16 +127,6 @@ Although siamese and triplet loss works on pair and triplet inputs respectively,

1. Write an embedding model:

````{tab} Keras
```python
import tensorflow as tf
embed_model = tf.keras.Sequential([
tf.keras.layers.Embedding(input_dim=5000, output_dim=64),
tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(64)),
tf.keras.layers.Dense(32)])
```
````

````{tab} PyTorch
```python
import torch
Expand All @@ -153,6 +143,16 @@ Although siamese and triplet loss works on pair and triplet inputs respectively,
```
````

````{tab} Keras
```python
import tensorflow as tf
embed_model = tf.keras.Sequential([
tf.keras.layers.Embedding(input_dim=5000, output_dim=64),
tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(64)),
tf.keras.layers.Dense(32)])
```
````

````{tab} Paddle
```python
import paddle
Expand Down Expand Up @@ -180,8 +180,8 @@ Although siamese and triplet loss works on pair and triplet inputs respectively,

finetuner.fit(
embed_model,
train_data=lambda: generate_qa_match(num_neg=5),
eval_data=lambda: generate_qa_match(num_neg=5)
train_data=generate_qa_match(num_neg=5),
eval_data=generate_qa_match(num_neg=5)
)
```

Expand Down
14 changes: 8 additions & 6 deletions finetuner/toydata.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
import base64
import copy
import csv
import gzip
import os
import urllib.request
from collections import defaultdict
from pathlib import Path
from typing import Optional, Generator
from typing import Optional, Generator, Tuple

import numpy as np
from jina import Document, DocumentArray
from jina.logging.profile import ProgressBar
from jina.types.document import png_to_buffer

from finetuner import __default_tag_key__
from finetuner.helper import DocumentArrayLike


def _text_to_word_sequence(
Expand Down Expand Up @@ -56,7 +56,7 @@ def _text_to_int_sequence(text, vocab, max_len=None):
return vec


def generate_qa_match(**kwargs):
def generate_qa_match(**kwargs) -> DocumentArrayLike:
return generate_qa_match_catalog(pre_init_generator=False, **kwargs)[0]


Expand All @@ -69,7 +69,7 @@ def generate_qa_match_catalog(
max_seq_len: int = 100,
is_testset: Optional[bool] = None,
pre_init_generator: bool = True,
) -> Generator[Document, None, None]:
) -> Tuple[DocumentArrayLike, DocumentArray]:
"""Get a generator of QA data with synthetic negative matches.
:param num_total: the total number of documents to return
Expand Down Expand Up @@ -155,7 +155,9 @@ def generator():
return generator, catalog


def generate_fashion_match(num_total=100, num_catalog=5000, **kwargs):
def generate_fashion_match(
num_total=100, num_catalog=5000, **kwargs
) -> DocumentArrayLike:
return generate_fashion_match_catalog(
num_total=num_total,
num_catalog=num_catalog,
Expand All @@ -176,7 +178,7 @@ def generate_fashion_match_catalog(
channel_axis: int = -1,
is_testset: bool = False,
pre_init_generator: bool = True,
) -> Generator[Document, None, None]:
) -> Tuple[DocumentArrayLike, DocumentArray]:
"""Get a Generator of fashion-mnist Documents with synthetic matches.
:param num_total: the total number of documents to return
Expand Down

0 comments on commit 177a78d

Please sign in to comment.