Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

When calling trainer.test() train_dataloader is also validated, which makes no sense #19745

Open
asusdisciple opened this issue Apr 8, 2024 · 2 comments
Labels
bug Something isn't working strategy: deepspeed

Comments

@asusdisciple
Copy link

asusdisciple commented Apr 8, 2024

Bug description

In the current logic of pytorch-lightning everytime I call atrainer.test()it is also checked if the train_dataloader() function makes sense. This is problematic.

For example, I use a WeightedRandomSampler only in the train_dataloader for obvious reasons. In order for this to work I calculate
the weights and num_samples parameters in the setup() stage="fit" section of my code.

Of course when I trigger trainer.test() this code is not executed and thus weights and num_samples are never calculated, which
leads to an error when lightning validates the train_dataloader function.

I dont see any best practices to avoid this and no reason to validate code which is never executed.

What version are you seeing the problem on?

v2.2

How to reproduce the bug

No response

Error messages and logs

# Error messages and logs here please

Environment

Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0):
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 2.0):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
#- Running environment of LightningApp (e.g. local, cloud):

More info

No response

cc @justusschock @awaelchli

@asusdisciple asusdisciple added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Apr 8, 2024
@carmocca
Copy link
Contributor

carmocca commented Apr 9, 2024

This should not happen. Can you update the snippet below to show the problem?

import os

import torch
from lightning.pytorch import LightningModule, Trainer
from torch.utils.data import DataLoader, Dataset


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)

    def train_dataloader(self):
        raise RuntimeError

    def val_dataloader(self):
        raise RuntimeError

    def test_dataloader(self):
        return DataLoader(RandomDataset(32, 64), batch_size=2)


def run():
    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        limit_val_batches=1,
        limit_test_batches=1,
        num_sanity_val_steps=0,
        max_epochs=1,
        enable_model_summary=False,
    )
    trainer.test(model)


if __name__ == "__main__":
    run()

@carmocca carmocca added waiting on author Waiting on user action, correction, or update data handling Generic data-related topic and removed needs triage Waiting to be triaged by maintainers labels Apr 9, 2024
@asusdisciple
Copy link
Author

I found the bug. It appears when you use the strategy "deepspeed" in the trainer. Code below :)

import os

import torch

from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
import lightning as L


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class LDataset(L.LightningDataModule):
    def __init__(self):
        super().__init__()
        self.num_samples = None
        self.weights = None
        self.len = None
        self.train_data = None
        self.test_data = None

    def setup(self, stage: str):
        if stage == "fit":
            self.train_data = RandomDataset(32, 14)

            # since RandomSampler only balances train data, the weights are calculated here naturally
            self.weights = [1, 1, 1, 1, 1, 1]
            self.num_samples = len(self.train_data)
        if stage == "test":
            self.test_data = RandomDataset(32, 14)

    def train_dataloader(self):
        return DataLoader(self.train_data,
                          sampler=WeightedRandomSampler(replacement=True,
                                                        weights=self.weights,
                                                        num_samples=self.num_samples
                                                        ),
                          batch_size=2)

    def val_dataloader(self):
        raise RuntimeError

    def test_dataloader(self):
        return DataLoader(self.test_data, batch_size=2)


class BoringModel(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)




def run():
    model = BoringModel()
    mydata = LDataset()
    trainer = L.Trainer(
        strategy="deepspeed",
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        limit_val_batches=1,
        limit_test_batches=1,
        num_sanity_val_steps=0,
        max_epochs=1,
        enable_model_summary=False,
        log_every_n_steps=10,
        enable_checkpointing=True,
        check_val_every_n_epoch=5
    )
    trainer.test(model, datamodule=mydata)


if __name__ == "__main__":
    run()

@carmocca carmocca added strategy: deepspeed and removed waiting on author Waiting on user action, correction, or update data handling Generic data-related topic labels Apr 11, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working strategy: deepspeed
Projects
None yet
Development

No branches or pull requests

2 participants