Skip to content

Commit

Permalink
RDS-252: Adding codes for speedups
Browse files Browse the repository at this point in the history
* RDS-252: Adding codes for speedups

Adding mixed-precision: Instances of torch.cuda.amp.autocast enable autocasting for chosen regions. Autocasting automatically chooses the precision for GPU operations to improve performance while maintaining accuracy. Instances of torch.cuda.amp.GradScaler helps perform the steps of gradient scaling conveniently. Gradient scaling improves convergence for networks with float16 gradients by minimizing gradient underflow, as explained here.

Enable async data loading: supports asynchronous data loading and data augmentation in separate worker subprocesses. The default setting for DataLoader is num_workers=0, which means that the data loading is synchronous and done in the main process. As a result the main training process has to wait for the data to be available to continue the execution. Setting num_workers > 0 enables asynchronous data loading and overlap between the training and data loading. num_workers should be tuned depending on the workload, CPU, GPU, and location of training data.

Adding functionality for being able to turn off cuda mixed precision by adding config variable. Default is config.mixed_precision_training = False. Enable it by config.mixed_precision_training = True.

GitOrigin-RevId: 4746c00b8bea10a95a4bf03a6c6b30184ac1b5fa
  • Loading branch information
santhosh97 committed Jun 23, 2022
1 parent 53e1881 commit 6e6bbfc
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 90 deletions.
5 changes: 5 additions & 0 deletions src/gretel_synthetics/timeseries_dgan/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ class DGANConfig:
batch
generator_rounds: training steps for the generator in each batch
cuda: use GPU if available
mixed_precision_training: enabling automatic mixed precision while training
in order to reduce memory costs, bandwith, and time by identifying the
steps that require full precision and using 32-bit floating point for
only those steps while using 16-bit floating point everywhere else.
"""

# Model structure
Expand Down Expand Up @@ -122,6 +126,7 @@ class DGANConfig:
generator_rounds: int = 1

cuda: bool = True
mixed_precision_training: bool = False

def to_dict(self):
"""Return dictionary representation of DGANConfig.
Expand Down
172 changes: 101 additions & 71 deletions src/gretel_synthetics/timeseries_dgan/dgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,8 +321,8 @@ def generate_numpy(
raise RuntimeError(
"generate() must receive either n or both attribute_noise and feature_noise"
)
attribute_noise = attribute_noise.to(self.device)
feature_noise = feature_noise.to(self.device)
attribute_noise = attribute_noise.to(self.device, non_blocking=True)
feature_noise = feature_noise.to(self.device, non_blocking=True)

internal_data = self._generate(attribute_noise, feature_noise)

Expand Down Expand Up @@ -436,7 +436,8 @@ def _build(
self.config.feature_num_units,
self.config.feature_num_layers,
)
self.generator.to(self.device)
self.generator.to(self.device, non_blocking=True)

if self.attribute_outputs is None:
self.attribute_outputs = []
attribute_dim = sum(output.dim for output in self.attribute_outputs)
Expand All @@ -454,7 +455,7 @@ def _build(
num_layers=5,
num_units=200,
)
self.feature_discriminator.to(self.device)
self.feature_discriminator.to(self.device, non_blocking=True)

self.attribute_discriminator = None
if not self.additional_attribute_outputs and not self.attribute_outputs:
Expand All @@ -466,7 +467,7 @@ def _build(
num_layers=5,
num_units=200,
)
self.attribute_discriminator.to(self.device)
self.attribute_discriminator.to(self.device, non_blocking=True)

self.attribute_noise_func = lambda batch_size: torch.randn(
batch_size, self.config.attribute_noise_dim, device=self.device
Expand Down Expand Up @@ -529,7 +530,14 @@ def _train(
"""

loader = DataLoader(
dataset, self.config.batch_size, shuffle=True, drop_last=True
dataset,
self.config.batch_size,
shuffle=True,
drop_last=True,
num_workers=2,
prefetch_factor=4,
persistent_workers=True,
pin_memory=True,
)

opt_discriminator = torch.optim.Adam(
Expand All @@ -556,94 +564,116 @@ def _train(

# Set torch modules to training mode
self._set_mode(True)
scaler = torch.cuda.amp.GradScaler(enabled=self.config.mixed_precision_training)

for epoch in range(self.config.epochs):
logger.info(f"epoch: {epoch}")

for real_batch in loader:
global_step += 1
attribute_noise = self.attribute_noise_func(self.config.batch_size)
feature_noise = self.feature_noise_func(self.config.batch_size)
with torch.cuda.amp.autocast(
enabled=self.config.mixed_precision_training
):
attribute_noise = self.attribute_noise_func(
real_batch[0].shape[0]
).to(self.device, non_blocking=True)
feature_noise = self.feature_noise_func(real_batch[0].shape[0]).to(
self.device, non_blocking=True
)

# Both real and generated batch are always three element tuple of
# tensors. The tuple is structured as follows: (attribute_output,
# additional_attribute_output, feature_output). If self.attribute_output
# and/or self.additional_attribute_output is empty, the respective
# tuple index will be filled with a placeholder nan-filled tensor.
# These nan-filled tensors get filtered out in the _discriminate,
# _get_gradient_penalty, and _discriminate_attributes functions.
# Both real and generated batch are always three element tuple of
# tensors. The tuple is structured as follows: (attribute_output,
# additional_attribute_output, feature_output). If self.attribute_output
# and/or self.additional_attribute_output is empty, the respective
# tuple index will be filled with a placeholder nan-filled tensor.
# These nan-filled tensors get filtered out in the _discriminate,
# _get_gradient_penalty, and _discriminate_attributes functions.

generated_batch = self.generator(attribute_noise, feature_noise)
real_batch = [x.to(self.device) for x in real_batch]
generated_batch = self.generator(attribute_noise, feature_noise)
real_batch = [
x.to(self.device, non_blocking=True) for x in real_batch
]

for _ in range(self.config.discriminator_rounds):
opt_discriminator.zero_grad()
generated_output = self._discriminate(generated_batch)
real_output = self._discriminate(real_batch)

loss_generated = torch.mean(generated_output)
loss_real = -torch.mean(real_output)
loss_gradient_penalty = self._get_gradient_penalty(
generated_batch, real_batch, self._discriminate
opt_discriminator.zero_grad(
set_to_none=self.config.mixed_precision_training
)

loss = (
loss_generated
+ loss_real
+ self.config.gradient_penalty_coef * loss_gradient_penalty
)

loss.backward(retain_graph=True)
opt_discriminator.step()

if opt_attribute_discriminator is not None:
opt_attribute_discriminator.zero_grad()
# Exclude features (last element of batches) for
# attribute discriminator
generated_output = self._discriminate_attributes(
generated_batch[:-1]
)
real_output = self._discriminate_attributes(real_batch[:-1])
with torch.cuda.amp.autocast(enabled=True):
generated_output = self._discriminate(generated_batch)
real_output = self._discriminate(real_batch)

loss_generated = torch.mean(generated_output)
loss_real = -torch.mean(real_output)
loss_gradient_penalty = self._get_gradient_penalty(
generated_batch[:-1],
real_batch[:-1],
self._discriminate_attributes,
generated_batch, real_batch, self._discriminate
)

attribute_loss = (
loss = (
loss_generated
+ loss_real
+ self.config.attribute_gradient_penalty_coef
* loss_gradient_penalty
+ self.config.gradient_penalty_coef * loss_gradient_penalty
)

attribute_loss.backward(retain_graph=True)
opt_attribute_discriminator.step()
scaler.scale(loss).backward(retain_graph=True)
scaler.step(opt_discriminator)
scaler.update()

for _ in range(self.config.generator_rounds):
opt_generator.zero_grad()
generated_output = self._discriminate(generated_batch)

if self.attribute_discriminator:
# Exclude features (last element of batch) before
# calling attribute discriminator
attribute_generated_output = self._discriminate_attributes(
generated_batch[:-1]
)

loss = -torch.mean(
generated_output
) + self.config.attribute_loss_coef * -torch.mean(
attribute_generated_output
)
else:
loss = -torch.mean(generated_output)
if opt_attribute_discriminator is not None:
opt_attribute_discriminator.zero_grad(set_to_none=False)
# Exclude features (last element of batches) for
# attribute discriminator
with torch.cuda.amp.autocast(
enabled=self.config.mixed_precision_training
):
generated_output = self._discriminate_attributes(
generated_batch[:-1]
)
real_output = self._discriminate_attributes(real_batch[:-1])

loss_generated = torch.mean(generated_output)
loss_real = -torch.mean(real_output)
loss_gradient_penalty = self._get_gradient_penalty(
generated_batch[:-1],
real_batch[:-1],
self._discriminate_attributes,
)

attribute_loss = (
loss_generated
+ loss_real
+ self.config.attribute_gradient_penalty_coef
* loss_gradient_penalty
)

scaler.scale(attribute_loss).backward(retain_graph=True)
scaler.step(opt_attribute_discriminator)
scaler.update()

loss.backward()
opt_generator.step()
for _ in range(self.config.generator_rounds):
opt_generator.zero_grad(set_to_none=False)
with torch.cuda.amp.autocast(
enabled=self.config.mixed_precision_training
):
generated_output = self._discriminate(generated_batch)

if self.attribute_discriminator:
# Exclude features (last element of batch) before
# calling attribute discriminator
attribute_generated_output = self._discriminate_attributes(
generated_batch[:-1]
)

loss = -torch.mean(
generated_output
) + self.config.attribute_loss_coef * -torch.mean(
attribute_generated_output
)
else:
loss = -torch.mean(generated_output)

scaler.scale(loss).backward()
scaler.step(opt_generator)
scaler.update()

def _generate(
self, attribute_noise: torch.Tensor, feature_noise: torch.Tensor
Expand Down
39 changes: 20 additions & 19 deletions src/gretel_synthetics/timeseries_dgan/torch_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ def __init__(self, input_dim: int, outputs: List[Output], dim_index: int):
[
(
"linear",
torch.nn.Linear(input_dim, output.dim),
torch.nn.Linear(int(input_dim), int(output.dim)),
),
("softmax", torch.nn.Softmax(dim=dim_index)),
("softmax", torch.nn.Softmax(dim=int(dim_index))),
]
)
)
Expand All @@ -101,7 +101,7 @@ def __init__(self, input_dim: int, outputs: List[Output], dim_index: int):
[
(
"linear",
torch.nn.Linear(input_dim, output.dim),
torch.nn.Linear(int(input_dim), int(output.dim)),
),
("normalization", normalizer),
]
Expand Down Expand Up @@ -188,7 +188,6 @@ def __init__(
attribute_num_units,
attribute_num_layers,
)

(
self.additional_attribute_gen,
additional_attribute_dim,
Expand All @@ -198,18 +197,19 @@ def __init__(
attribute_num_units,
attribute_num_layers,
)

self.feature_gen = torch.nn.Sequential(
OrderedDict(
[
(
"lstm",
torch.nn.LSTM(
attribute_dim
+ additional_attribute_dim
+ feature_noise_dim,
feature_num_units,
feature_num_layers,
int(
attribute_dim
+ additional_attribute_dim
+ feature_noise_dim
),
int(feature_num_units),
int(feature_num_layers),
batch_first=True,
),
),
Expand All @@ -219,7 +219,7 @@ def __init__(
Merger(
[
OutputDecoder(
feature_num_units, feature_outputs, dim_index=2
int(feature_num_units), feature_outputs, dim_index=2
)
for _ in range(self.sample_len)
],
Expand Down Expand Up @@ -254,16 +254,16 @@ def _make_attribute_generator(
if not outputs:
return None, 0
seq = []
last_dim = input_dim
last_dim = int(input_dim)
for _ in range(num_layers):
seq.append(torch.nn.Linear(last_dim, num_units))
seq.append(torch.nn.Linear(int(last_dim), int(num_units)))
seq.append(torch.nn.ReLU())
seq.append(torch.nn.BatchNorm1d(num_units))
last_dim = num_units
seq.append(torch.nn.BatchNorm1d(int(num_units)))
last_dim = int(num_units)

seq.append(OutputDecoder(last_dim, outputs, dim_index=1))
seq.append(OutputDecoder(int(last_dim), outputs, dim_index=1))
attribute_dim = sum(output.dim for output in outputs)
return torch.nn.Sequential(*seq), attribute_dim
return torch.nn.Sequential(*seq), int(attribute_dim)

def forward(
self, attribute_noise: torch.Tensor, feature_noise: torch.Tensor
Expand All @@ -285,6 +285,7 @@ def forward(
"""

# Attribute features exist

empty_tensor = torch.Tensor(np.full((1, 1), np.nan))

if self.attribute_gen is not None:
Expand Down Expand Up @@ -368,11 +369,11 @@ def __init__(self, input_dim: int, num_layers: int = 5, num_units: int = 200):
seq = []
last_dim = input_dim
for _ in range(num_layers):
seq.append(torch.nn.Linear(last_dim, num_units))
seq.append(torch.nn.Linear(int(last_dim), int(num_units)))
seq.append(torch.nn.ReLU())
last_dim = num_units

seq.append(torch.nn.Linear(last_dim, 1))
seq.append(torch.nn.Linear(int(last_dim), 1))

self.seq = torch.nn.Sequential(*seq)

Expand Down

0 comments on commit 6e6bbfc

Please sign in to comment.