Skip to content

Commit

Permalink
support 4.30 transformers (#1019)
Browse files Browse the repository at this point in the history
  • Loading branch information
xin3he committed Jun 14, 2023
1 parent 8671d14 commit 256c1dd
Showing 1 changed file with 26 additions and 9 deletions.
35 changes: 26 additions & 9 deletions intel_extension_for_transformers/optimization/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,11 @@
xm = LazyImport('torch_xla.core.xla_model')
timeit = LazyImport('timeit')

if version.parse(__version__) < version.parse("4.30"):
NEW_DEEPSPEED_FLAG = False
else:
NEW_DEEPSPEED_FLAG = True


class BaseTrainer():
"""The base class of trainer."""
Expand Down Expand Up @@ -1149,9 +1154,10 @@ def training_step(
if self.args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training

if self.args.gradient_accumulation_steps > 1 and not self.deepspeed:
# deepspeed handles loss scaling by gradient_accumulation_steps in its `backward`
loss = loss / self.args.gradient_accumulation_steps
if not NEW_DEEPSPEED_FLAG:
if self.args.gradient_accumulation_steps > 1 and not self.deepspeed:
# deepspeed handles loss scaling by gradient_accumulation_steps in its `backward`
loss = loss / self.args.gradient_accumulation_steps

if self.compression_ctrl is not None:
compression_loss = self.compression_ctrl.loss()
Expand All @@ -1175,6 +1181,9 @@ def training_step(
elif self.use_apex:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
elif NEW_DEEPSPEED_FLAG:
self.accelerator.backward(loss)
loss / self.args.gradient_accumulation_steps
elif self.deepspeed:
# loss gets scaled under gradient_accumulation_steps in deepspeed
loss = self.deepspeed.backward(loss)
Expand Down Expand Up @@ -1229,9 +1238,10 @@ def training_step_length_adaptive(
if self.args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training

if self.args.gradient_accumulation_steps > 1 and not self.deepspeed:
# deepspeed handles loss scaling by gradient_accumulation_steps in its `backward`
loss = loss / self.args.gradient_accumulation_steps
if not NEW_DEEPSPEED_FLAG:
if self.args.gradient_accumulation_steps > 1 and not self.deepspeed:
# deepspeed handles loss scaling by gradient_accumulation_steps in its `backward`
loss = loss / self.args.gradient_accumulation_steps

if self.compression_ctrl is not None: # TODO- should be added here?
compression_loss = self.compression_ctrl.loss()
Expand Down Expand Up @@ -1260,6 +1270,9 @@ def training_step_length_adaptive(
elif self.use_apex:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
elif NEW_DEEPSPEED_FLAG:
self.accelerator.backward(loss)
loss / self.args.gradient_accumulation_steps
elif self.deepspeed:
# loss gets scaled under gradient_accumulation_steps in deepspeed
loss = self.deepspeed.backward(loss)
Expand Down Expand Up @@ -1318,9 +1331,10 @@ def training_step_length_adaptive(
if self.args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training

if self.args.gradient_accumulation_steps > 1 and not self.deepspeed:
# deepspeed handles loss scaling by gradient_accumulation_steps in its `backward`
loss = loss / self.args.gradient_accumulation_steps
if not NEW_DEEPSPEED_FLAG:
if self.args.gradient_accumulation_steps > 1 and not self.deepspeed:
# deepspeed handles loss scaling by gradient_accumulation_steps in its `backward`
loss = loss / self.args.gradient_accumulation_steps

if self.compression_ctrl is not None: # TODO- should be added here?
compression_loss = self.compression_ctrl.loss()
Expand Down Expand Up @@ -1351,6 +1365,9 @@ def training_step_length_adaptive(
elif self.use_apex:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
elif NEW_DEEPSPEED_FLAG:
self.accelerator.backward(loss)
loss / self.args.gradient_accumulation_steps
elif self.deepspeed:
# loss gets scaled under gradient_accumulation_steps in deepspeed
loss = self.deepspeed.backward(loss)
Expand Down

0 comments on commit 256c1dd

Please sign in to comment.