Skip to content

Commit

Permalink
Merge pull request huggingface#53 from pytorch-tpu/jonbolin/step-time
Browse files Browse the repository at this point in the history
Add step time metrics via `xla_execution_time_step`
  • Loading branch information
jonb377 committed Dec 20, 2023
2 parents 57b505c + 87439d2 commit 9834eec
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 1 deletion.
25 changes: 24 additions & 1 deletion src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1841,7 +1841,7 @@ def zero_grad(x):
profile_duration = int(os.environ.get('PROFILE_DURATION_MS', 20000))
profile_logdir = os.environ.get('PROFILE_LOGDIR', None)
for step, inputs in enumerate(epoch_iterator):
if step == 0 and epoch == 0:
if self.state.global_step == 0:
print('input sharding', {k: (v.shape, torch_xla._XLAC._get_xla_sharding_spec(v)) for k, v in inputs.items()})
total_batched_samples += 1
if rng_to_sync:
Expand Down Expand Up @@ -1941,6 +1941,29 @@ def zero_grad(x):
else:
self.control = self.callback_handler.on_substep_end(args, self.state, self.control)

if self.args.xla_execution_time_step is not None:
# To ensure the step time is accurate, we need the measured wall
# time to only reflect execution of the single target step.
if self.state.global_step == self.args.xla_execution_time_step:
# After tracing is complete for the target step, wait for all device ops to
# complete before the `mark_step` call starts its execution on devices.
xm.wait_device_ops()
execution_time_start = time.time()
elif self.state.global_step == self.args.xla_execution_time_step + 1:
# The time taken to reach this point in the next step is the tracing time.
tracing_time = time.time() - execution_time_start
# Wait for the target step's execution to complete before measuring the
# execution's wall time.
xm.wait_device_ops()
step_wall_time = time.time() - execution_time_start
# Tracing must be faster than device execution by a measurable
# amount, otherwise the measured time may not actually reflect device
# execution time.
assert step_wall_time - tracing_time > 0.1, \
f"Tracing time ({tracing_time}s) too close to overall step wall time ({step_wall_time}s)"
metrics = {'step_wall_time': step_wall_time, 'tracing_time': tracing_time}
self.log(metrics)

if self.control.should_epoch_stop or self.control.should_training_stop:
break

Expand Down
1 change: 1 addition & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,7 @@ class TrainingArguments:
output_dir: str = field(
metadata={"help": "The output directory where the model predictions and checkpoints will be written."},
)
xla_execution_time_step: int = field(default=None, metadata={"help": "Global step to measure the on-device step execution time when using torch_xla."})
checkpoint_manager_path: Optional[str] = field(
default=None,
metadata={"help": "Specify the path for CheckpointManager will checkpoint to. This flag controls whether or not CheckpointManager will be used - if unspecified, CheckpointManager will not be used."},
Expand Down

0 comments on commit 9834eec

Please sign in to comment.