Skip to content

Commit 4528b8e

Browse files
committed
ref: inner train loop (intermediate step) 3/n
1 parent 7a57042 commit 4528b8e

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

pytorch_lightning/trainer/training_loop.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -938,8 +938,7 @@ def optimizer_closure(self, split_batch, batch_idx, opt_idx, optimizer, hiddens)
938938
result = self.train_loop.training_step(split_batch, batch_idx, opt_idx, hiddens)
939939

940940
# backward pass
941-
with self.profiler.profile('model_backward'):
942-
result.closure_loss = self.accelerator_backend.backward(result.closure_loss, optimizer, opt_idx)
941+
self.train_loop.backward(result, optimizer, opt_idx)
943942

944943
# hook
945944
self.train_loop.on_after_backward(result.training_step_output, batch_idx, result.untouched_loss)

pytorch_lightning/trainer/training_loop_temp.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,11 @@ def get_optimizers_iterable(self):
134134
opt_idx = np.argmax(optimizer_freq_cumsum > current_place_in_loop)
135135
return [(opt_idx, self.trainer.optimizers[opt_idx])]
136136

137+
def backward(self, result, optimizer, opt_idx):
138+
# backward pass
139+
with self.trainer.profiler.profile('model_backward'):
140+
result.closure_loss = self.trainer.accelerator_backend.backward(result.closure_loss, optimizer, opt_idx)
141+
137142
def on_after_backward(self, training_step_output, batch_idx, untouched_loss):
138143
is_result_obj = isinstance(training_step_output, Result)
139144

0 commit comments

Comments
 (0)