Skip to content

Commit

Permalink
Add a test case for loading a legacy model config (#68)
Browse files Browse the repository at this point in the history
  • Loading branch information
misberner committed Nov 13, 2020
1 parent bd18546 commit 5a58a06
Show file tree
Hide file tree
Showing 8 changed files with 92 additions and 4 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/integration-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
python-version: 3.6
- name: Install
run:
pip install -e .
pip install -e '.[tf]'
pip install -r test-requirements.txt
- name: Test
run: pytest -s -vv --cov src --cov-report term-missing tests-integration/
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/unit-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
python-version: 3.6
- name: Install
run:
pip install -e .
pip install -e '.[tf]'
pip install -r test-requirements.txt
- name: Lint
run: |
Expand Down
2 changes: 1 addition & 1 deletion src/gretel_synthetics/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ class TensorFlowConfig(BaseConfig):
def __post_init__(self):
if self.dp:
major, minor, micro = tf.__version__.split(".")
if int(minor) < 4 and int(major) >= 2:
if (int(major), int(minor)) < (2, 4):
raise RuntimeError(
"Running in differential privacy mode requires TensorFlow 2.4.x or greater. "
"Please see the README for details"
Expand Down
1 change: 1 addition & 0 deletions src/gretel_synthetics/tensorflow/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ def _create_dataset(
return total_token_count, dataset


@tf.autograph.experimental.do_not_convert
def _split_input_target(chunk: str) -> Tuple[str, str]:
"""
For each sequence, duplicate and shift it to form the input and target text
Expand Down
1 change: 0 additions & 1 deletion test-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
tensorflow==2.3.1
flake8
pytest
pytest-cov
Expand Down
34 changes: 34 additions & 0 deletions tests/data/0.14.x/dp-model/model_params.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
{
"max_lines": 0,
"epochs": 100,
"early_stopping": true,
"early_stopping_patience": 5,
"best_model_metric": "loss",
"batch_size": 64,
"buffer_size": 10000,
"seq_length": 100,
"embedding_dim": 256,
"rnn_units": 256,
"dropout_rate": 0.2,
"rnn_initializer": "glorot_uniform",
"field_delimiter": ",",
"field_delimiter_token": "<d>",
"vocab_size": 20000,
"character_coverage": 1.0,
"pretrain_sentence_count": 1000000,
"max_line_len": 2048,
"dp": true,
"dp_learning_rate": 0.001,
"dp_noise_multiplier": 1.1,
"dp_l2_norm_clip": 1.0,
"dp_microbatches": 256,
"gen_temp": 1.0,
"gen_chars": 0,
"gen_lines": 1000,
"predict_batch_size": 1,
"save_all_checkpoints": false,
"save_best_model": true,
"overwrite": false,
"checkpoint_dir": "/Users/mi/gretel/gretel-synthetics/tests/ckpoint",
"input_data_path": "/Users/mi/gretel/gretel-synthetics/tests/data/smol.txt"
}
34 changes: 34 additions & 0 deletions tests/data/0.14.x/non-dp-model/model_params.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
{
"max_lines": 0,
"epochs": 100,
"early_stopping": true,
"early_stopping_patience": 5,
"best_model_metric": "loss",
"batch_size": 64,
"buffer_size": 10000,
"seq_length": 100,
"embedding_dim": 256,
"rnn_units": 256,
"dropout_rate": 0.2,
"rnn_initializer": "glorot_uniform",
"field_delimiter": ",",
"field_delimiter_token": "<d>",
"vocab_size": 20000,
"character_coverage": 1.0,
"pretrain_sentence_count": 1000000,
"max_line_len": 2048,
"dp": false,
"dp_learning_rate": 0.001,
"dp_noise_multiplier": 1.1,
"dp_l2_norm_clip": 1.0,
"dp_microbatches": 256,
"gen_temp": 1.0,
"gen_chars": 0,
"gen_lines": 1000,
"predict_batch_size": 1,
"save_all_checkpoints": false,
"save_best_model": true,
"overwrite": false,
"checkpoint_dir": "/Users/mi/gretel/gretel-synthetics/tests/ckpoint",
"input_data_path": "/Users/mi/gretel/gretel-synthetics/tests/data/smol.txt"
}
20 changes: 20 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from pathlib import Path
import pytest

from gretel_synthetics.config import config_from_model_dir, TensorFlowConfig

test_data_dir = Path(__file__).parent / "data"


@pytest.mark.parametrize("model_name,dp,expected_learning_rate",
[("non-dp-model", False, 0.01),
("dp-model", True, 0.001)])
def test_load_legacy_config(model_name, dp, expected_learning_rate):
legacy_model_dir = test_data_dir / '0.14.x' / model_name

config = config_from_model_dir(legacy_model_dir)

assert isinstance(config, TensorFlowConfig)
assert 'dp_learning_rate' not in config.__dict__
assert config.learning_rate == expected_learning_rate
assert config.dp == dp

0 comments on commit 5a58a06

Please sign in to comment.