Skip to content

Commit

Permalink
RDS-653: Use eval mode for inference during sampling in ACTGAN.
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 046cf888d7018be1d3552c966cb411d84b40de7c
  • Loading branch information
kboyd authored and theonlyrob committed Jun 9, 2023
1 parent eee8958 commit 0957a28
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/gretel_synthetics/actgan/actgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,8 @@ def sample(
else:
global_condition_vec = None

# Switch generator to eval mode for inference
self._generator.eval()
steps = n // self._batch_size + 1
data = []
for _ in range(steps):
Expand All @@ -574,6 +576,8 @@ def sample(
fakeact = self._apply_activate(fake)
data.append(fakeact.detach().cpu().numpy())

# Switch generator back to train mode now that inference is complete
self._generator.train()
data = np.concatenate(data, axis=0)
data = data[:n]

Expand Down

0 comments on commit 0957a28

Please sign in to comment.