Skip to content

Commit

Permalink
Enable bfloat16 training (#108)
Browse files Browse the repository at this point in the history
* Enable bfloat16 training

* Bump up to v0.0.26
  • Loading branch information
erogol committed Jun 9, 2023
1 parent 71223c4 commit 9879d3d
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 6 deletions.
2 changes: 1 addition & 1 deletion trainer/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
v0.0.25
v0.0.26
27 changes: 22 additions & 5 deletions trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,12 @@ class TrainerConfig(Coqpit):
)
# Fields for training specs
mixed_precision: bool = field(default=False, metadata={"help": "Use mixed precision training. Defaults to False"})
precision: str = field(
default="fp16",
metadata={
"help": "Precision to use in mixed precision training. `fp16` for float16 and `bf16` for bfloat16. Defaults to 'f16'"
},
)
epochs: int = field(default=1000, metadata={"help": "Number of epochs to train. Defaults to 1000"})
batch_size: int = field(default=32, metadata={"help": "Batch size to use. Defaults to 32"})
eval_batch_size: int = field(default=16, metadata={"help": "Batch size to use for eval. Defaults to 16"})
Expand Down Expand Up @@ -438,7 +444,11 @@ def __init__( # pylint: disable=dangerous-default-value
self.keep_avg_eval = None

self.use_apex = self._is_apex_available()
self.use_amp_scaler = self.use_cuda if self.config.mixed_precision else self.config.use_grad_scaler
self.use_amp_scaler = (
self.use_cuda
if self.config.mixed_precision and self.config.precision == "fp16"
else self.config.use_grad_scaler
)

if train_samples is not None:
# use the provided samples
Expand Down Expand Up @@ -993,12 +1003,19 @@ def _model_train_step(
return model.module.train_step(*input_args)
return model.train_step(*input_args)

def _get_autocast_args(self, mixed_precision: bool):
def _get_autocast_args(self, mixed_precision: bool, precision: str):
device = "cpu"
dtype = torch.get_autocast_cpu_dtype()
if self.use_cuda:
device = "cuda"
dtype = torch.float16 if mixed_precision else torch.float32
dtype = torch.float32
if mixed_precision:
if precision == "fp16":
dtype = torch.float16
elif precision == "bf16":
dtype = torch.bfloat16
else:
raise ValueError(f" ❗ Unknown precision {precision}")
elif mixed_precision:
dtype = torch.bfloat16
return device, dtype
Expand Down Expand Up @@ -1057,7 +1074,7 @@ def optimize(
step_start_time = time.time()

# forward pass and loss computation
device, dtype = self._get_autocast_args(config.mixed_precision)
device, dtype = self._get_autocast_args(config.mixed_precision, config.precision)
with torch.autocast(device_type=device, dtype=dtype, enabled=config.mixed_precision):
if optimizer_idx is not None:
outputs, loss_dict = self._model_train_step(batch, model, criterion, optimizer_idx=optimizer_idx)
Expand Down Expand Up @@ -1170,7 +1187,7 @@ def train_step(self, batch: Dict, batch_n_steps: int, step: int, loader_start_ti
if isimplemented(self.model, "optimize"): # pylint: disable=too-many-nested-blocks
# custom optimize for the model
step_time = time.time()
device, dtype = self._get_autocast_args(self.config.mixed_precision)
device, dtype = self._get_autocast_args(self.config.mixed_precision, self.config.precision)
with torch.autocast(device_type=device, dtype=dtype, enabled=self.config.mixed_precision):
outputs, loss_dict_new = self.model.optimize(
batch,
Expand Down

0 comments on commit 9879d3d

Please sign in to comment.