Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Pytorch lightning unit test #4780

Merged
merged 4 commits into from
Jul 10, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .github/workflows/checks-integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ jobs:
sudo apt-get -y install openmpi-bin libopenmpi-dev libopenblas-dev

# TODO(Shinichi): Remove the version constraint on SQLAlchemy
# TODO(Shinichi): Remove the version constraint on PyTorch Lightning
- name: Install
run: |
python -m pip install -U pip
Expand All @@ -42,7 +41,6 @@ jobs:
pip install --progress-bar off -U bayesmark
pip install --progress-bar off -U kurobako
pip install "sqlalchemy<2.0.0"
pip install "pytorch-lightning<2.0.0"

- name: Output installed packages
run: |
Expand Down
2 changes: 0 additions & 2 deletions .github/workflows/coverage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ jobs:
sudo apt-get update
sudo apt-get -y install openmpi-bin libopenmpi-dev libopenblas-dev

# TODO(Shinichi): Remove the version constraint on PyTorch Lightning
# TODO(c-bata): Remove the version constraint on fakeredis
- name: Install
run: |
Expand All @@ -56,7 +55,6 @@ jobs:
pip install --progress-bar off .[test]
pip install --progress-bar off .[optional]
pip install --progress-bar off .[integration] --extra-index-url https://download.pytorch.org/whl/cpu
pip install "pytorch-lightning<2.0.0"
pip install "fakeredis<2.11.1"

echo 'import coverage; coverage.process_startup()' > sitecustomize.py
Expand Down
4 changes: 0 additions & 4 deletions .github/workflows/mac-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ jobs:
restore-keys: |
${{ runner.os }}-3.8-${{ env.cache-name }}-${{ hashFiles('**/pyproject.toml') }}

# TODO(Shinichi): Remove the version constraint on PyTorch Lightning
# TODO(c-bata): Remove the version constraint on fakeredis
- name: Install
run: |
Expand All @@ -53,7 +52,6 @@ jobs:
optuna --version
pip install --progress-bar off .[test]
pip install --progress-bar off .[optional]
pip install "pytorch-lightning<2.0.0"
pip install "fakeredis<2.11.1"

- name: Output installed packages
Expand Down Expand Up @@ -105,7 +103,6 @@ jobs:
brew install open-mpi
brew install openblas

# TODO(Shinichi): Remove the version constraint on PyTorch Lightning
# TODO(c-bata): Remove the version constraint on fakeredis
- name: Install
run: |
Expand All @@ -119,7 +116,6 @@ jobs:
pip install --progress-bar off .[test]
pip install --progress-bar off .[optional]
pip install --progress-bar off .[integration]
pip install "pytorch-lightning<2.0.0"
pip install "fakeredis<2.11.1"

- name: Output installed packages
Expand Down
2 changes: 0 additions & 2 deletions .github/workflows/tests-integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ jobs:
sudo apt-get update
sudo apt-get -y install openmpi-bin libopenmpi-dev libopenblas-dev

# TODO(Shinichi): Remove the version constraint on PyTorch Lightning
# TODO(c-bata): Remove the version constraint on fakeredis
- name: Install
run: |
Expand All @@ -60,7 +59,6 @@ jobs:
pip install --progress-bar off .[test]
pip install --progress-bar off .[optional]
pip install --progress-bar off .[integration] --extra-index-url https://download.pytorch.org/whl/cpu
pip install "pytorch-lightning<2.0.0"
pip install "fakeredis<2.11.1"

- name: Output installed packages
Expand Down
4 changes: 0 additions & 4 deletions .github/workflows/windows-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ jobs:
restore-keys: |
${{ runner.os }}-3.9-${{ env.cache-name }}-${{ hashFiles('**/pyproject.toml') }}

# TODO(Shinichi): Remove the version constraint on PyTorch Lightning
# TODO(c-bata): Remove the version constraint on fakeredis
- name: Install
run: |
Expand All @@ -53,7 +52,6 @@ jobs:
optuna --version
pip install --progress-bar off .[test]
pip install --progress-bar off .[optional]
pip install "pytorch-lightning<2.0.0"
pip install PyQt6 # Install PyQT for using QtAgg as matplotlib backend.
pip install "fakeredis<2.11.1"

Expand Down Expand Up @@ -109,7 +107,6 @@ jobs:
with:
mpi: "msmpi"

# TODO(Shinichi): Remove the version constraint on PyTorch Lightning
# TODO(c-bata): Remove the version constraint on fakeredis
- name: Install
run: |
Expand All @@ -122,7 +119,6 @@ jobs:
pip install --progress-bar off .[test]
pip install --progress-bar off .[optional]
pip install --progress-bar off .[integration]
pip install "pytorch-lightning<2.0.0"
pip install "distributed<2023.3.2"
pip install "fakeredis<2.11.1"

Expand Down
52 changes: 17 additions & 35 deletions tests/integration_tests/test_pytorch_lightning.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
from typing import Any
from typing import cast
from typing import Dict
from typing import List
from typing import Mapping
from typing import Sequence
from typing import Union
from __future__ import annotations

from collections.abc import Sequence

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess you unintentionally changed the import statement here.

Suggested change
from collections.abc import Sequence
from typing import Sequence

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe it is recommended by PEP584.
If it is better to separate type annotation update from this PR I'll revert the change.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for letting me know! I didn't know that 🙇

import numpy as np
import pytest
Expand All @@ -19,7 +15,6 @@
with try_import() as _imports:
import pytorch_lightning as pl
from pytorch_lightning import LightningModule
from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy
import torch
from torch import nn
from torch.multiprocessing.spawn import ProcessRaisedException
Expand All @@ -35,40 +30,34 @@ class Model(LightningModule):
def __init__(self) -> None:
super().__init__()
self._model = nn.Sequential(nn.Linear(4, 8))
self.validation_step_outputs: list["torch.Tensor"] = []

def forward(self, data: "torch.Tensor") -> "torch.Tensor":
return self._model(data)

def training_step(
self, batch: Sequence["torch.Tensor"], batch_nb: int
) -> Dict[str, "torch.Tensor"]:
) -> dict[str, "torch.Tensor"]:
data, target = batch
output = self.forward(data)
loss = F.nll_loss(output, target)
return {"loss": loss}

def validation_step(
self, batch: Sequence["torch.Tensor"], batch_nb: int
) -> Dict[str, "torch.Tensor"]:
def validation_step(self, batch: Sequence["torch.Tensor"], batch_nb: int) -> "torch.Tensor":
data, target = batch
output = self.forward(data)
pred = output.argmax(dim=1, keepdim=True)
accuracy = pred.eq(target.view_as(pred)).double().mean()
return {"validation_accuracy": accuracy}
self.validation_step_outputs.append(accuracy)
return accuracy

def validation_epoch_end(
def on_validation_epoch_end(
self,
outputs: Union[
Sequence[Union["torch.Tensor", Mapping[str, Any]]],
Sequence[Sequence[Union["torch.Tensor", Mapping[str, Any]]]],
],
) -> None:
if not len(outputs):
if not len(self.validation_step_outputs):
return

accuracy = sum(
x["validation_accuracy"] for x in cast(List[Dict[str, "torch.Tensor"]], outputs)
) / len(outputs)
accuracy = sum(self.validation_step_outputs) / len(self.validation_step_outputs)
self.log("accuracy", accuracy)

def configure_optimizers(self) -> "torch.optim.Optimizer":
Expand All @@ -91,9 +80,7 @@ class ModelDDP(Model):
def __init__(self) -> None:
super().__init__()

def validation_step(
self, batch: Sequence["torch.Tensor"], batch_nb: int
) -> Dict[str, "torch.Tensor"]:
def validation_step(self, batch: Sequence["torch.Tensor"], batch_nb: int) -> "torch.Tensor":
data, target = batch
output = self.forward(data)
pred = output.argmax(dim=1, keepdim=True)
Expand All @@ -104,15 +91,9 @@ def validation_step(
accuracy = torch.tensor(0.6)

self.log("accuracy", accuracy, sync_dist=True)
return {"validation_accuracy": accuracy}
return accuracy

def validation_epoch_end(
self,
output: Union[
Sequence[Union["torch.Tensor", Mapping[str, Any]]],
Sequence[Sequence[Union["torch.Tensor", Mapping[str, Any]]]],
],
) -> None:
def on_validation_epoch_end(self) -> None:
return


Expand All @@ -121,6 +102,7 @@ def objective(trial: optuna.trial.Trial) -> float:
callback = PyTorchLightningPruningCallback(trial, monitor="accuracy")
trainer = pl.Trainer(
max_epochs=2,
accelerator="cpu",
enable_checkpointing=False,
callbacks=[callback],
)
Expand Down Expand Up @@ -168,7 +150,7 @@ def objective(trial: optuna.trial.Trial) -> float:
devices=2,
enable_checkpointing=False,
callbacks=[callback],
strategy=DDPSpawnStrategy(find_unused_parameters=False),
strategy="ddp_spawn",
)

model = ModelDDP()
Expand Down Expand Up @@ -206,7 +188,7 @@ def objective(trial: optuna.trial.Trial) -> float:
devices=2,
enable_checkpointing=False,
callbacks=[callback],
strategy=DDPSpawnStrategy(find_unused_parameters=False),
strategy="ddp_spawn",
)

model = ModelDDP()
Expand Down
Loading