Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
b185867
added decorator to create a pytorch lightning model from torch
Oct 27, 2021
8151c8a
added unit test for pytorch lightning decorator
Oct 27, 2021
5e4b992
refactoring - renaming, adding hints and docstring
Oct 28, 2021
c6f5127
moved lightning extension to nano/pytorch
Oct 28, 2021
6580715
remove loss, optim creator and directly pass loss and optimizer to in…
Oct 28, 2021
06f4b88
added another implementation for pytorch to lightning
Oct 29, 2021
ff2f028
use LightningModuleFromTorch to create lightning module from pytorch
Nov 3, 2021
aae8fe9
remove temporary change
Nov 3, 2021
6af7535
remove redundant part
Nov 3, 2021
c02c13e
added trainer.compile to convert pytorch to pytorch-lightning
Nov 3, 2021
4b930f2
added unit test for trainer.compile
Nov 3, 2021
6aa2dc1
fixed return when input is pl model
Nov 3, 2021
d5bb5c1
added type hint for LightningModuleFromTorch.copy
Nov 3, 2021
a893aa3
Renamed copy as _copy
Nov 4, 2021
c6fb693
Modified comment of compile
Nov 4, 2021
596e4e6
added input checking
Nov 4, 2021
db4466b
refactored docstring
Nov 4, 2021
90eabc7
Reformat docstring
Nov 4, 2021
153c1a2
Tiny changes
Nov 4, 2021
c642dc9
reformat
Nov 4, 2021
1caa2f1
correct the import
Nov 4, 2021
f38dff2
type check and
Nov 4, 2021
9dbd8f0
assign model as a member variable
Nov 9, 2021
293e54a
override load_state_dict
Nov 9, 2021
d3c20d5
fix test_trainer_compile
Nov 9, 2021
6fabec3
fix test_lightning
Nov 9, 2021
d50e403
try lightning module and then self.model
Nov 9, 2021
6646454
rename _forward as forward
Nov 9, 2021
3dc152d
type check
Nov 9, 2021
36b8b6c
optimize imports
Nov 11, 2021
daa9b27
Merge branch 'branch-2.0' into pytorch_lightning_wrapper
yangw1234 Nov 11, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions python/nano/src/bigdl/nano/pytorch/lightning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from collections import OrderedDict

from pytorch_lightning import LightningModule
from torch import nn, Tensor
from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer


class LightningModuleFromTorch(LightningModule):
def __init__(self, model: nn.Module, loss: _Loss, optimizer: Optimizer):
"""
Integrate pytorch modules, loss, optimizer to pytorch-lightning model.

:param model: Pytorch model to be converted.
:param loss: A torch loss function.
:param optimizer: A torch optimizer.
"""
super().__init__()
self.model = model
self.loss = loss
self.optimizer = optimizer

def forward(self, batch):
# Handle different numbers of input for various models
nargs = self.model.forward.__code__.co_argcount
return self.model(*(batch[:nargs - 1]))

def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(batch)
loss = self.loss(y_hat, y)
return loss

def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(batch)
loss = self.loss(y_hat, y)
return loss

def test_step(self, batch, batch_idx):
x, y = batch
y_hat = self(batch)
loss = self.loss(y_hat, y)
return loss

def configure_optimizers(self):
return self.optimizer

def load_state_dict(self, state_dict: 'OrderedDict[str, Tensor]',
strict: bool = True):
try:
super().load_state_dict(state_dict)
except RuntimeError:
self.model.load_state_dict(state_dict)
38 changes: 34 additions & 4 deletions python/nano/src/bigdl/nano/pytorch/trainer/Trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,17 @@


from logging import warning
import torch
from typing import Any, List, Optional

import pytorch_lightning as pl
from bigdl.nano.pytorch.plugins.ddp_spawn import DDPSpawnPlugin
from bigdl.nano.common import check_avx512
import torch
from pytorch_lightning.plugins.environments import LightningEnvironment
from typing import Any, List, Optional
from torch import nn
from torch.nn.modules.loss import _Loss

from bigdl.nano.common import check_avx512
from bigdl.nano.pytorch.lightning import LightningModuleFromTorch
from bigdl.nano.pytorch.plugins.ddp_spawn import DDPSpawnPlugin

distributed_backends = ["spawn", "ray"]

Expand Down Expand Up @@ -104,3 +109,28 @@ def __init__(self, num_processes: int = 1,

super().__init__(accelerator=accelerator,
plugins=[plugin], *args, **kwargs)

@staticmethod
def compile(model: nn.Module, loss: _Loss = None, optimizer: torch.optim.Optimizer = None):
"""
Construct a pytorch-lightning model. If model is already a pytorch-lightning model,
return model. If model is pytorch model, construct a new pytorch-lightning module
with model, loss and optimizer.

:param model: A model instance.
:param loss: Loss to construct pytorch-lightning model.
Should be None if model is instance of pl.LightningModule.
:param optimizer: Optimizer to construct pytorch-lightning model Should be None.
if model is instance of pl.LightningModule.
:return: A LightningModule object.
"""
assert isinstance(model, nn.Module), \
"Model must be instance of nn.Module but got {}".format(model.__class__)
if isinstance(model, pl.LightningModule):
assert not (loss or optimizer), \
"Loss and optimizer should be None if model is a pytorch-lightning model."
return model
else:
assert loss and optimizer, \
"Loss and optimizer are required to construct a LightningModule instance."
return LightningModuleFromTorch(model, loss, optimizer)
69 changes: 69 additions & 0 deletions python/nano/test/test_lightning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import os
from unittest import TestCase

import torch
from torch import nn

from _train_torch_lightning import create_data_loader, data_transform
from bigdl.nano.pytorch.lightning import LightningModuleFromTorch
from bigdl.nano.pytorch.trainer import Trainer
from bigdl.nano.pytorch.vision.models import vision

num_classes = 10
batch_size = 256
num_workers = 0
data_dir = os.path.join(os.path.dirname(__file__), "data")


class ResNet18(nn.Module):
def __init__(self, pretrained=True, include_top=False, freeze=True):
super().__init__()
backbone = vision.resnet18(pretrained=pretrained, include_top=include_top, freeze=freeze)
output_size = backbone.get_output_size()
head = nn.Linear(output_size, num_classes)
self.model = torch.nn.Sequential(backbone, head)

def forward(self, x):
return self.model(x)


model = ResNet18(pretrained=True, include_top=False, freeze=True)
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)


class TestLightningModuleFromTorch(TestCase):

def test_resnet18(self):
pl_model = LightningModuleFromTorch(model, loss, optimizer)
train_loader = create_data_loader(data_dir, batch_size, num_workers, data_transform)
trainer = Trainer(max_epochs=1)
trainer.fit(pl_model, train_loader)

def test_load_state_dict_from_torch(self):
torch.save(model.state_dict(), "resnet18_test.pth")
pl_model = LightningModuleFromTorch(model, loss, optimizer)
state_dict = torch.load("resnet18_test.pth")
pl_model.load_state_dict(state_dict)

def test_load_state_dict_from_lightning(self):
pl_model = LightningModuleFromTorch(model, loss, optimizer)
torch.save(pl_model.state_dict(), "lightning_resnet18_test.pth")
state_dict = torch.load("lightning_resnet18_test.pth")
pl_model.load_state_dict(state_dict)
27 changes: 26 additions & 1 deletion python/nano/test/test_trainer_ipex.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,15 @@
import pytest
import os
from unittest import TestCase

import torch
from torch import nn

from test._train_torch_lightning import create_data_loader, data_transform
from bigdl.nano.pytorch.trainer import Trainer
from bigdl.nano.pytorch.vision.models import vision
from test._train_torch_lightning import train_with_linear_top_layer


batch_size = 256
num_workers = 0
data_dir = os.path.join(os.path.dirname(__file__), "data")
Expand All @@ -36,6 +41,26 @@ def test_resnet18_ipex(self):
resnet18, batch_size, num_workers, data_dir,
use_orca_lite_trainer=True)

def test_trainer_compile(self):
class ResNet18(nn.Module):
def __init__(self, num_classes, pretrained=True, include_top=False, freeze=True):
super().__init__()
backbone = vision.resnet18(pretrained=pretrained, include_top=include_top, freeze=freeze)
output_size = backbone.get_output_size()
head = nn.Linear(output_size, num_classes)
self.model = nn.Sequential(backbone, head)

def forward(self, x):
return self.model(x)

model = ResNet18(10, pretrained=True, include_top=False, freeze=True)
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
trainer = Trainer(max_epochs=1)
pl_model = trainer.compile(model, loss, optimizer)
train_loader = create_data_loader(data_dir, batch_size, num_workers, data_transform)
trainer.fit(pl_model, train_loader)


if __name__ == '__main__':
pytest.main([__file__])