Skip to content

Commit

Permalink
Jm/bugfix generate (#42)
Browse files Browse the repository at this point in the history
* Bugfix on generation, RC2 prep

* Use SCM for install

* test

* Revert to full list decoding for gen chars

* Remove parallelism kwarg

Co-authored-by: John Myers <john@gretel.ai>
  • Loading branch information
johntmyers and John Myers committed Aug 4, 2020
1 parent f7cf83b commit 4d63b79
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 23 deletions.
32 changes: 13 additions & 19 deletions src/gretel_synthetics/generator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import TYPE_CHECKING, Callable, List, Iterable, Optional, Tuple
from dataclasses import dataclass, asdict
from collections import namedtuple
from io import StringIO

import cloudpickle

Expand Down Expand Up @@ -221,8 +220,7 @@ def _predict_chars(
# Here batch size == 1
model.reset_states()

buf = StringIO()
buf_len = 0
sentence_ids = []

while True:
predictions = model(input_eval)
Expand All @@ -237,19 +235,15 @@ def _predict_chars(
# We pass the predicted word as the next input to the model
# along with the previous hidden state
input_eval = tf.expand_dims([predicted_id], 0)
next_id = int(predicted_id)

decoded = sp.DecodeIds([next_id])
if store.field_delimiter is not None and decoded == store.field_delimiter_token:
decoded = store.field_delimiter

if decoded == "<n>":
break

buf.write(decoded)
buf_len += len(decoded)

if 0 < store.gen_chars <= buf_len:
break

return PredString(buf.getvalue())
sentence_ids.append(int(predicted_id))

decoded = sp.DecodeIds(sentence_ids)
if store.field_delimiter is not None:
decoded = decoded.replace(
store.field_delimiter_token, store.field_delimiter
)

if "<n>" in decoded:
return PredString(decoded.replace("<n>", ""))
elif 0 < store.gen_chars <= len(decoded):
return PredString(decoded)
3 changes: 2 additions & 1 deletion src/gretel_synthetics/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import logging
from pathlib import Path
import shutil
from typing import Tuple

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -217,7 +218,7 @@ def _create_dataset(store: BaseConfig, text: str, sp: spm.SentencePieceProcessor
return dataset


def _split_input_target(chunk: str) -> (str, str):
def _split_input_target(chunk: str) -> Tuple[str, str]:
"""
For each sequence, duplicate and shift it to form the input and target text
by using the map method to apply a simple function to each batch:
Expand Down
8 changes: 5 additions & 3 deletions tests/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@ def random_cat():
@patch("tensorflow.random.categorical")
@patch("tensorflow.expand_dims")
def test_predict_chars(mock_dims, mock_cat, global_local_config, random_cat):
global_local_config.gen_chars = 0
global_local_config.gen_chars = 10
mock_model = Mock(return_value=[1.0])
mock_tensor = MagicMock()
mock_tensor[-1, 0].numpy.return_value = 1
mock_cat.return_value = mock_tensor

sp = Mock()
sp.DecodeIds.side_effect = ["this", " ", "is", " ", "the", " ", "end", "<n>"]
sp.DecodeIds.return_value = "this is the end<n>"
# sp.DecodeIds.side_effect = ["this", " ", "is", " ", "the", " ", "end", "<n>"]

line = _predict_chars(mock_model, sp, "\n", global_local_config)
assert line == PredString(data="this is the end")
Expand All @@ -33,7 +34,8 @@ def test_predict_chars(mock_dims, mock_cat, global_local_config, random_cat):
mock_cat.return_value = mock_tensor
global_local_config.gen_chars = 3
sp = Mock()
sp.DecodeIds.side_effect = ["a", "b", "c", "d"]
sp.DecodeIds.side_effect = ["a", "ab", "abc", "abcd"]
# sp.DecodeIds.side_effect = ["a", "b", "c", "d"]
line = _predict_chars(mock_model, sp, "\n", global_local_config)
assert line.data == "abc"

Expand Down

0 comments on commit 4d63b79

Please sign in to comment.