diff --git a/python/nano/src/bigdl/nano/pytorch/lightning.py b/python/nano/src/bigdl/nano/pytorch/lightning.py new file mode 100644 index 00000000000..ab8242ecb92 --- /dev/null +++ b/python/nano/src/bigdl/nano/pytorch/lightning.py @@ -0,0 +1,69 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import OrderedDict + +from pytorch_lightning import LightningModule +from torch import nn, Tensor +from torch.nn.modules.loss import _Loss +from torch.optim import Optimizer + + +class LightningModuleFromTorch(LightningModule): + def __init__(self, model: nn.Module, loss: _Loss, optimizer: Optimizer): + """ + Integrate pytorch modules, loss, optimizer to pytorch-lightning model. + + :param model: Pytorch model to be converted. + :param loss: A torch loss function. + :param optimizer: A torch optimizer. + """ + super().__init__() + self.model = model + self.loss = loss + self.optimizer = optimizer + + def forward(self, batch): + # Handle different numbers of input for various models + nargs = self.model.forward.__code__.co_argcount + return self.model(*(batch[:nargs - 1])) + + def training_step(self, batch, batch_idx): + x, y = batch + y_hat = self(batch) + loss = self.loss(y_hat, y) + return loss + + def validation_step(self, batch, batch_idx): + x, y = batch + y_hat = self(batch) + loss = self.loss(y_hat, y) + return loss + + def test_step(self, batch, batch_idx): + x, y = batch + y_hat = self(batch) + loss = self.loss(y_hat, y) + return loss + + def configure_optimizers(self): + return self.optimizer + + def load_state_dict(self, state_dict: 'OrderedDict[str, Tensor]', + strict: bool = True): + try: + super().load_state_dict(state_dict) + except RuntimeError: + self.model.load_state_dict(state_dict) diff --git a/python/nano/src/bigdl/nano/pytorch/trainer/Trainer.py b/python/nano/src/bigdl/nano/pytorch/trainer/Trainer.py index 632c82064f0..ba62be8bd4f 100644 --- a/python/nano/src/bigdl/nano/pytorch/trainer/Trainer.py +++ b/python/nano/src/bigdl/nano/pytorch/trainer/Trainer.py @@ -16,12 +16,17 @@ from logging import warning -import torch +from typing import Any, List, Optional + import pytorch_lightning as pl -from bigdl.nano.pytorch.plugins.ddp_spawn import DDPSpawnPlugin -from bigdl.nano.common import check_avx512 +import torch from pytorch_lightning.plugins.environments import LightningEnvironment -from typing import Any, List, Optional +from torch import nn +from torch.nn.modules.loss import _Loss + +from bigdl.nano.common import check_avx512 +from bigdl.nano.pytorch.lightning import LightningModuleFromTorch +from bigdl.nano.pytorch.plugins.ddp_spawn import DDPSpawnPlugin distributed_backends = ["spawn", "ray"] @@ -104,3 +109,28 @@ def __init__(self, num_processes: int = 1, super().__init__(accelerator=accelerator, plugins=[plugin], *args, **kwargs) + + @staticmethod + def compile(model: nn.Module, loss: _Loss = None, optimizer: torch.optim.Optimizer = None): + """ + Construct a pytorch-lightning model. If model is already a pytorch-lightning model, + return model. If model is pytorch model, construct a new pytorch-lightning module + with model, loss and optimizer. + + :param model: A model instance. + :param loss: Loss to construct pytorch-lightning model. + Should be None if model is instance of pl.LightningModule. + :param optimizer: Optimizer to construct pytorch-lightning model Should be None. + if model is instance of pl.LightningModule. + :return: A LightningModule object. + """ + assert isinstance(model, nn.Module), \ + "Model must be instance of nn.Module but got {}".format(model.__class__) + if isinstance(model, pl.LightningModule): + assert not (loss or optimizer), \ + "Loss and optimizer should be None if model is a pytorch-lightning model." + return model + else: + assert loss and optimizer, \ + "Loss and optimizer are required to construct a LightningModule instance." + return LightningModuleFromTorch(model, loss, optimizer) diff --git a/python/nano/test/test_lightning.py b/python/nano/test/test_lightning.py new file mode 100644 index 00000000000..4ac7aa8f688 --- /dev/null +++ b/python/nano/test/test_lightning.py @@ -0,0 +1,69 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +from unittest import TestCase + +import torch +from torch import nn + +from _train_torch_lightning import create_data_loader, data_transform +from bigdl.nano.pytorch.lightning import LightningModuleFromTorch +from bigdl.nano.pytorch.trainer import Trainer +from bigdl.nano.pytorch.vision.models import vision + +num_classes = 10 +batch_size = 256 +num_workers = 0 +data_dir = os.path.join(os.path.dirname(__file__), "data") + + +class ResNet18(nn.Module): + def __init__(self, pretrained=True, include_top=False, freeze=True): + super().__init__() + backbone = vision.resnet18(pretrained=pretrained, include_top=include_top, freeze=freeze) + output_size = backbone.get_output_size() + head = nn.Linear(output_size, num_classes) + self.model = torch.nn.Sequential(backbone, head) + + def forward(self, x): + return self.model(x) + + +model = ResNet18(pretrained=True, include_top=False, freeze=True) +loss = nn.CrossEntropyLoss() +optimizer = torch.optim.Adam(model.parameters(), lr=0.01) + + +class TestLightningModuleFromTorch(TestCase): + + def test_resnet18(self): + pl_model = LightningModuleFromTorch(model, loss, optimizer) + train_loader = create_data_loader(data_dir, batch_size, num_workers, data_transform) + trainer = Trainer(max_epochs=1) + trainer.fit(pl_model, train_loader) + + def test_load_state_dict_from_torch(self): + torch.save(model.state_dict(), "resnet18_test.pth") + pl_model = LightningModuleFromTorch(model, loss, optimizer) + state_dict = torch.load("resnet18_test.pth") + pl_model.load_state_dict(state_dict) + + def test_load_state_dict_from_lightning(self): + pl_model = LightningModuleFromTorch(model, loss, optimizer) + torch.save(pl_model.state_dict(), "lightning_resnet18_test.pth") + state_dict = torch.load("lightning_resnet18_test.pth") + pl_model.load_state_dict(state_dict) diff --git a/python/nano/test/test_trainer_ipex.py b/python/nano/test/test_trainer_ipex.py index 1c7d2d8f023..d056ce1b2da 100644 --- a/python/nano/test/test_trainer_ipex.py +++ b/python/nano/test/test_trainer_ipex.py @@ -18,10 +18,15 @@ import pytest import os from unittest import TestCase + +import torch +from torch import nn + +from test._train_torch_lightning import create_data_loader, data_transform +from bigdl.nano.pytorch.trainer import Trainer from bigdl.nano.pytorch.vision.models import vision from test._train_torch_lightning import train_with_linear_top_layer - batch_size = 256 num_workers = 0 data_dir = os.path.join(os.path.dirname(__file__), "data") @@ -36,6 +41,26 @@ def test_resnet18_ipex(self): resnet18, batch_size, num_workers, data_dir, use_orca_lite_trainer=True) + def test_trainer_compile(self): + class ResNet18(nn.Module): + def __init__(self, num_classes, pretrained=True, include_top=False, freeze=True): + super().__init__() + backbone = vision.resnet18(pretrained=pretrained, include_top=include_top, freeze=freeze) + output_size = backbone.get_output_size() + head = nn.Linear(output_size, num_classes) + self.model = nn.Sequential(backbone, head) + + def forward(self, x): + return self.model(x) + + model = ResNet18(10, pretrained=True, include_top=False, freeze=True) + loss = nn.CrossEntropyLoss() + optimizer = torch.optim.Adam(model.parameters(), lr=0.01) + trainer = Trainer(max_epochs=1) + pl_model = trainer.compile(model, loss, optimizer) + train_loader = create_data_loader(data_dir, batch_size, num_workers, data_transform) + trainer.fit(pl_model, train_loader) + if __name__ == '__main__': pytest.main([__file__])