Skip to content

Commit

Permalink
[Python] Improve hygiene around .read() on file-like objects
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 7af5e6d8b0c9c710127ece31c867674d6c713d9b
  • Loading branch information
misberner committed Mar 31, 2023
1 parent 616c5ad commit 02469e8
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 7 deletions.
5 changes: 3 additions & 2 deletions src/gretel_synthetics/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def load_validator_from_file(self):
p = Path(self.checkpoint_dir) / "validator.p.gz"
if p.exists():
with gzip.open(p, "r") as fin:
self.validator = cloudpickle.loads(fin.read())
self.validator = cloudpickle.load(fin)

def reset_gen_data(self):
"""Reset all objects that accumulate or track synthetic
Expand Down Expand Up @@ -180,7 +180,8 @@ def _create_batch_from_dir(batch_dir: str):

if not (path / HEADER_FILE).is_file(): # pragma: no cover
raise ValueError("missing headers")
headers = json.loads(open(path / HEADER_FILE).read())
with open(path / HEADER_FILE) as f:
headers = json.load(f)

if not (path / CONFIG_FILE).is_file(): # pragma: no cover
raise ValueError("missing model param file")
Expand Down
3 changes: 2 additions & 1 deletion src/gretel_synthetics/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,8 @@ def config_from_model_dir(model_dir: str) -> BaseConfig:
package and will instantiate a TensorFlowConfig
"""
params_file = Path(model_dir) / const.MODEL_PARAMS
params_dict = json.loads(open(params_file).read())
with open(params_file) as f:
params_dict = json.load(f)
model_type = params_dict.pop(const.MODEL_TYPE, None)

# swap out the checkpoint dir location for the currently
Expand Down
8 changes: 4 additions & 4 deletions src/gretel_synthetics/tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,13 +232,13 @@ def _load_delimiter_data(self):
params_file = Path(self._model_dir) / self.settings_fname

if not params_file.is_file():
model_params_dict = json.loads(
open(Path(self._model_dir) / const.MODEL_PARAMS).read()
)
with open(Path(self._model_dir) / const.MODEL_PARAMS) as f:
model_params_dict = json.load(f)
self.field_delimiter = model_params_dict[FIELD_DELIM]
self.field_delimiter_token = model_params_dict[FIELD_DELIM_TOKEN]
else:
params_dict = json.loads(open(params_file).read())
with open(params_file) as f:
params_dict = json.load(f)
self.field_delimiter = params_dict.get(FIELD_DELIM, None)
self.field_delimiter_token = params_dict.get(FIELD_DELIM_TOKEN, None)

Expand Down

0 comments on commit 02469e8

Please sign in to comment.