Skip to content

Commit

Permalink
convert actgan params to pure scalars for model creation
Browse files Browse the repository at this point in the history
* convert actgan params to pure scalars for model creation

* add modified torch unpickling to actgan directly

GitOrigin-RevId: ed4687ef4682200a311ea9d2563c8e0f42174fc2
  • Loading branch information
johntmyers committed Dec 22, 2022
1 parent 805da13 commit cc933ea
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 4 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@ sdv<0.18.0
tensorflow_estimator==2.8
tensorflow_privacy==0.7.3
tensorflow_probability==0.16.0
torch==1.13.1
torch==1.11.0
tqdm<5.0
13 changes: 13 additions & 0 deletions src/gretel_synthetics/actgan/actgan_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from gretel_synthetics.actgan.actgan import ACTGANSynthesizer
from gretel_synthetics.detectors.sdv import SDVTableMetadata
from gretel_synthetics.utils import torch_utils
from sdv.tabular.base import BaseTabularModel

if TYPE_CHECKING:
Expand Down Expand Up @@ -159,6 +160,18 @@ def save(self, path: str) -> None:
self._model._epoch_callback = _tmp_callback
self._model_kwargs[EPOCH_CALLBACK] = _tmp_callback

@classmethod
def load_v2(cls, path: str) -> ACTGAN:
"""
An updated version of loading that will allow reading in a pickled model
that can be used on a CPU or GPU for sampling.
"""
device = torch_utils.determine_device()
with open(path, "rb") as fin:
loaded_model: ACTGAN = torch_utils.patched_torch_unpickle(fin, device)
loaded_model._model.set_device(device)
return loaded_model


class ACTGAN(_ACTGANModel):
"""
Expand Down
40 changes: 40 additions & 0 deletions src/gretel_synthetics/utils/torch_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""
Utils for PyTorch
"""
import pickle

from io import BytesIO
from typing import BinaryIO

import torch


def determine_device() -> str:
"""
Returns device on which generation should run.
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
return device


def patched_torch_unpickle(file_handle: BinaryIO, device: str) -> object:
# https://github.com/pytorch/pytorch/issues/16797#issuecomment-777059657

unpickler = _PyTorchPatchedUnpickler(file_handle, map_location=device)
return unpickler.load()


class _PyTorchPatchedUnpickler(pickle.Unpickler):
def __init__(self, *args, map_location: str, **kwargs):
self._map_location = map_location
super().__init__(*args, **kwargs)

def find_class(self, module, name):
if module == "torch.storage" and name == "_load_from_bytes":
return _load_with_map_location(self._map_location)
else:
return super().find_class(module, name)


def _load_with_map_location(map_location: str) -> callable:
return lambda b: torch.load(BytesIO(b), map_location=map_location)
2 changes: 1 addition & 1 deletion test-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
faker==15.3.3
faker==4.1.1
flake8==4.0.1
numpy>=1.18.0
pandas>=1.1.0
Expand Down
Empty file removed tests/__init__.py
Empty file.
3 changes: 1 addition & 2 deletions tests/actgan/test_actgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ def test_auto_transform_datetimes(test_df):
model._fit = Mock()
model.fit(test_df)

_, args, _ = model._fit.mock_calls[0]
transformed_df = args[0]
transformed_df = model._fit.mock_calls[0].args[0]
assert is_number(transformed_df[transformed_df.columns[0]][0])


Expand Down

0 comments on commit cc933ea

Please sign in to comment.