Skip to content

Commit

Permalink
PROD-394: Factor out original condvec sampler for ACTGAN
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 394e2ef9b5b57b2728b3521b62a86b003079cc3e
  • Loading branch information
misberner committed Mar 30, 2023
1 parent 059c669 commit 5d387d9
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 29 deletions.
25 changes: 13 additions & 12 deletions src/gretel_synthetics/actgan/actgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def __init__(
self._device = torch.device(device)

self._transformer = None
self._data_sampler = None
self._condvec_sampler = None
self._generator = None

self._activation_fns: List[
Expand Down Expand Up @@ -389,21 +389,22 @@ def _actual_fit(self, train_data: TrainData) -> None:

epochs = self._epochs

self._data_sampler = DataSampler(
data_sampler = DataSampler(
train_data,
self._log_frequency,
)
self._condvec_sampler = data_sampler.condvec_sampler

data_dim = train_data.encoded_dim

self._generator = Generator(
self._embedding_dim + self._data_sampler.dim_cond_vec(),
self._embedding_dim + data_sampler.dim_cond_vec(),
self._generator_dim,
data_dim,
).to(self._device)

discriminator = Discriminator(
data_dim + self._data_sampler.dim_cond_vec(),
data_dim + data_sampler.dim_cond_vec(),
self._discriminator_dim,
pac=self.pac,
).to(self._device)
Expand Down Expand Up @@ -432,12 +433,10 @@ def _actual_fit(self, train_data: TrainData) -> None:
for n in range(self._discriminator_steps):
fakez = torch.normal(mean=mean, std=std)

condvec = self._data_sampler.sample_condvec(self._batch_size)
condvec = data_sampler.sample_condvec(self._batch_size)
if condvec is None:
c1, m1, col, opt = None, None, None, None
real = self._data_sampler.sample_data(
self._batch_size, col, opt
)
real = data_sampler.sample_data(self._batch_size, col, opt)
else:
c1, m1, col, opt = condvec
c1 = torch.from_numpy(c1).to(self._device)
Expand All @@ -446,7 +445,7 @@ def _actual_fit(self, train_data: TrainData) -> None:

perm = np.arange(self._batch_size)
np.random.shuffle(perm)
real = self._data_sampler.sample_data(
real = data_sampler.sample_data(
self._batch_size, col[perm], opt[perm]
)
c2 = c1[perm]
Expand Down Expand Up @@ -477,7 +476,7 @@ def _actual_fit(self, train_data: TrainData) -> None:
optimizerD.step()

fakez = torch.normal(mean=mean, std=std)
condvec = self._data_sampler.sample_condvec(self._batch_size)
condvec = data_sampler.sample_condvec(self._batch_size)

if condvec is None:
c1, m1, col, opt = None, None, None, None
Expand Down Expand Up @@ -545,7 +544,7 @@ def sample(
condition_column, condition_value
)
global_condition_vec = (
self._data_sampler.generate_cond_from_condition_column_info(
self._condvec_sampler.generate_cond_from_condition_column_info(
condition_info, self._batch_size
)
)
Expand All @@ -562,7 +561,9 @@ def sample(
if global_condition_vec is not None:
condvec = global_condition_vec.copy()
else:
condvec = self._data_sampler.sample_original_condvec(self._batch_size)
condvec = self._condvec_sampler.sample_original_condvec(
self._batch_size
)

if condvec is not None:
c1 = condvec
Expand Down
87 changes: 71 additions & 16 deletions src/gretel_synthetics/actgan/data_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,17 @@ def __init__(
self._n_discrete_columns = n_discrete_columns
self._n_categories = sum(col.num_values for col in self._discrete_columns)

category_freqs = []
for i, col in enumerate(self._discrete_columns):
category_freq = np.bincount(col.data, minlength=max_category)
category_freqs.append(category_freq[: col.num_values])
if log_frequency:
category_freq = np.log(category_freq + 1)
category_prob = category_freq / np.sum(category_freq)
self._discrete_column_category_prob_cum[i] = category_prob.cumsum()

self._condvec_sampler = CondVecSampler(category_freqs)

def _random_choice_prob_index(self, discrete_column_id):
probs_cum = self._discrete_column_category_prob_cum[discrete_column_id]
r = np.expand_dims(np.random.rand(probs_cum.shape[0]), axis=1)
Expand Down Expand Up @@ -129,21 +133,6 @@ def sample_condvec(self, batch):

return cond, mask, discrete_column_id, category_id_in_col

def sample_original_condvec(self, batch):
"""Generate the conditional vector for generation use original frequency."""
if self._n_discrete_columns == 0:
return None

cond = np.zeros((batch, self._n_categories), dtype="float32")

for i in range(batch):
row_idx = np.random.randint(0, self._n_rows)
col_idx = np.random.randint(0, self._n_discrete_columns)
pick = self._discrete_columns[col_idx].data[row_idx]
cond[i, pick + self._discrete_column_cond_st[col_idx]] = 1

return cond

def sample_data(self, n, col, opt):
"""Sample data from original training data satisfying the sampled conditional vector.
Expand All @@ -164,9 +153,75 @@ def dim_cond_vec(self) -> int:
"""Return the total number of categories."""
return self._n_categories

@property
def condvec_sampler(self) -> CondVecSampler:
return self._condvec_sampler


class CondVecSampler:

_n_discrete_columns: int
"""The total number of discrete columns."""
_n_categories: int
"""The cumulative number of categories across all discrete columns."""
_categories_uniform_prob_cum: np.ndarray
"""A vector with cumulative probabilities for selecting column/category pairs."""
_discrete_column_cond_st: np.ndarray
"""Starting offset for each discrete column in the conditional vector."""

def __init__(self, category_freqs: List[np.ndarray]):
"""Constructor.
Args:
category_freqs:
For each discrete column, this list contains a 1D array with
absolute category frequencies.
"""

self._n_discrete_columns = len(category_freqs)
self._n_categories = sum(len(a) for a in category_freqs)

if self._n_discrete_columns == 0:
return

# Calculate a probability vector for selecting a single (column, category) pair,
# where the column is chosen uniformly at random among all discrete columns, and the
# category is chosen according to its relative frequency within the column.
categories_uniform_prob = np.concatenate([a / a.sum() for a in category_freqs])
categories_uniform_prob = (
categories_uniform_prob / categories_uniform_prob.sum()
)
self._categories_uniform_prob_cum = categories_uniform_prob.cumsum()

# Calculate the starting offset for each discrete column in the conditional vector.
# This is the cumulative number of categories in all previous discrete columns.
self._discrete_column_cond_st = np.array(
[0] + [len(a) for a in category_freqs[:-1]]
).cumsum()

def sample_original_condvec(self, batch_size: int):
"""Generate the conditional vector for generation use original frequency."""
if self._n_discrete_columns == 0:
return None

r = np.random.rand(batch_size, 1)
pick = np.argmax(r < self._categories_uniform_prob_cum, axis=1)

cond = np.zeros((batch_size, self._n_categories), dtype="float32")
cond[np.arange(batch_size), pick] = 1

return cond

def generate_cond_from_condition_column_info(
self, condition_info: ColumnIdInfo, batch_size: int
self,
condition_info: ColumnIdInfo,
batch_size: int,
) -> np.ndarray:
if condition_info.discrete_column_id >= self._n_discrete_columns:
raise ValueError(
f"invalid discrete column ID {condition_info.discrete_column_id}, "
+ f"there are only {self._n_discrete_columns} discrete columns"
)
"""Generate the condition vector."""
vec = np.zeros((batch_size, self._n_categories), dtype="float32")
id_ = self._discrete_column_cond_st[condition_info.discrete_column_id]
Expand Down
2 changes: 1 addition & 1 deletion src/gretel_synthetics/actgan/data_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def __init__(

def _fit_continuous(self, data: pd.DataFrame) -> ColumnTransformInfo:
"""Train Bayesian GMM for continuous columns."""
if self._cbn_sample_size is not None and self._cbn_sample_size < len(data):
if self._cbn_sample_size and self._cbn_sample_size < len(data):
# Train on only a sample of the data, if requested.
data = data.sample(n=self._cbn_sample_size)
column_name = data.columns[0]
Expand Down

0 comments on commit 5d387d9

Please sign in to comment.