Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

trainer.test results not reproduced when loading from a checkpoint #1640

Closed
robmarkcole opened this issue Oct 10, 2023 · 6 comments · Fixed by #1647
Closed

trainer.test results not reproduced when loading from a checkpoint #1640

robmarkcole opened this issue Oct 10, 2023 · 6 comments · Fixed by #1647
Labels
trainers PyTorch Lightning trainers
Milestone

Comments

@robmarkcole
Copy link
Contributor

Description

As title: in my train notebook:

┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric        ┃       DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│   test_AverageAccuracy    │     0.956980288028717     │
│       test_F1Score        │    0.9579629898071289     │
│     test_JaccardIndex     │    0.9181522130966187     │
│   test_OverallAccuracy    │    0.9579629898071289     │
│         test_loss         │     0.148747056722641     │
└───────────────────────────┴───────────────────────────┘

Then in predict notebook, loading from checkpoint:

┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric        ┃       DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│   test_AverageAccuracy    │    0.1227879747748375     │
│       test_F1Score        │    0.11518518626689911    │
│     test_JaccardIndex     │    0.03011714667081833    │
│   test_OverallAccuracy    │    0.11518518626689911    │
│         test_loss         │    20.139453887939453     │

This is Eurosat 10 class, so basically results are random guess

Steps to reproduce

In train notebook:

model = ClassificationTask(
    model="resnet18",
    weights=True, # standard Imagenet
    num_classes=10,
    in_channels=13,
    loss="ce", 
    patience=10
)

# tb_logger = TensorBoardLogger("tensorboard_logs", name="eurosat")
wandb_logger = WandbLogger(
    project="eurosat-tests", 
    name="imagenet", 
    log_model=True, 
    save_dir = "wandb_logs_tests"
)

trainer = Trainer(
    callbacks=[checkpoint_callback, early_stopping_callback],
    logger=wandb_logger,
    min_epochs=5,
    max_epochs=25,
    # enable_model_summary=False
)

# saves my.ckpt

In predict notebook:

ckpt_path
model = ClassificationTask.load_from_checkpoint(my.ckpt, map_location=torch.device(device))
trainer = Trainer()
trainer.test(model=model, dataloaders=datamodule.test_dataloader())

Version

0.5.0

@robmarkcole
Copy link
Contributor Author

robmarkcole commented Oct 10, 2023

OK think I know what is going on - I updated my wandb_logger to use log_model="all", which should save a checkpoint every epoch, but only a SINGLE checkpoint is saved. Training completed at 6 epochs but the checkpoint is from epoch 3. So what could be happening is that the logger is expecting to use a callback, but torchgeo over-rides that with its own callback..?

Note: this is NOT the case

@adamjstewart
Copy link
Collaborator

Ah, @roybenhayun reported the same bug but we couldn't figure out why it wasn't working. Now we know!

@adamjstewart adamjstewart added this to the 0.5.1 milestone Oct 10, 2023
@adamjstewart adamjstewart added the trainers PyTorch Lightning trainers label Oct 10, 2023
@roybenhayun
Copy link

opened #1645 for checkpoints saving not working as expected in 0.5.0

@robmarkcole
Copy link
Contributor Author

I had some suspicion the pre-processing may not be getting applied on the test run and I think I have confirmed that. I created a module with no normalisation:

class EuroSATDataModule(NonGeoDataModule):
    """LightningDataModule implementation for the EuroSAT dataset.

    Uses the train/val/test splits from the dataset.

    .. versionadded:: 0.2
    """

    mean = torch.zeros(13)
    std = torch.ones(13)

    def __init__(
        self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any
    ) -> None:
        """Initialize a new EuroSATDataModule instance.

        Args:
            batch_size: Size of each mini-batch.
            num_workers: Number of workers for parallel data loading.
            **kwargs: Additional keyword arguments passed to
                :class:`~torchgeo.datasets.EuroSAT`.
        """
        super().__init__(EuroSAT, batch_size, num_workers, **kwargs)

And now the test results are reproduced

CC @roybenhayun @calebrob6

@adamjstewart
Copy link
Collaborator

Did we ever figure out what's wrong with this? Was this fixed by #1647 or is there a different issue?

@adamjstewart adamjstewart modified the milestones: 0.5.1, 0.5.2 Nov 6, 2023
@robmarkcole
Copy link
Contributor Author

This appears resolved now - below is train (left) and predict (right) on a checkpoint file

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
trainers PyTorch Lightning trainers
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants