Skip to content
This repository has been archived by the owner on Jul 31, 2023. It is now read-only.

Commit

Permalink
Refactor check and common modules to utils.
Browse files Browse the repository at this point in the history
  • Loading branch information
cfezequiel committed Oct 28, 2020
1 parent 1db8664 commit 7f3f480
Show file tree
Hide file tree
Showing 10 changed files with 76 additions and 125 deletions.
2 changes: 1 addition & 1 deletion .github/PULL_REQUEST_TEMPLATE.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Please delete options that are not relevant.
- [ ] My code adheres to the [Google Python Style Guide](https://google.github.io/styleguide/pyguide.html)
- [ ] I ran `make pylint` and code is rated 10/10
- [ ] I have added tests that prove my fix is effective or that my feature works
- [ ] I ran `make test` and all tests pass
- [ ] I ran `make test testnb` and all tests pass
- [ ] I ran the tool and verified the change works
- [ ] I have adequately commented my code, particularly in hard-to-understand areas
- [ ] I have made relevant changes to the documentation, if needed
Expand Down
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
all: init test pylint
all: init testnb test pylint

init:
pip install -r requirements.txt
Expand All @@ -12,4 +12,4 @@ testnb:
pylint:
pylint -j 0 tfrecorder

.PHONY: all init test pylint
.PHONY: all init testnb test pylint
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,8 @@ Using Python interpreter:
import tfrecorder

tfrecorder.inspect(
file_pattern='/path/to/tfrecords/train*.tfrecord.gz',
tfrecord_dir='/path/to/tfrecords/',
split='TRAIN',
num_records=5,
output_dir='/tmp/output')
```
Expand All @@ -186,8 +187,9 @@ representing the images encoded into TFRecords.
Using the command line:

```bash
tfrecorder check-tfrecords \
--file_pattern=/path/to/tfrecords/train*.tfrecord.gz \
tfrecorder inspect \
--tfrecord-dir=/path/to/tfrecords/ \
--split='TRAIN' \
--num_records=5 \
--output_dir=/tmp/output
```
Expand Down
5 changes: 2 additions & 3 deletions tfrecorder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

from tfrecorder import accessor
from tfrecorder.converter import convert
from tfrecorder.converter import convert_and_load
# TODO(cezequiel): refactor check module
from tfrecorder.check import check_tfrecords as inspect
from tfrecorder.dataset_loader import load
from tfrecorder.converter import convert_and_load
from tfrecorder.utils import inspect
6 changes: 3 additions & 3 deletions tfrecorder/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@
import fire

from tfrecorder import converter
from tfrecorder import check
from tfrecorder import utils


def main():
"""Entry point for command-line interface."""

fire.Fire({
'create-tfrecords': converter.convert,
'check-tfrecords': check.check_tfrecords,
'convert': converter.convert,
'inspect': utils.inspect,
})


Expand Down
42 changes: 0 additions & 42 deletions tfrecorder/common.py

This file was deleted.

53 changes: 0 additions & 53 deletions tfrecorder/common_test.py

This file was deleted.

13 changes: 7 additions & 6 deletions tfrecorder/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@
import tensorflow as tf

from tfrecorder import beam_pipeline
from tfrecorder import common
from tfrecorder import dataset_loader
from tfrecorder import constants
from tfrecorder import input_schema
from tfrecorder import types
from tfrecorder import utils


# TODO(mikebernico) Add test for only one split_key.
Expand Down Expand Up @@ -152,7 +152,7 @@ def _get_job_name(job_label: str = None) -> str:
insure uniqueness.
"""

job_name = 'tfrecorder-' + common.get_timestamp()
job_name = 'tfrecorder-' + utils.get_timestamp()
if job_label:
job_label = job_label.replace('_', '-')
job_name += '-' + job_label
Expand Down Expand Up @@ -254,7 +254,7 @@ def convert(
region: Optional[str] = None,
tfrecorder_wheel: Optional[str] = None,
dataflow_options: Optional[Dict[str, Any]] = None,
job_label: str = 'create-tfrecords',
job_label: str = 'convert',
compression: Optional[str] = 'gzip',
num_shards: int = 0) -> Dict[str, Any]:
"""Generates TFRecord files from given input data.
Expand All @@ -265,10 +265,10 @@ def convert(
Usage:
import tfrecorder
job_id = tfrecorder.client.create_tfrecords(
job_id = tfrecorder.convert(
train_df,
output_dir='gcs://foo/bar/train',
runner='DirectFlowRunner)
runner='DirectRunner)
Args:
source: Pandas DataFrame, CSV file or image directory path.
Expand All @@ -277,6 +277,7 @@ def convert(
header: Indicates row/s to use as a header. Not used when `input_data` is
a Pandas DataFrame.
If 'infer' (default), header is taken from the first line of a CSV
names: List of column names to use for CSV or DataFrame input.
runner: Beam runner. Can be 'DirectRunner' or 'DataFlowRunner'
project: GCP project name (Required if DataflowRunner)
region: GCP region name (Required if DataflowRunner)
Expand Down Expand Up @@ -353,7 +354,7 @@ def convert(
'dataflow_url': url,
}
# Copy the logfile to GCS output dir
common.copy_logfile_to_gcs(logfile, output_dir)
utils.copy_logfile_to_gcs(logfile, output_dir)

else:
raise ValueError(f'Unsupported runner: {runner}')
Expand Down
34 changes: 26 additions & 8 deletions tfrecorder/check.py → tfrecorder/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utilities for checking content of TFRecord files."""
"""Miscellaneous utility functions."""

from datetime import datetime
from typing import Dict

import csv
Expand All @@ -24,9 +25,8 @@
import tensorflow as tf

from tfrecorder import beam_image
from tfrecorder import common
# TODO(cezequiel): Rename `dataset` module to something else
from tfrecorder import dataset_loader as _dataset
from tfrecorder import constants
from tfrecorder import dataset_loader

_OUT_IMAGE_TEMPLATE = 'image_{:0>3d}.png'

Expand All @@ -48,12 +48,12 @@ def _save_image_from_record(record: Dict[str, tf.Tensor], outfile: str):
image.save(outfile)


def check_tfrecords(
def inspect(
tfrecord_dir: str,
split: str = 'TRAIN',
num_records: int = 1,
output_dir: str = 'output'):
"""Checks TFRecords from a TFRecord directory generated by TFRecorder.
"""Prints contents of TFRecord files generated by TFRecorder.
Args:
tfrecord_dir: TFRecord directory.
Expand All @@ -65,12 +65,12 @@ def check_tfrecords(
`ValueError` when data for a given `split` could not be loaded.
"""

dataset = _dataset.load(tfrecord_dir).get(split)
dataset = dataset_loader.load(tfrecord_dir).get(split)
if not dataset:
raise ValueError(f'Could not load data for {split}')

data_dir = os.path.join(
output_dir, 'check-tfrecords-' + common.get_timestamp())
output_dir, 'check-tfrecords-' + get_timestamp())
os.makedirs(data_dir)

with open(os.path.join(data_dir, 'data.csv'), 'wt') as f:
Expand Down Expand Up @@ -99,3 +99,21 @@ def check_tfrecords(
print('Output written to {}'.format(data_dir))

return data_dir


def get_timestamp() -> str:
"""Returns current date and time as formatted string."""
return datetime.now().strftime('%Y%m%d-%H%M%S')


def copy_logfile_to_gcs(logfile: str, output_dir: str):
"""Copies a logfile from local to gcs storage."""
try:
with open(logfile, 'r') as log_reader:
out_log = os.path.join(output_dir, constants.LOGFILE)
with tf.io.gfile.GFile(out_log, 'w') as gcs_logfile:
log = log_reader.read()
gcs_logfile.write(log)
except FileNotFoundError as e:
raise FileNotFoundError("Unable to copy log file {} to gcs.".format(
e.filename)) from e
34 changes: 30 additions & 4 deletions tfrecorder/check_test.py → tfrecorder/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests `check.py`."""
"""Tests `utils.py`."""

import functools
import os
Expand All @@ -27,7 +27,8 @@
import tensorflow as tf

from tfrecorder import beam_image
from tfrecorder import check
from tfrecorder import constants
from tfrecorder import utils
from tfrecorder import test_utils
from tfrecorder import input_schema
from tfrecorder import dataset_loader as _dataset
Expand Down Expand Up @@ -76,7 +77,7 @@ def test_valid_records(self, mock_fn):
num_records = len(self.data['image'])

with tempfile.TemporaryDirectory(dir='/tmp') as dir_:
actual_dir = check.check_tfrecords(
actual_dir = utils.inspect(
self.tfrecord_dir, split=self.split, num_records=num_records,
output_dir=dir_)
self.assertTrue('check-tfrecords-' in actual_dir)
Expand All @@ -103,8 +104,33 @@ def test_no_data_for_split(self, mock_fn):

mock_fn.return_value = {}
with self.assertRaisesRegex(ValueError, 'Could not load data for'):
check.check_tfrecords(self.tfrecord_dir, split='UNSUPPORTED')
utils.inspect(self.tfrecord_dir, split='UNSUPPORTED')


if __name__ == '__main__':
unittest.main()


class CopyLogTest(unittest.TestCase):
"""Misc tests for _copy_logfile_to_gcs."""

def test_valid_copy(self):
"""Test valid file copy."""
with tempfile.TemporaryDirectory() as tmpdirname:
text = 'log test log test'
infile = os.path.join(tmpdirname, 'foo.log')
with open(infile, 'w') as f:
f.write(text)
utils.copy_logfile_to_gcs(infile, tmpdirname)

outfile = os.path.join(tmpdirname, constants.LOGFILE)
with open(outfile, 'r') as f:
data = f.read()
self.assertEqual(text, data)

def test_invalid_copy(self):
"""Test invalid file copy."""
with tempfile.TemporaryDirectory() as tmpdirname:
infile = os.path.join(tmpdirname, 'foo.txt')
with self.assertRaises(FileNotFoundError):
utils.copy_logfile_to_gcs(infile, tmpdirname)

0 comments on commit 7f3f480

Please sign in to comment.