Skip to content

Commit

Permalink
Vectorize conditional row selection ACTGAN sampling
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 9a0722db6bd5053d0c54d66e00f13202e6ad6948
  • Loading branch information
misberner committed May 30, 2023
1 parent bb078b9 commit f1a116e
Showing 1 changed file with 42 additions and 10 deletions.
52 changes: 42 additions & 10 deletions src/gretel_synthetics/actgan/data_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,18 +52,42 @@ def __init__(

n_discrete_columns = len(self._discrete_columns)

# Prepare an interval matrix for efficiently sample conditional vector
max_category = max(
(col.num_values for col in self._discrete_columns), default=0
)

# Store the row id for each category in each discrete column.
# For example _rid_by_cat_cols[a][b] is a list of all rows with the
# a-th discrete column equal value b.
self._rid_by_cat_cols: List[List[np.ndarray]] = [
rid_by_cat_cols: List[List[np.ndarray]] = [
[np.nonzero(col.data == j)[0] for j in range(col.num_values)]
for col in self._discrete_columns
]

# Prepare an interval matrix for efficiently sample conditional vector
max_category = max(
(col.num_values for col in self._discrete_columns), default=0
# Store rid_by_cat_cols as a two-dimensional numpy array. While the
# lengths of rid_by_cat_cols[i][j] are non-uniform, the concatenation
# over all these for a fixed i is. Thus, for locating the elements
# of rid_by_cat_cols[i][j] within the two-dimensional array, we'll need
# to map each (i, j) pair to a range along the second axis.
self._rid_by_cat_cols = (
np.stack([np.concatenate(rids) for rids in rid_by_cat_cols])
if rid_by_cat_cols
else np.zeros((0, 0), dtype=int)
)
# This stores len(rid_by_cat_cols[i, j]) at index (i, j)
self._num_rows_by_cat_cols = np.zeros(
(n_discrete_columns, max_category),
dtype=np.int32,
)
for col, rid_by_cat in enumerate(rid_by_cat_cols):
for cat, rids in enumerate(rid_by_cat):
self._num_rows_by_cat_cols[col, cat] = len(rids)

# This stores the start offset of the row index range for (i, j).
self._rid_ofs_by_cat_cols = np.pad(
self._num_rows_by_cat_cols.cumsum(axis=1),
[(0, 0), (1, 0)],
)[:, :-1]

# Calculate the start position of each discrete column in a conditional
# vector. I.e., the (ordinal) value b of the a-th discrete column is
Expand All @@ -81,7 +105,7 @@ def __init__(
# The probability that the (ordinal) value of the a-th discrete column is
# less than or equal to b is _discrete_column_category_prob_cum[a, b].
self._discrete_column_category_prob_cum = np.zeros(
(n_discrete_columns, max_category)
(n_discrete_columns, max_category),
)
self._n_discrete_columns = n_discrete_columns
self._n_categories = sum(col.num_values for col in self._discrete_columns)
Expand Down Expand Up @@ -143,11 +167,19 @@ def sample_data(self, n, col, opt):
idx = np.random.randint(len(self._train_data), size=n)
return self._train_data.to_numpy_encoded(row_indices=idx)

idx = []
for c, o in zip(col, opt):
idx.append(np.random.choice(self._rid_by_cat_cols[c][o]))
# For each column/option pair (c, o), generate a random integer in
# [0, k), where k is the number of rows with c=o. This will be a
# "local" index in the range corresponding to `o` within
# self._rid_by_cat_cols[c].
rid_idxs = np.random.randint(self._num_rows_by_cat_cols[col, opt])

# Translate the local indices to global indices via the _rid_by_cat_cols
# map, by adding the respective range offset for c before indexing.
row_indices = self._rid_by_cat_cols[
col, self._rid_ofs_by_cat_cols[col, opt] + rid_idxs
]

return self._train_data.to_numpy_encoded(row_indices=idx)
return self._train_data.to_numpy_encoded(row_indices=row_indices)

def dim_cond_vec(self) -> int:
"""Return the total number of categories."""
Expand Down

0 comments on commit f1a116e

Please sign in to comment.