Skip to content

Commit

Permalink
RDS-636: Refactor ACTGAN in preparation of anyway training
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 3527a13404d9cb7af5f8ea4f8202b788371bcd60
  • Loading branch information
kboyd committed Jun 21, 2023
1 parent 051c6b7 commit 0ad0296
Show file tree
Hide file tree
Showing 2 changed files with 191 additions and 85 deletions.
270 changes: 188 additions & 82 deletions src/gretel_synthetics/actgan/actgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from gretel_synthetics.actgan.train_data import TrainData
from gretel_synthetics.typing import DFLike
from packaging import version
from rdt.transformers.base import BaseTransformer
from torch import optim
from torch.nn import (
BatchNorm1d,
Expand All @@ -35,6 +34,23 @@
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

# NOTE on data terminology used in ACTGAN: This module operates with 3 different
# representations of the training data (and generated synthetic data).
#
# - original - input data as received in API calls, as DataFrame, same format
# and style is returned for synthetic samples
# - transformed - compact representation after applying DataTransformer (and
# usually stored in TrainData instances), columns here are always numeric,
# but may be in a more compact decoded form than what the actual DNN works
# on, in particular one-hot or binary encoded columns are stored as
# integer indices, instead of multiple columns, also known as decoded
# - encoded - representation passed directly to DNNs and should be in proper
# float32 dtype
#
# During training we apply the transformations from original -> transformed ->
# encoded. And for generation the process reverses, going from encoded
# representation back to the original format.


class Discriminator(Module):
"""Discriminator for the ACTGANSynthesizer."""
Expand Down Expand Up @@ -104,7 +120,7 @@ def forward(self, input_):
class Generator(Module):
"""Generator for the ACTGANSynthesizer."""

def __init__(self, embedding_dim, generator_dim, data_dim):
def __init__(self, embedding_dim: int, generator_dim: Sequence[int], data_dim: int):
super(Generator, self).__init__()
dim = embedding_dim
seq = []
Expand All @@ -120,7 +136,13 @@ def forward(self, input_):
return data


def _gumbel_softmax_stabilized(logits, tau=1, hard=False, eps=1e-10, dim=-1):
def _gumbel_softmax_stabilized(
logits: torch.Tensor,
tau: float = 1,
hard: bool = False,
eps: float = 1e-10,
dim: int = -1,
):
"""Deals with the instability of the gumbel_softmax for older versions of torch.
For more details about the issue:
https://drive.google.com/file/d/1AA5wPfZ1kquaRtVruCd6BiYZGcDeNxyP/view?usp=sharing
Expand Down Expand Up @@ -225,7 +247,6 @@ def __init__(
pac: int = 10,
cuda: bool = True,
):

if batch_size % 2 != 0:
raise ValueError("`batch_size` must be divisible by 2")

Expand Down Expand Up @@ -276,6 +297,66 @@ def __init__(
else _gumbel_softmax_stabilized
)

def _make_noise(self) -> torch.Tensor:
"""Create new random noise tensors for a batch.
Returns:
Tensor of random noise used as (part of the) input to generator
network. Shape is [batch_size, embedding_dim].
"""
# NOTE: speedup may be possible if we can reuse the mean and std tensors
# here across calls to _make_noise.
mean = torch.zeros((self._batch_size, self._embedding_dim), device=self._device)
std = mean + 1.0
return torch.normal(mean, std)

def _apply_generator(
self, fakez: torch.Tensor, fake_cond_vec: Optional[torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Apply generator network.
Args:
fakez: Random noise (z-vectors), shape is [batch_size,
embedding_dim]
fake_cond_vec: Optional conditional vectors to guide generation,
shape is [batch_size, cond_vec_dim]
Returns:
Tuple of direct generator output, and output after applying
activation functions. Shape of both tensor outputs is [batch_size,
data_dim]
"""
if fake_cond_vec is None:
input = fakez
else:
input = torch.cat([fakez, fake_cond_vec], dim=1)

fake = self._generator(input)
fakeact = self._apply_activate(fake)
return fake, fakeact

def _apply_discriminator(
self,
encoded: torch.Tensor,
cond_vec: Optional[torch.Tensor],
discriminator: Discriminator,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Apply discriminator network.
Args:
encoded: Tensor of data in encoded representation to evaluate.
cond_vec: Optional conditional vector
Returns:
Tuple of full input to the discriminator network and the output.
"""
if cond_vec is None:
input = encoded
else:
input = torch.cat([encoded, cond_vec], dim=1)
y = discriminator(input)
return input, y

def _apply_activate(self, data):
"""Apply proper activation function to the output of the generator."""
data_t = [
Expand Down Expand Up @@ -326,6 +407,54 @@ def _validate_discrete_columns(
if invalid_columns:
raise ValueError(f"Invalid columns found: {invalid_columns}")

def _prepare_batch(
self, data_sampler: DataSampler
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Select a random subset of training data for one batch.
Also prepares other required Tensors such as conditional vectors, for
generator and discriminator training.
Args:
data_sampler: DataSampler instance that performs sampling
Returns:
Tuple of:
- torch.Tensor or None, fake conditional vector (part of input to
generator)
- torch.Tensor or None, real conditional vector associated with
the encoded real sample returned
- torch.Tensor or None, column mask indicating which columns (in
transformed representation) are set in the fake conditional
vector
- torch.Tensor, encoded real sample
"""
fake_cond_vec, fake_column_mask, col, opt = data_sampler.sample_condvec(
self._batch_size
)

if fake_cond_vec is None:
real_encoded = data_sampler.sample_data(self._batch_size, None, None)
real_cond_vec = None
else:
fake_cond_vec = torch.from_numpy(fake_cond_vec).to(self._device)
fake_column_mask = torch.from_numpy(fake_column_mask).to(self._device)

perm = np.random.permutation(self._batch_size)
real_encoded = data_sampler.sample_data(
self._batch_size, col[perm], opt[perm]
)
real_cond_vec = fake_cond_vec[perm]

real_encoded = torch.from_numpy(real_encoded.astype("float32")).to(self._device)

return (
fake_cond_vec,
real_cond_vec,
fake_column_mask,
real_encoded,
)

@random_state
def fit(
self, train_data: DFLike, discrete_columns: Optional[Sequence[str]] = None
Expand Down Expand Up @@ -384,7 +513,7 @@ def _actual_fit(self, train_data: TrainData) -> None:
"""Fit the ACTGAN Synthesizer models to the training data.
Args:
train_data: Training Data. It must be a 2-dimensional numpy array or a pandas.DataFrame.
train_data: training data as a TrainData instance
"""

epochs = self._epochs
Expand All @@ -398,13 +527,13 @@ def _actual_fit(self, train_data: TrainData) -> None:
data_dim = train_data.encoded_dim

self._generator = Generator(
self._embedding_dim + data_sampler.dim_cond_vec(),
self._embedding_dim + data_sampler.cond_vec_dim,
self._generator_dim,
data_dim,
).to(self._device)

discriminator = Discriminator(
data_dim + data_sampler.dim_cond_vec(),
data_dim + data_sampler.cond_vec_dim,
self._discriminator_dim,
pac=self.pac,
).to(self._device)
Expand All @@ -423,47 +552,27 @@ def _actual_fit(self, train_data: TrainData) -> None:
weight_decay=self._discriminator_decay,
)

mean = torch.zeros(self._batch_size, self._embedding_dim, device=self._device)
std = mean + 1

steps_per_epoch = max(len(train_data) // self._batch_size, 1)
for i in range(epochs):
for _ in range(steps_per_epoch):

for n in range(self._discriminator_steps):
fakez = torch.normal(mean=mean, std=std)

condvec = data_sampler.sample_condvec(self._batch_size)
if condvec is None:
c1, m1, col, opt = None, None, None, None
real = data_sampler.sample_data(self._batch_size, col, opt)
else:
c1, m1, col, opt = condvec
c1 = torch.from_numpy(c1).to(self._device)
m1 = torch.from_numpy(m1).to(self._device)
fakez = torch.cat([fakez, c1], dim=1)

perm = np.arange(self._batch_size)
np.random.shuffle(perm)
real = data_sampler.sample_data(
self._batch_size, col[perm], opt[perm]
)
c2 = c1[perm]

fake = self._generator(fakez)
fakeact = self._apply_activate(fake)

real = torch.from_numpy(real.astype("float32")).to(self._device)

if c1 is not None:
fake_cat = torch.cat([fakeact, c1], dim=1)
real_cat = torch.cat([real, c2], dim=1)
else:
real_cat = real
fake_cat = fakeact

y_fake = discriminator(fake_cat)
y_real = discriminator(real_cat)
for _ in range(self._discriminator_steps):
# Optimize discriminator
fakez = self._make_noise()
(
fake_cond_vec,
real_cond_vec,
fake_column_mask,
real_encoded,
) = self._prepare_batch(data_sampler)

fake, fakeact = self._apply_generator(fakez, fake_cond_vec)

fake_cat, y_fake = self._apply_discriminator(
fakeact, fake_cond_vec, discriminator
)
real_cat, y_real = self._apply_discriminator(
real_encoded, real_cond_vec, discriminator
)

pen = discriminator.calc_gradient_penalty(
real_cat, fake_cat, self._device, self.pac
Expand All @@ -475,31 +584,30 @@ def _actual_fit(self, train_data: TrainData) -> None:
loss_d.backward()
optimizerD.step()

fakez = torch.normal(mean=mean, std=std)
condvec = data_sampler.sample_condvec(self._batch_size)

if condvec is None:
c1, m1, col, opt = None, None, None, None
else:
c1, m1, col, opt = condvec
c1 = torch.from_numpy(c1).to(self._device)
m1 = torch.from_numpy(m1).to(self._device)
fakez = torch.cat([fakez, c1], dim=1)

fake = self._generator(fakez)
fakeact = self._apply_activate(fake)

if c1 is not None:
y_fake = discriminator(torch.cat([fakeact, c1], dim=1))
else:
y_fake = discriminator(fakeact)
# Optimize generator
fakez = self._make_noise()
(
fake_cond_vec,
real_cond_vec,
fake_column_mask,
# Real data is unused here, possible speedup if we skip
# creating this Tensor for CTGAN style conditional vectors
_,
) = self._prepare_batch(data_sampler)

fake, fakeact = self._apply_generator(fakez, fake_cond_vec)
fake_cat, y_fake = self._apply_discriminator(
fakeact, fake_cond_vec, discriminator
)

if condvec is None:
cross_entropy = 0
if fake_cond_vec is None:
loss_reconstruction = 0.0
else:
cross_entropy = self._cond_loss(fake, c1, m1)
loss_reconstruction = self._cond_loss(
fake, fake_cond_vec, fake_column_mask
)

loss_g = -torch.mean(y_fake) + cross_entropy
loss_g = -torch.mean(y_fake) + loss_reconstruction

optimizerG.zero_grad()
loss_g.backward()
Expand Down Expand Up @@ -553,36 +661,34 @@ def sample(

# Switch generator to eval mode for inference
self._generator.eval()
steps = n // self._batch_size + 1
steps = (n - 1) // self._batch_size + 1
data = []
for _ in range(steps):
mean = torch.zeros(self._batch_size, self._embedding_dim)
std = mean + 1
fakez = torch.normal(mean=mean, std=std).to(self._device)

if global_condition_vec is not None:
condvec = global_condition_vec.copy()
condvec_numpy = global_condition_vec.copy()
else:
condvec = self._condvec_sampler.sample_original_condvec(
condvec_numpy = self._condvec_sampler.sample_original_condvec(
self._batch_size
)

if condvec is not None:
c1 = condvec
c1 = torch.from_numpy(c1).to(self._device)
fakez = torch.cat([fakez, c1], dim=1)
fakez = self._make_noise()

if condvec_numpy is not None:
condvec = torch.from_numpy(condvec_numpy).to(self._device)
else:
condvec = None

fake, fakeact = self._apply_generator(fakez, condvec)

fake = self._generator(fakez)
fakeact = self._apply_activate(fake)
data.append(fakeact.detach().cpu().numpy())

# Switch generator back to train mode now that inference is complete
self._generator.train()
data = np.concatenate(data, axis=0)
data = data[:n]

transformed_data = self._transformer.inverse_transform(data)
return transformed_data
original_repr_data = self._transformer.inverse_transform(data)
return original_repr_data

def set_device(self, device: str) -> None:
"""Set the `device` to be used ('GPU' or 'CPU)."""
Expand Down

0 comments on commit 0ad0296

Please sign in to comment.