Skip to content

Commit

Permalink
Allow any input in to_onnx and to_torchscript (Lightning-AI#4378)
Browse files Browse the repository at this point in the history
* branch merge

* sample

* update with valid input tensors

* pep

* pathlib

* Updated with BoringModel and added more input types

* try fix

* pep

* skip test with torch < 1.4

* fix test

* Apply suggestions from code review

* update tests

* Allow any input in to_onnx and to_torchscript

* Update tests/models/test_torchscript.py

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* no_grad

* try fix random failing test

* rm example_input_array

* rm example_input_array

Co-authored-by: chaton <thomas@grid.ai>
Co-authored-by: Jeff Yang <ydcjeff@outlook.com>
Co-authored-by: Roger Shieh <sh.rog@protonmail.ch>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: edenlightning <66261195+edenlightning@users.noreply.github.com>
  • Loading branch information
7 people committed Dec 12, 2020
1 parent b5a2afd commit 3100b78
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 77 deletions.
5 changes: 3 additions & 2 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

"""Various hooks to be used in the Lightning code."""

from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Optional, Union

import torch
from pytorch_lightning.utilities import move_data_to_device, rank_zero_warn
Expand Down Expand Up @@ -501,7 +501,7 @@ def val_dataloader(self):
will have an argument ``dataloader_idx`` which matches the order here.
"""

def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any:
def transfer_batch_to_device(self, batch: Any, device: Optional[torch.device] = None) -> Any:
"""
Override this hook if your :class:`~torch.utils.data.DataLoader` returns tensors
wrapped in a custom data structure.
Expand Down Expand Up @@ -549,6 +549,7 @@ def transfer_batch_to_device(self, batch, device)
- :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device`
- :func:`~pytorch_lightning.utilities.apply_func.apply_to_collection`
"""
device = device or self.device
return move_data_to_device(batch, device)


Expand Down
87 changes: 51 additions & 36 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import tempfile
from abc import ABC
from argparse import Namespace
from pathlib import Path
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union

import torch
Expand Down Expand Up @@ -1530,12 +1531,19 @@ def _set_hparams(self, hp: Union[dict, Namespace, str]) -> None:
else:
self._hparams = hp

def to_onnx(self, file_path: str, input_sample: Optional[Tensor] = None, **kwargs):
"""Saves the model in ONNX format
@torch.no_grad()
def to_onnx(
self,
file_path: Union[str, Path],
input_sample: Optional[Any] = None,
**kwargs,
):
"""
Saves the model in ONNX format
Args:
file_path: The path of the file the model should be saved to.
input_sample: A sample of an input tensor for tracing.
file_path: The path of the file the onnx model should be saved to.
input_sample: An input for tracing. Default: None (Use self.example_input_array)
**kwargs: Will be passed to torch.onnx.export function.
Example:
Expand All @@ -1554,31 +1562,32 @@ def to_onnx(self, file_path: str, input_sample: Optional[Tensor] = None, **kwarg
... os.path.isfile(tmpfile.name)
True
"""
mode = self.training

if isinstance(input_sample, Tensor):
input_data = input_sample
elif self.example_input_array is not None:
input_data = self.example_input_array
else:
if input_sample is not None:
if input_sample is None:
if self.example_input_array is None:
raise ValueError(
f"Received `input_sample` of type {type(input_sample)}. Expected type is `Tensor`"
"Could not export to ONNX since neither `input_sample` nor"
" `model.example_input_array` attribute is set."
)
raise ValueError(
"Could not export to ONNX since neither `input_sample` nor"
" `model.example_input_array` attribute is set."
)
input_data = input_data.to(self.device)
input_sample = self.example_input_array

input_sample = self.transfer_batch_to_device(input_sample)

if "example_outputs" not in kwargs:
self.eval()
with torch.no_grad():
kwargs["example_outputs"] = self(input_data)
kwargs["example_outputs"] = self(input_sample)

torch.onnx.export(self, input_data, file_path, **kwargs)
torch.onnx.export(self, input_sample, file_path, **kwargs)
self.train(mode)

@torch.no_grad()
def to_torchscript(
self, file_path: Optional[str] = None, method: Optional[str] = 'script',
example_inputs: Optional[Union[torch.Tensor, Tuple[torch.Tensor]]] = None, **kwargs
self,
file_path: Optional[Union[str, Path]] = None,
method: Optional[str] = 'script',
example_inputs: Optional[Any] = None,
**kwargs,
) -> Union[ScriptModule, Dict[str, ScriptModule]]:
"""
By default compiles the whole model to a :class:`~torch.jit.ScriptModule`.
Expand All @@ -1590,7 +1599,7 @@ def to_torchscript(
Args:
file_path: Path where to save the torchscript. Default: None (no file saved).
method: Whether to use TorchScript's script or trace method. Default: 'script'
example_inputs: Tensor to be used to do tracing when method is set to 'trace'.
example_inputs: An input to be used to do tracing when method is set to 'trace'.
Default: None (Use self.example_input_array)
**kwargs: Additional arguments that will be passed to the :func:`torch.jit.script` or
:func:`torch.jit.trace` function.
Expand Down Expand Up @@ -1624,21 +1633,27 @@ def to_torchscript(
This LightningModule as a torchscript, regardless of whether file_path is
defined or not.
"""

mode = self.training
with torch.no_grad():
if method == 'script':
torchscript_module = torch.jit.script(self.eval(), **kwargs)
elif method == 'trace':
# if no example inputs are provided, try to see if model has example_input_array set
if example_inputs is None:
example_inputs = self.example_input_array
# automatically send example inputs to the right device and use trace
example_inputs = self.transfer_batch_to_device(example_inputs, device=self.device)
torchscript_module = torch.jit.trace(func=self.eval(), example_inputs=example_inputs, **kwargs)
else:
raise ValueError(f"The 'method' parameter only supports 'script' or 'trace', but value given was:"
f"{method}")

if method == 'script':
torchscript_module = torch.jit.script(self.eval(), **kwargs)
elif method == 'trace':
# if no example inputs are provided, try to see if model has example_input_array set
if example_inputs is None:
if self.example_input_array is None:
raise ValueError(
'Choosing method=`trace` requires either `example_inputs`'
' or `model.example_input_array` to be defined'
)
example_inputs = self.example_input_array

# automatically send example inputs to the right device and use trace
example_inputs = self.transfer_batch_to_device(example_inputs)
torchscript_module = torch.jit.trace(func=self.eval(), example_inputs=example_inputs, **kwargs)
else:
raise ValueError("The 'method' parameter only supports 'script' or 'trace',"
f" but value given was: {method}")

self.train(mode)

if file_path is not None:
Expand Down
49 changes: 22 additions & 27 deletions tests/models/test_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,44 +21,44 @@
import tests.base.develop_pipelines as tpipes
import tests.base.develop_utils as tutils
from pytorch_lightning import Trainer
from tests.base import EvalModelTemplate
from tests.base import BoringModel, EvalModelTemplate


def test_model_saves_with_input_sample(tmpdir):
"""Test that ONNX model saves with input sample and size is greater than 3 MB"""
model = EvalModelTemplate()
model = BoringModel()
trainer = Trainer(max_epochs=1)
trainer.fit(model)

file_path = os.path.join(tmpdir, "model.onnx")
input_sample = torch.randn((1, 28 * 28))
input_sample = torch.randn((1, 32))
model.to_onnx(file_path, input_sample)
assert os.path.isfile(file_path)
assert os.path.getsize(file_path) > 3e+06
assert os.path.getsize(file_path) > 4e2


@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
def test_model_saves_on_gpu(tmpdir):
"""Test that model saves on gpu"""
model = EvalModelTemplate()
model = BoringModel()
trainer = Trainer(gpus=1, max_epochs=1)
trainer.fit(model)

file_path = os.path.join(tmpdir, "model.onnx")
input_sample = torch.randn((1, 28 * 28))
input_sample = torch.randn((1, 32))
model.to_onnx(file_path, input_sample)
assert os.path.isfile(file_path)
assert os.path.getsize(file_path) > 3e+06
assert os.path.getsize(file_path) > 4e2


def test_model_saves_with_example_output(tmpdir):
"""Test that ONNX model saves when provided with example output"""
model = EvalModelTemplate()
model = BoringModel()
trainer = Trainer(max_epochs=1)
trainer.fit(model)

file_path = os.path.join(tmpdir, "model.onnx")
input_sample = torch.randn((1, 28 * 28))
input_sample = torch.randn((1, 32))
model.eval()
example_outputs = model.forward(input_sample)
model.to_onnx(file_path, input_sample, example_outputs=example_outputs)
Expand All @@ -67,11 +67,13 @@ def test_model_saves_with_example_output(tmpdir):

def test_model_saves_with_example_input_array(tmpdir):
"""Test that ONNX model saves with_example_input_array and size is greater than 3 MB"""
model = EvalModelTemplate()
model = BoringModel()
model.example_input_array = torch.randn(5, 32)

file_path = os.path.join(tmpdir, "model.onnx")
model.to_onnx(file_path)
assert os.path.exists(file_path) is True
assert os.path.getsize(file_path) > 3e+06
assert os.path.getsize(file_path) > 4e2


@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
Expand Down Expand Up @@ -100,38 +102,31 @@ def test_model_saves_on_multi_gpu(tmpdir):

def test_verbose_param(tmpdir, capsys):
"""Test that output is present when verbose parameter is set"""
model = EvalModelTemplate()
model = BoringModel()
model.example_input_array = torch.randn(5, 32)

file_path = os.path.join(tmpdir, "model.onnx")
model.to_onnx(file_path, verbose=True)
captured = capsys.readouterr()
assert "graph(%" in captured.out


def test_error_if_no_input(tmpdir):
"""Test that an exception is thrown when there is no input tensor"""
model = EvalModelTemplate()
"""Test that an error is thrown when there is no input tensor"""
model = BoringModel()
model.example_input_array = None
file_path = os.path.join(tmpdir, "model.onnx")
with pytest.raises(ValueError, match=r'Could not export to ONNX since neither `input_sample` nor'
r' `model.example_input_array` attribute is set.'):
model.to_onnx(file_path)


def test_error_if_input_sample_is_not_tensor(tmpdir):
"""Test that an exception is thrown when there is no input tensor"""
model = EvalModelTemplate()
model.example_input_array = None
file_path = os.path.join(tmpdir, "model.onnx")
input_sample = np.random.randn(1, 28 * 28)
with pytest.raises(ValueError, match=f'Received `input_sample` of type {type(input_sample)}. Expected type is '
f'`Tensor`'):
model.to_onnx(file_path, input_sample)


def test_if_inference_output_is_valid(tmpdir):
"""Test that the output inferred from ONNX model is same as from PyTorch"""
model = EvalModelTemplate()
trainer = Trainer(max_epochs=5)
model = BoringModel()
model.example_input_array = torch.randn(5, 32)

trainer = Trainer(max_epochs=2)
trainer.fit(model)

model.eval()
Expand Down
Loading

0 comments on commit 3100b78

Please sign in to comment.