Skip to content

Commit

Permalink
Merge pull request #4780 from Alnusjaponica/pytorch-lightning
Browse files Browse the repository at this point in the history
Fix Pytorch lightning unit test
  • Loading branch information
c-bata committed Jul 10, 2023
2 parents f013fd8 + 989987f commit e3902d5
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 49 deletions.
2 changes: 0 additions & 2 deletions .github/workflows/checks-integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ jobs:
sudo apt-get -y install openmpi-bin libopenmpi-dev libopenblas-dev
# TODO(Shinichi): Remove the version constraint on SQLAlchemy
# TODO(Shinichi): Remove the version constraint on PyTorch Lightning
- name: Install
run: |
python -m pip install -U pip
Expand All @@ -42,7 +41,6 @@ jobs:
pip install --progress-bar off -U bayesmark
pip install --progress-bar off -U kurobako
pip install "sqlalchemy<2.0.0"
pip install "pytorch-lightning<2.0.0"
- name: Output installed packages
run: |
Expand Down
2 changes: 0 additions & 2 deletions .github/workflows/coverage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ jobs:
sudo apt-get -y install openmpi-bin libopenmpi-dev libopenblas-dev
# TODO(Shinichi): Remove the version constraint on Numpy
# TODO(Shinichi): Remove the version constraint on PyTorch Lightning
# TODO(c-bata): Remove the version constraint on fakeredis
- name: Install
run: |
Expand All @@ -58,7 +57,6 @@ jobs:
pip install --progress-bar off .[optional]
pip install --progress-bar off .[integration] --extra-index-url https://download.pytorch.org/whl/cpu
pip install "numpy<1.24.0"
pip install "pytorch-lightning<2.0.0"
pip install "fakeredis<2.11.1"
echo 'import coverage; coverage.process_startup()' > sitecustomize.py
Expand Down
4 changes: 0 additions & 4 deletions .github/workflows/mac-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ jobs:
restore-keys: |
${{ runner.os }}-3.8-${{ env.cache-name }}-${{ hashFiles('**/pyproject.toml') }}
# TODO(Shinichi): Remove the version constraint on PyTorch Lightning
# TODO(c-bata): Remove the version constraint on fakeredis
- name: Install
run: |
Expand All @@ -53,7 +52,6 @@ jobs:
optuna --version
pip install --progress-bar off .[test]
pip install --progress-bar off .[optional]
pip install "pytorch-lightning<2.0.0"
pip install "fakeredis<2.11.1"
- name: Output installed packages
Expand Down Expand Up @@ -106,7 +104,6 @@ jobs:
brew install openblas
# TODO(Shinichi): Remove the version constraint on Numpy
# TODO(Shinichi): Remove the version constraint on PyTorch Lightning
# TODO(c-bata): Remove the version constraint on fakeredis
- name: Install
run: |
Expand All @@ -121,7 +118,6 @@ jobs:
pip install --progress-bar off .[optional]
pip install --progress-bar off .[integration]
pip install "numpy<1.24.0"
pip install "pytorch-lightning<2.0.0"
pip install "fakeredis<2.11.1"
- name: Output installed packages
Expand Down
2 changes: 0 additions & 2 deletions .github/workflows/tests-integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ jobs:
sudo apt-get -y install openmpi-bin libopenmpi-dev libopenblas-dev
# TODO(Shinichi): Remove the version constraint on Numpy
# TODO(Shinichi): Remove the version constraint on PyTorch Lightning
# TODO(c-bata): Remove the version constraint on fakeredis
- name: Install
run: |
Expand All @@ -62,7 +61,6 @@ jobs:
pip install --progress-bar off .[optional]
pip install --progress-bar off .[integration] --extra-index-url https://download.pytorch.org/whl/cpu
pip install "numpy<1.24.0"
pip install "pytorch-lightning<2.0.0"
pip install "fakeredis<2.11.1"
- name: Output installed packages
Expand Down
4 changes: 0 additions & 4 deletions .github/workflows/windows-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ jobs:
restore-keys: |
${{ runner.os }}-3.9-${{ env.cache-name }}-${{ hashFiles('**/pyproject.toml') }}
# TODO(Shinichi): Remove the version constraint on PyTorch Lightning
# TODO(c-bata): Remove the version constraint on fakeredis
- name: Install
run: |
Expand All @@ -53,7 +52,6 @@ jobs:
optuna --version
pip install --progress-bar off .[test]
pip install --progress-bar off .[optional]
pip install "pytorch-lightning<2.0.0"
pip install PyQt6 # Install PyQT for using QtAgg as matplotlib backend.
pip install "fakeredis<2.11.1"
Expand Down Expand Up @@ -110,7 +108,6 @@ jobs:
mpi: "msmpi"

# TODO(Shinichi): Remove the version constraint on Numpy
# TODO(Shinichi): Remove the version constraint on PyTorch Lightning
# TODO(c-bata): Remove the version constraint on fakeredis
- name: Install
run: |
Expand All @@ -124,7 +121,6 @@ jobs:
pip install --progress-bar off .[optional]
pip install --progress-bar off .[integration]
pip install "numpy<1.24.0"
pip install "pytorch-lightning<2.0.0"
pip install "distributed<2023.3.2"
pip install "fakeredis<2.11.1"
Expand Down
52 changes: 17 additions & 35 deletions tests/integration_tests/test_pytorch_lightning.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
from typing import Any
from typing import cast
from typing import Dict
from typing import List
from typing import Mapping
from typing import Sequence
from typing import Union
from __future__ import annotations

from collections.abc import Sequence

import numpy as np
import pytest
Expand All @@ -19,7 +15,6 @@
with try_import() as _imports:
import pytorch_lightning as pl
from pytorch_lightning import LightningModule
from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy
import torch
from torch import nn
from torch.multiprocessing.spawn import ProcessRaisedException
Expand All @@ -35,40 +30,34 @@ class Model(LightningModule):
def __init__(self) -> None:
super().__init__()
self._model = nn.Sequential(nn.Linear(4, 8))
self.validation_step_outputs: list["torch.Tensor"] = []

def forward(self, data: "torch.Tensor") -> "torch.Tensor":
return self._model(data)

def training_step(
self, batch: Sequence["torch.Tensor"], batch_nb: int
) -> Dict[str, "torch.Tensor"]:
) -> dict[str, "torch.Tensor"]:
data, target = batch
output = self.forward(data)
loss = F.nll_loss(output, target)
return {"loss": loss}

def validation_step(
self, batch: Sequence["torch.Tensor"], batch_nb: int
) -> Dict[str, "torch.Tensor"]:
def validation_step(self, batch: Sequence["torch.Tensor"], batch_nb: int) -> "torch.Tensor":
data, target = batch
output = self.forward(data)
pred = output.argmax(dim=1, keepdim=True)
accuracy = pred.eq(target.view_as(pred)).double().mean()
return {"validation_accuracy": accuracy}
self.validation_step_outputs.append(accuracy)
return accuracy

def validation_epoch_end(
def on_validation_epoch_end(
self,
outputs: Union[
Sequence[Union["torch.Tensor", Mapping[str, Any]]],
Sequence[Sequence[Union["torch.Tensor", Mapping[str, Any]]]],
],
) -> None:
if not len(outputs):
if not len(self.validation_step_outputs):
return

accuracy = sum(
x["validation_accuracy"] for x in cast(List[Dict[str, "torch.Tensor"]], outputs)
) / len(outputs)
accuracy = sum(self.validation_step_outputs) / len(self.validation_step_outputs)
self.log("accuracy", accuracy)

def configure_optimizers(self) -> "torch.optim.Optimizer":
Expand All @@ -91,9 +80,7 @@ class ModelDDP(Model):
def __init__(self) -> None:
super().__init__()

def validation_step(
self, batch: Sequence["torch.Tensor"], batch_nb: int
) -> Dict[str, "torch.Tensor"]:
def validation_step(self, batch: Sequence["torch.Tensor"], batch_nb: int) -> "torch.Tensor":
data, target = batch
output = self.forward(data)
pred = output.argmax(dim=1, keepdim=True)
Expand All @@ -104,15 +91,9 @@ def validation_step(
accuracy = torch.tensor(0.6)

self.log("accuracy", accuracy, sync_dist=True)
return {"validation_accuracy": accuracy}
return accuracy

def validation_epoch_end(
self,
output: Union[
Sequence[Union["torch.Tensor", Mapping[str, Any]]],
Sequence[Sequence[Union["torch.Tensor", Mapping[str, Any]]]],
],
) -> None:
def on_validation_epoch_end(self) -> None:
return


Expand All @@ -121,6 +102,7 @@ def objective(trial: optuna.trial.Trial) -> float:
callback = PyTorchLightningPruningCallback(trial, monitor="accuracy")
trainer = pl.Trainer(
max_epochs=2,
accelerator="cpu",
enable_checkpointing=False,
callbacks=[callback],
)
Expand Down Expand Up @@ -168,7 +150,7 @@ def objective(trial: optuna.trial.Trial) -> float:
devices=2,
enable_checkpointing=False,
callbacks=[callback],
strategy=DDPSpawnStrategy(find_unused_parameters=False),
strategy="ddp_spawn",
)

model = ModelDDP()
Expand Down Expand Up @@ -206,7 +188,7 @@ def objective(trial: optuna.trial.Trial) -> float:
devices=2,
enable_checkpointing=False,
callbacks=[callback],
strategy=DDPSpawnStrategy(find_unused_parameters=False),
strategy="ddp_spawn",
)

model = ModelDDP()
Expand Down

0 comments on commit e3902d5

Please sign in to comment.