-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
When calling trainer.test() train_dataloader is also validated, which makes no sense #19745
Labels
Comments
asusdisciple
added
bug
Something isn't working
needs triage
Waiting to be triaged by maintainers
labels
Apr 8, 2024
This should not happen. Can you update the snippet below to show the problem? import os
import torch
from lightning.pytorch import LightningModule, Trainer
from torch.utils.data import DataLoader, Dataset
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("train_loss", loss)
return {"loss": loss}
def validation_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("valid_loss", loss)
def test_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("test_loss", loss)
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
def train_dataloader(self):
raise RuntimeError
def val_dataloader(self):
raise RuntimeError
def test_dataloader(self):
return DataLoader(RandomDataset(32, 64), batch_size=2)
def run():
model = BoringModel()
trainer = Trainer(
default_root_dir=os.getcwd(),
limit_train_batches=1,
limit_val_batches=1,
limit_test_batches=1,
num_sanity_val_steps=0,
max_epochs=1,
enable_model_summary=False,
)
trainer.test(model)
if __name__ == "__main__":
run() |
carmocca
added
waiting on author
Waiting on user action, correction, or update
data handling
Generic data-related topic
and removed
needs triage
Waiting to be triaged by maintainers
labels
Apr 9, 2024
I found the bug. It appears when you use the strategy "deepspeed" in the trainer. Code below :)
|
carmocca
added
strategy: deepspeed
and removed
waiting on author
Waiting on user action, correction, or update
data handling
Generic data-related topic
labels
Apr 11, 2024
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Bug description
In the current logic of pytorch-lightning everytime I call a
trainer.test()
it is also checked if thetrain_dataloader()
function makes sense. This is problematic.For example, I use a
WeightedRandomSampler
only in thetrain_dataloader
for obvious reasons. In order for this to work I calculatethe
weights
andnum_samples
parameters in thesetup() stage="fit"
section of my code.Of course when I trigger
trainer.test()
this code is not executed and thus weights and num_samples are never calculated, whichleads to an error when lightning validates the
train_dataloader
function.I dont see any best practices to avoid this and no reason to validate code which is never executed.
What version are you seeing the problem on?
v2.2
How to reproduce the bug
No response
Error messages and logs
Environment
Current environment
More info
No response
cc @justusschock @awaelchli
The text was updated successfully, but these errors were encountered: