Skip to content

Commit

Permalink
PROD-389: Fix sampling of conditional vectors in ACTGAN
Browse files Browse the repository at this point in the history
GitOrigin-RevId: c5574321e5634d77c081d121d038a5bc6280c951
  • Loading branch information
misberner committed Mar 20, 2023
1 parent 0a80292 commit be1e267
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/gretel_synthetics/actgan/data_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def __init__(self, data, output_info: List[List[SpanInfo]], log_frequency):
self._discrete_column_category_prob[
current_id, : span_info.dim
] = category_prob
self._discrete_column_matrix_st[current_id] = st
self._discrete_column_cond_st[current_id] = current_cond_st
self._discrete_column_n_category[current_id] = span_info.dim
current_cond_st += span_info.dim
Expand Down Expand Up @@ -171,7 +172,7 @@ def generate_cond_from_condition_column_info(
) -> np.ndarray:
"""Generate the condition vector."""
vec = np.zeros((batch_size, self._n_categories), dtype="float32")
id_ = self._discrete_column_matrix_st[condition_info.discrete_column_id]
id_ = self._discrete_column_cond_st[condition_info.discrete_column_id]
id_ += condition_info.value_id
vec[:, id_] = 1
return vec

0 comments on commit be1e267

Please sign in to comment.