Skip to content

Commit b96dd21

Browse files
authored
Update new project code sample (Lightning-AI#2287)
1 parent f278ac4 commit b96dd21

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

docs/source/new-project.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ Under the hood, lightning does (in high-level pseudocode):
7676
.. code-block:: python
7777
7878
model = LitModel()
79+
torch.set_grad_enabled(True)
80+
model.train()
7981
train_dataloader = model.train_dataloader()
8082
optimizer = model.configure_optimizers()
8183
@@ -139,13 +141,15 @@ Under the hood in pseudocode, lightning does the following:
139141
# ...
140142

141143
if validate_at_some_point:
144+
torch.set_grad_enabled(False)
142145
model.eval()
143146
val_outs = []
144147
for val_batch in model.val_dataloader:
145148
val_out = model.validation_step(val_batch)
146149
val_outs.append(val_out)
147150

148151
model.validation_epoch_end(val_outs)
152+
torch.set_grad_enabled(True)
149153
model.train()
150154

151155
The beauty of Lightning is that it handles the details of when to validate, when to call .eval(),
@@ -196,6 +200,7 @@ Again, under the hood, lightning does the following in (pseudocode):
196200

197201
.. code-block:: python
198202
203+
torch.set_grad_enabled(False)
199204
model.eval()
200205
test_outs = []
201206
for test_batch in model.test_dataloader:

0 commit comments

Comments
 (0)