From 177a78ddb5097cd3df7f4f1dc119ef101a1816ea Mon Sep 17 00:00:00 2001 From: Maximilian Werk Date: Tue, 19 Oct 2021 16:14:41 +0200 Subject: [PATCH] docs: fix generate docs (#144) * docs: fix generate docs * docs: added function signature --- docs/components/tuner.md | 28 ++++++++++++++-------------- finetuner/toydata.py | 14 ++++++++------ 2 files changed, 22 insertions(+), 20 deletions(-) diff --git a/docs/components/tuner.md b/docs/components/tuner.md index 388e7c6f2..b2ab72221 100644 --- a/docs/components/tuner.md +++ b/docs/components/tuner.md @@ -113,8 +113,8 @@ Although siamese and triplet loss works on pair and triplet inputs respectively, finetuner.fit( embed_model, - train_data=lambda: generate_fashion_match(num_pos=10, num_neg=10), - eval_data=lambda: generate_fashion_match(num_pos=10, num_neg=10, is_testset=True) + train_data=generate_fashion_match(num_pos=10, num_neg=10), + eval_data=generate_fashion_match(num_pos=10, num_neg=10, is_testset=True) ) ``` @@ -127,16 +127,6 @@ Although siamese and triplet loss works on pair and triplet inputs respectively, 1. Write an embedding model: - ````{tab} Keras - ```python - import tensorflow as tf - embed_model = tf.keras.Sequential([ - tf.keras.layers.Embedding(input_dim=5000, output_dim=64), - tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(64)), - tf.keras.layers.Dense(32)]) - ``` - ```` - ````{tab} PyTorch ```python import torch @@ -153,6 +143,16 @@ Although siamese and triplet loss works on pair and triplet inputs respectively, ``` ```` + ````{tab} Keras + ```python + import tensorflow as tf + embed_model = tf.keras.Sequential([ + tf.keras.layers.Embedding(input_dim=5000, output_dim=64), + tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(64)), + tf.keras.layers.Dense(32)]) + ``` + ```` + ````{tab} Paddle ```python import paddle @@ -180,8 +180,8 @@ Although siamese and triplet loss works on pair and triplet inputs respectively, finetuner.fit( embed_model, - train_data=lambda: generate_qa_match(num_neg=5), - eval_data=lambda: generate_qa_match(num_neg=5) + train_data=generate_qa_match(num_neg=5), + eval_data=generate_qa_match(num_neg=5) ) ``` diff --git a/finetuner/toydata.py b/finetuner/toydata.py index f7da9a95c..55a14af8f 100644 --- a/finetuner/toydata.py +++ b/finetuner/toydata.py @@ -1,12 +1,11 @@ import base64 -import copy import csv import gzip import os import urllib.request from collections import defaultdict from pathlib import Path -from typing import Optional, Generator +from typing import Optional, Generator, Tuple import numpy as np from jina import Document, DocumentArray @@ -14,6 +13,7 @@ from jina.types.document import png_to_buffer from finetuner import __default_tag_key__ +from finetuner.helper import DocumentArrayLike def _text_to_word_sequence( @@ -56,7 +56,7 @@ def _text_to_int_sequence(text, vocab, max_len=None): return vec -def generate_qa_match(**kwargs): +def generate_qa_match(**kwargs) -> DocumentArrayLike: return generate_qa_match_catalog(pre_init_generator=False, **kwargs)[0] @@ -69,7 +69,7 @@ def generate_qa_match_catalog( max_seq_len: int = 100, is_testset: Optional[bool] = None, pre_init_generator: bool = True, -) -> Generator[Document, None, None]: +) -> Tuple[DocumentArrayLike, DocumentArray]: """Get a generator of QA data with synthetic negative matches. :param num_total: the total number of documents to return @@ -155,7 +155,9 @@ def generator(): return generator, catalog -def generate_fashion_match(num_total=100, num_catalog=5000, **kwargs): +def generate_fashion_match( + num_total=100, num_catalog=5000, **kwargs +) -> DocumentArrayLike: return generate_fashion_match_catalog( num_total=num_total, num_catalog=num_catalog, @@ -176,7 +178,7 @@ def generate_fashion_match_catalog( channel_axis: int = -1, is_testset: bool = False, pre_init_generator: bool = True, -) -> Generator[Document, None, None]: +) -> Tuple[DocumentArrayLike, DocumentArray]: """Get a Generator of fashion-mnist Documents with synthetic matches. :param num_total: the total number of documents to return