Skip to content

Commit

Permalink
Jm/read mode update (#56)
Browse files Browse the repository at this point in the history
* Update read-only mode to support model loading from any location

* Use cloudpickle for validators

Co-authored-by: John Myers <john@gretel.ai>
  • Loading branch information
johntmyers and John Myers committed Sep 14, 2020
1 parent a22cabf commit dd9f615
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions src/gretel_synthetics/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import gzip
from math import ceil
from typing import List, Type, Callable, Dict
import pickle
from copy import deepcopy
import logging
import io
Expand All @@ -22,6 +21,7 @@
import pandas as pd
import numpy as np
from tqdm.auto import tqdm
import cloudpickle

from gretel_synthetics.config import LocalConfig
from gretel_synthetics.generate import gen_text, generate_text
Expand Down Expand Up @@ -83,14 +83,14 @@ def set_validator(self, fn: Callable, save=True):
if save:
p = Path(self.checkpoint_dir) / "validator.p.gz"
with gzip.open(p, "w") as fout:
fout.write(pickle.dumps(fn))
fout.write(cloudpickle.dumps(fn))

def load_validator_from_file(self):
"""Load a saved validation object if it exists """
p = Path(self.checkpoint_dir) / "validator.p.gz"
if p.exists():
with gzip.open(p, "r") as fin:
self.validator = pickle.loads(fin.read())
self.validator = cloudpickle.loads(fin.read())

def reset_gen_data(self):
"""Reset all objects that accumulate or track synthetic
Expand Down Expand Up @@ -141,9 +141,14 @@ def _create_batch_from_dir(batch_dir: str):
raise ValueError("missing model param file")
config = json.loads(open(path / CONFIG_FILE).read())

if not (path / TRAIN_FILE).is_file(): # pragma: no cover
raise ValueError("missing training data")
train_path = str(path / TRAIN_FILE)
# training path can be empty, since we will not need access
# to training data simply for read-only data generation
train_path = ""

# overwrite the previously saved config with the location that we are reading
# the model data in from. this enables a model to be loaded from a different
# location other than the exact location the data was stored during training
config["checkpoint_dir"] = batch_dir

batch = Batch(
checkpoint_dir=batch_dir,
Expand Down

0 comments on commit dd9f615

Please sign in to comment.