Skip to content

Commit

Permalink
Add save and load functions to DGAN model.
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 829c03baa8c4e565935071b29dd88d02d62cd497
  • Loading branch information
kboyd committed May 4, 2022
1 parent 5e78b94 commit 33d76f1
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 1 deletion.
11 changes: 10 additions & 1 deletion src/gretel_synthetics/timeseries_dgan/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from dataclasses import dataclass
from dataclasses import asdict, dataclass
from enum import Enum


Expand Down Expand Up @@ -122,3 +122,12 @@ class DGANConfig:
generator_rounds: int = 1

cuda: bool = True

def to_dict(self):
"""Return dictionary representation of DGANConfig.
Returns:
Dictionary of member variables, usable to initialize a new config
object, e.g., `DGANConfig(**config.to_dict())`
"""
return asdict(self)
73 changes: 73 additions & 0 deletions src/gretel_synthetics/timeseries_dgan/dgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
"""


from __future__ import annotations

import logging

from typing import List, Optional, Tuple, Union
Expand Down Expand Up @@ -134,6 +136,9 @@ def __init__(
"feature_outputs and attribute_ouputs must either both be given or both be None"
)

self.attribute_column_names = None
self.feature_column_names = None

def train_numpy(
self,
attributes: Optional[np.ndarray] = None,
Expand Down Expand Up @@ -821,3 +826,71 @@ def _extract_from_dataframe(
features = np.expand_dims(features_df.to_numpy(), axis=-1)

return attributes, features

def save(self, file_name: str, **kwargs):
"""Save DGAN model to a file.
Args:
file_name: location to save serialized model
kwargs: additional parameters passed to torch.save
"""
state = {
"config": self.config.to_dict(),
"attribute_outputs": self.attribute_outputs,
"feature_outputs": self.feature_outputs,
}
state["generate_state_dict"] = self.generator.state_dict()
state[
"feature_discriminator_state_dict"
] = self.feature_discriminator.state_dict()
if self.attribute_discriminator is not None:
state[
"attribute_discriminator_state_dict"
] = self.attribute_discriminator.state_dict()

if self.attribute_column_names is not None:
state["attribute_column_names"] = self.attribute_column_names
state["feature_column_names"] = self.feature_column_names

torch.save(state, file_name, **kwargs)

@classmethod
def load(cls, file_name: str, **kwargs) -> DGAN:
"""Load DGAN model instance from a file.
Args:
file_name: location to load from
kwargs: additional parameters passed to torch.load, for example, use
map_location=torch.device("cpu") to load a model saved for GPU on
a machine without cuda
Returns:
DGAN model instance
"""

state = torch.load(file_name, **kwargs)

config = DGANConfig(**state["config"])
dgan = DGAN(config)

dgan._build(state["attribute_outputs"], state["feature_outputs"])

dgan.generator.load_state_dict(state["generate_state_dict"])
dgan.feature_discriminator.load_state_dict(
state["feature_discriminator_state_dict"]
)
if "attribute_discriminator_state_dict" in state:
if dgan.attribute_discriminator is None:
raise RuntimeError(
"Error deserializing model: found unexpected attribute discriminator state in file"
)

dgan.attribute_discriminator.load_state_dict(
state["attribute_discriminator_state_dict"]
)

if "attribute_column_names" in state and "feature_column_names" in state:
dgan.attribute_column_names = state["attribute_column_names"]
dgan.feature_column_names = state["feature_column_names"]

return dgan
54 changes: 54 additions & 0 deletions tests/timeseries_dgan/test_dgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,3 +260,57 @@ def test_extract_from_dataframe(config):

assert attributes.shape == (6, 0)
assert features.shape == (6, 4, 1)


@pytest.mark.parametrize(
"use_attribute_discriminator,apply_example_scaling,noise_dim,sample_len",
itertools.product([False, True], [False, True], [10, 25], [2, 5]),
)
def test_save_and_load(
attribute_data,
feature_data,
config: DGANConfig,
tmp_path,
use_attribute_discriminator,
apply_example_scaling,
noise_dim,
sample_len,
):
attributes, attribute_types = attribute_data
features, feature_types = feature_data

config.epochs = 1
config.use_attribute_discriminator = use_attribute_discriminator
config.apply_example_scaling = apply_example_scaling
config.attribute_noise_dim = noise_dim
config.feature_noise_dim = noise_dim
config.sample_len = sample_len

dg = DGAN(config=config)

dg.train_numpy(
attributes=attributes,
features=features,
attribute_types=attribute_types,
feature_types=feature_types,
)

n = 25
attribute_noise = dg.attribute_noise_func(n)
feature_noise = dg.feature_noise_func(n)

expected_attributes, expected_features = dg.generate_numpy(
attribute_noise=attribute_noise, feature_noise=feature_noise
)

file_name = str(tmp_path / "model.pt")
dg.save(file_name)

loaded_dg = DGAN.load(file_name)

attributes, features = loaded_dg.generate_numpy(
attribute_noise=attribute_noise, feature_noise=feature_noise
)

np.testing.assert_allclose(attributes, expected_attributes)
np.testing.assert_allclose(features, expected_features)

0 comments on commit 33d76f1

Please sign in to comment.