Skip to content

Commit

Permalink
Merge pull request #1792 from AdeelH/onnx
Browse files Browse the repository at this point in the history
Add experimental ONNX support
  • Loading branch information
AdeelH committed Jun 14, 2023
2 parents 8a72681 + 7559566 commit 6b44624
Show file tree
Hide file tree
Showing 7 changed files with 315 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -169,18 +169,21 @@ def __str__(self) -> str:
def to_boxlist(self) -> BoxList:
return self.boxlist

def to_dict(self) -> dict:
def to_dict(self, round_boxes: bool = True) -> dict:
"""Returns a dict version of these labels.
The Dict has a Box as a key, and a tuple of (class_id, score)
as the values.
"""
d = {}
boxes = list(map(Box.from_npbox, self.get_npboxes()))
classes = list(self.get_class_ids())
scores = list(self.get_scores())
for box, class_id, score in zip(boxes, classes, scores):
d[box.tuple_format()] = (class_id, score)
npboxes = self.get_npboxes()
if round_boxes and np.issubdtype(npboxes.dtype, np.floating):
npboxes = npboxes.round(2)
classes = self.get_class_ids()
scores = self.get_scores().round(6)
d = {
Box.from_npbox(box): (class_id, score)
for box, class_id, score in zip(npboxes, classes, scores)
}
return d

@staticmethod
Expand Down
190 changes: 144 additions & 46 deletions rastervision_pytorch_learner/rastervision/pytorch_learner/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Optional, Tuple, Union, Type)
from typing_extensions import Literal
from abc import ABC, abstractmethod
import os
from os.path import join, isfile, basename, isdir
import warnings
import time
Expand Down Expand Up @@ -31,7 +32,7 @@
save_pipeline_config)
from rastervision.pytorch_learner.utils import (
get_hubconf_dir_from_cfg, aggregate_metrics, log_metrics_to_csv,
log_system_details)
log_system_details, ONNXRuntimeAdapter)
from rastervision.pytorch_learner.dataset.visualizer import Visualizer

if TYPE_CHECKING:
Expand All @@ -46,6 +47,9 @@

MODULES_DIRNAME = 'modules'
TRANSFORMS_DIRNAME = 'custom_albumentations_transforms'
BUNDLE_MODEL_WEIGHTS_FILENAME = 'model.pth'
BUNDLE_MODEL_ONNX_FILENAME = 'model.onnx'
USE_ONNX = os.getenv('RASTERVISION_USE_ONNX', 'false').lower() in ('true', '1')

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -184,7 +188,7 @@ def __init__(self,

self.modules_dir = join(self.output_dir, MODULES_DIRNAME)
# ---------------------------

self._onnx_mode = False
self.setup_model(
model_weights_path=model_weights_path,
model_def_path=model_def_path)
Expand All @@ -193,7 +197,8 @@ def __init__(self,
self.setup_training(loss_def_path=loss_def_path)
self.model.train()
else:
self.model.eval()
if not self.onnx_mode:
self.model.eval()

self.visualizer = self.get_visualizer_class()(
cfg.data.class_names, cfg.data.class_colors,
Expand All @@ -206,6 +211,7 @@ def from_model_bundle(cls: Type,
tmp_dir: Optional[str] = None,
cfg: Optional['LearnerConfig'] = None,
training: bool = False,
use_onnx_model: bool = USE_ONNX,
**kwargs) -> 'Learner':
"""Create a Learner from a model bundle.
Expand All @@ -227,6 +233,11 @@ def from_model_bundle(cls: Type,
model will be put into eval mode. If True, the training
apparatus will be set up and the model will be put into
training mode. Defaults to True.
use_onnx_model (bool, optional): If True and training=False and a
model.onnx file is available in the bundle, use that for
inference rather than the PyTorch weights. Defaults to the
boolean environment variable RASTERVISION_USE_ONNX if set,
False otherwise.
**kwargs: See :meth:`.Learner.__init__`.
Raises:
Expand All @@ -245,8 +256,6 @@ def from_model_bundle(cls: Type,
log.info(f'Unzipping model-bundle to {model_bundle_dir}')
unzip(model_bundle_path, model_bundle_dir)

model_weights_path = join(model_bundle_dir, 'model.pth')

if cfg is None:
config_path = join(model_bundle_dir, 'pipeline-config.json')

Expand Down Expand Up @@ -288,10 +297,20 @@ def from_model_bundle(cls: Type,
# config has been altered, so re-validate
cfg = build_config(cfg.dict())

if cfg.model is None and kwargs.get('model') is None:
raise ValueError(
'Model definition is not saved in the model-bundle. '
'Please specify the model explicitly.')
onnx_mode = False
if not training and use_onnx_model:
onnx_path = join(model_bundle_dir, 'model.onnx')
if file_exists(onnx_path):
model_weights_path = onnx_path
onnx_mode = True

if not onnx_mode:
if cfg.model is None and kwargs.get('model') is None:
raise ValueError(
'Model definition is not saved in the model-bundle. '
'Please specify the model explicitly.')
model_weights_path = join(model_bundle_dir,
BUNDLE_MODEL_WEIGHTS_FILENAME)

if cls == Learner:
if len(kwargs) > 0:
Expand Down Expand Up @@ -557,6 +576,13 @@ def predict(self, x: Tensor, raw_out: bool = False) -> Any:
out = self.to_device(out, 'cpu')
return out

def predict_onnx(self, x: Tensor, raw_out: bool = False) -> Tensor:
"""Alternative to predict() for ONNX inference."""
out = self.model(x)
if not raw_out:
out = self.prob_to_pred(self.post_forward(out))
return out

def predict_dataset(self,
dataset: 'Dataset',
return_format: Literal['xyz', 'yz', 'z'] = 'z',
Expand Down Expand Up @@ -714,14 +740,18 @@ def _predict_dataloader(
might or might not be batched depending on the batched_output
argument.
"""
self.model.eval()

if self.onnx_mode:
log.info('Running inference with ONNX runtime.')
else:
self.model.eval()

for x, y in dl:
x = self.to_device(x, self.device)
z = self.predict(x, raw_out=raw_out, **predict_kw)
x = self.to_device(x, 'cpu')
y = self.to_device(y, 'cpu') if y is not None else y
z = self.to_device(z, 'cpu')
if self.onnx_mode:
z = self.predict_onnx(x, raw_out=raw_out, **predict_kw)
else:
z = self.predict(x, raw_out=raw_out, **predict_kw)
x = self.to_device(x, 'cpu')
if batched_output:
yield x, y, z
else:
Expand All @@ -738,31 +768,6 @@ def output_to_numpy(self, out: Tensor) -> np.ndarray:
"""
return out.numpy()

def numpy_predict(self, x: np.ndarray,
raw_out: bool = False) -> np.ndarray:
"""Make a prediction using an image or batch of images in numpy format.
If x.dtype is a subtype of np.unsignedinteger, it will be normalized
to [0, 1] using the max possible value of that dtype. Otherwise, x will
be assumed to be in [0, 1] already and will be cast to torch.float32
directly.
Args:
x: (ndarray) of shape [height, width, channels] or
[batch_sz, height, width, channels]
raw_out: if True, return prediction probabilities
Returns:
predictions using numpy arrays
"""
transform, _ = self.cfg.data.get_data_transforms()
x = self.normalize_input(x)
x = self.to_batch(x)
x = np.stack([transform(image=img)['image'] for img in x])
x = torch.from_numpy(x)
x = x.permute((0, 3, 1, 2))
out = self.predict(x, raw_out=raw_out)
return self.output_to_numpy(out)

def prob_to_pred(self, x: Tensor) -> Tensor:
"""Convert a Tensor with prediction probabilities to class ids.
Expand Down Expand Up @@ -830,6 +835,11 @@ def setup_model(self,
model_def_path (Optional[str], optional): Path to model definition.
Will be available when loading from a bundle. Defaults to None.
"""
self._onnx_mode = (model_weights_path is not None
and model_weights_path.lower().endswith('.onnx'))
if self._onnx_mode:
self.model = self.load_onnx_model(model_weights_path)
return
if self.model is None:
self.model = self.build_model(model_def_path=model_def_path)
self.model.to(device=self.device)
Expand Down Expand Up @@ -1038,7 +1048,7 @@ def plot_dataloaders(self,
#########
# Bundle
#########
def save_model_bundle(self):
def save_model_bundle(self, export_onnx: bool = True):
"""Save a model bundle.
This is a zip file with the model weights in .pth format and a serialized
Expand All @@ -1058,7 +1068,7 @@ def save_model_bundle(self):
model_bundle_dir = join(self.tmp_dir, 'model-bundle')
make_dir(model_bundle_dir, force_empty=True)

self._bundle_model(model_bundle_dir)
self._bundle_model(model_bundle_dir, export_onnx=export_onnx)
self._bundle_modules(model_bundle_dir)
self._bundle_transforms(model_bundle_dir)

Expand All @@ -1070,11 +1080,90 @@ def save_model_bundle(self):
log.info(f'Saving bundle to {zip_path}.')
zipdir(model_bundle_dir, zip_path)

def _bundle_model(self, model_bundle_dir: str) -> None:
def _bundle_model(self, model_bundle_dir: str,
export_onnx: bool = True) -> None:
"""Save model weights and copy them to bundle dir.."""
torch.save(self.model.state_dict(), self.last_model_weights_path)
shutil.copyfile(self.last_model_weights_path,
join(model_bundle_dir, 'model.pth'))
# pytorch
path = join(model_bundle_dir, BUNDLE_MODEL_WEIGHTS_FILENAME)
if file_exists(self.last_model_weights_path):
shutil.copyfile(self.last_model_weights_path, path)
else:
torch.save(self.model.state_dict(), path)

# ONNX
if export_onnx:
path = join(model_bundle_dir, BUNDLE_MODEL_ONNX_FILENAME)
self.export_to_onnx(path)

def export_to_onnx(self,
path: str,
model: Optional['nn.Module'] = None,
sample_input: Optional[Tensor] = None,
validate_export: bool = True,
**kwargs) -> None:
"""Export model to ONNX format via torch.onnx.export.
Args:
path (str): File path to save to.
model (Optional[nn.Module]): The model to export. If None,
self.model will be used. Defaults to None.
sample_input (Optional[Tensor]): Sample input to the model. If
None, a single batch from any available DataLoader in this
Learner will be used. Defaults to None.
validate_export (bool): If True, use onnx.checker.check_model to
validate exported model. An exception is raised if the check
fails. Defaults to True.
**kwargs (dict): Keyword args to pass to torch.onnx.export. These
override the default values used in the function definition.
Raises:
ValueError: If sample_input is None and the Learner has no valid
DataLoaders.
"""
if model is None:
model = self.model

training_state = model.training

model.eval()
if sample_input is None:
for split in ['train', 'valid', 'test']:
dl = self.get_dataloader(split)
if dl is not None:
break
else:
raise ValueError('sample_input not provided and Learner does '
'not have a DataLoader to get sample input '
'from.')
sample_input, _ = next(iter(dl))
sample_input = self.to_device(sample_input, self.device)

args = dict(
input_names=['x'],
output_names=['out'],
dynamic_axes={
'x': {
0: 'batch_size',
2: 'height',
3: 'width',
},
'out': {
0: 'batch_size',
},
},
training=torch.onnx.TrainingMode.EVAL,
opset_version=15,
)
args.update(**kwargs)
log.info('Exporting to model to ONNX.')
torch.onnx.export(model, sample_input, path, **args)

model.train(training_state)

if validate_export:
import onnx
model_onnx = onnx.load(path)
onnx.checker.check_model(model_onnx)

def _bundle_modules(self, model_bundle_dir: str) -> None:
"""Copy modules into bundle."""
Expand Down Expand Up @@ -1216,6 +1305,11 @@ def load_checkpoint(self):
args['strict'] = self.cfg.model.load_strict
self.load_weights(uri=weights_path, **args)

def load_onnx_model(self, model_path: str) -> ONNXRuntimeAdapter:
log.info(f'Loading ONNX model from {model_path}')
onnx_model = ONNXRuntimeAdapter.from_file(model_path)
return onnx_model

def log_data_stats(self):
"""Log stats about each DataSet."""
if self.train_ds is not None:
Expand Down Expand Up @@ -1257,3 +1351,7 @@ def stop_tensorboard(self):
self.tb_writer.close()
if self.cfg.run_tensorboard:
self.tb_process.terminate()

@property
def onnx_mode(self) -> bool:
return self._onnx_mode

0 comments on commit 6b44624

Please sign in to comment.