Skip to content

Commit

Permalink
Merge pull request #2056 from AdeelH/learner-bundle-config
Browse files Browse the repository at this point in the history
Remove `LearnerPipeline` and `LearnerPipelineConfig`
  • Loading branch information
AdeelH committed Feb 8, 2024
2 parents 9fd4dd5 + e67e488 commit 44891f0
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 75 deletions.
42 changes: 41 additions & 1 deletion rastervision_pipeline/rastervision/pipeline/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@
Type, Union)
import inspect
import logging
import json

from pydantic import ( # noqa
BaseModel, create_model, Field, root_validator, validate_model,
ValidationError, validator)

from rastervision.pipeline import (registry_ as registry, rv_config_ as
rv_config)
from rastervision.pipeline.file_system import str_to_file
from rastervision.pipeline.file_system import (str_to_file, json_to_file,
file_to_json)

if TYPE_CHECKING:
from rastervision.pipeline.pipeline_config import PipelineConfig
Expand Down Expand Up @@ -112,6 +114,35 @@ def validate_list(self, field: str, valid_options: List[str]):
if val not in valid_options:
raise ConfigError(f'{val} is not a valid option for {field}')

def to_file(self, uri: str, with_rv_metadata: bool = True) -> None:
"""Save a Config to a JSON file, optionally with RV metadata.
Args:
uri: URI to save to.
with_rv_metadata: If True, inject Raster Vision metadata such as
``plugin_versions``, so that the config can be upgraded when
loaded.
"""
cfg_json = self.json()
if with_rv_metadata:
# self.dict() --> json_to_file() would be simpler but runs into
# JSON serialization problems
cfg_dict = json.loads(cfg_json)
cfg_dict['plugin_versions'] = registry.plugin_versions
cfg_json = json.dumps(cfg_dict)
json_to_file(cfg_dict, uri)

@classmethod
def from_file(self, uri: str) -> 'Config':
"""Deserialize a Config from a JSON file, upgrading if possible.
Args:
uri: URI to load from.
"""
cfg_dict = load_config_dict(uri)
cfg = build_config(cfg_dict)
return cfg

def __repr_args__(self):
"""Override to delete 'type_hint' field."""
args = dict(super().__repr_args__())
Expand All @@ -133,6 +164,15 @@ def save_pipeline_config(cfg: 'PipelineConfig', output_uri: str) -> None:
str_to_file(cfg_json, output_uri)


def load_config_dict(uri: str) -> dict:
"""Load a serialized Config from a JSON file as a dict and upgrade it."""
cfg_dict = file_to_json(uri)
if 'plugin_versions' in cfg_dict:
cfg_dict = upgrade_config(cfg_dict)
cfg_dict.pop('plugin_versions', None)
return cfg_dict


def build_config(x: Union[dict, List[Union[dict, Config]], Config]
) -> Union[Config, List[Config]]:
"""Build a Config from various types of input.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@ def register_plugin(registry: 'Registry'):
import rastervision.pipeline
from rastervision.pytorch_learner.learner_config import *
from rastervision.pytorch_learner.learner import *
from rastervision.pytorch_learner.learner_pipeline_config import *
from rastervision.pytorch_learner.learner_pipeline import *
from rastervision.pytorch_learner.classification_learner_config import *
from rastervision.pytorch_learner.classification_learner import *
from rastervision.pytorch_learner.regression_learner_config import *
Expand All @@ -26,9 +24,6 @@ def register_plugin(registry: 'Registry'):
from rastervision.pytorch_learner.dataset import *

__all__ = [
# LearnerPipeline
LearnerPipeline.__name__,
LearnerPipelineConfig.__name__,
# Learner
Learner.__name__,
SemanticSegmentationLearner.__name__,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,25 +27,24 @@
from rastervision.pipeline import rv_config_ as rv_config
from rastervision.pipeline.utils import get_env_var
from rastervision.pipeline.file_system import (
sync_to_dir, json_to_file, file_to_json, make_dir, zipdir,
download_if_needed, download_or_copy, sync_from_dir, get_local_path, unzip,
str_to_file, is_local, get_tmp_dir)
sync_to_dir, json_to_file, make_dir, zipdir, download_if_needed,
download_or_copy, sync_from_dir, get_local_path, unzip, is_local,
get_tmp_dir)
from rastervision.pipeline.file_system.utils import file_exists
from rastervision.pipeline.utils import terminate_at_exit
from rastervision.pipeline.config import (build_config, upgrade_config,
save_pipeline_config)
from rastervision.pipeline.config import build_config
from rastervision.pytorch_learner.utils import (
get_hubconf_dir_from_cfg, aggregate_metrics, log_metrics_to_csv,
log_system_details, ONNXRuntimeAdapter, DDPContextManager)
aggregate_metrics, DDPContextManager, get_hubconf_dir_from_cfg,
get_learner_config_from_bundle_dir, log_metrics_to_csv, log_system_details,
ONNXRuntimeAdapter)
from rastervision.pytorch_learner.dataset.visualizer import Visualizer

if TYPE_CHECKING:
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import Dataset, Sampler

from rastervision.pytorch_learner import (LearnerConfig,
LearnerPipelineConfig)
from rastervision.pytorch_learner import LearnerConfig

warnings.filterwarnings('ignore')

Expand Down Expand Up @@ -305,14 +304,7 @@ def from_model_bundle(cls: Type,
unzip(model_bundle_path, model_bundle_dir)

if cfg is None:
config_path = join(model_bundle_dir, 'pipeline-config.json')

config_dict = file_to_json(config_path)
config_dict = upgrade_config(config_dict)

learner_pipeline_cfg: 'LearnerPipelineConfig' = build_config(
config_dict)
cfg = learner_pipeline_cfg.learner
cfg = get_learner_config_from_bundle_dir(model_bundle_dir)

hub_dir = join(model_bundle_dir, MODULES_DIRNAME)
model_def_path = None
Expand Down Expand Up @@ -1024,8 +1016,8 @@ def setup_training(self, loss_def_path: Optional[str] = None) -> None:
"""
cfg = self.cfg

self.config_path = join(self.output_dir, 'learner-config.json')
str_to_file(cfg.json(), self.config_path)
self.config_path = join(self.output_dir_local, 'learner-config.json')
cfg.to_file(self.config_path)
self.log_path = join(self.output_dir_local, 'log.csv')
self.last_model_weights_path = join(self.output_dir_local,
'last-model.pth')
Expand Down Expand Up @@ -1399,9 +1391,6 @@ def save_model_bundle(self, export_onnx: bool = True):
This is a zip file with the model weights in .pth format and a serialized
copy of the LearningConfig, which allows for making predictions in the future.
"""
from rastervision.pytorch_learner.learner_pipeline_config import (
LearnerPipelineConfig)

if self.cfg.model is None:
log.warning(
'Model was not configured via ModelConfig, and therefore, '
Expand All @@ -1417,9 +1406,8 @@ def save_model_bundle(self, export_onnx: bool = True):
self._bundle_modules(model_bundle_dir)
self._bundle_transforms(model_bundle_dir)

pipeline_cfg = LearnerPipelineConfig(learner=self.cfg)
save_pipeline_config(pipeline_cfg,
join(model_bundle_dir, 'pipeline-config.json'))
cfg_uri = join(model_bundle_dir, 'learner-config.json')
shutil.copy(self.config_path, cfg_uri)

zip_path = join(self.output_dir_local, basename(self.model_bundle_uri))
log.info(f'Saving bundle to {zip_path}.')
Expand Down

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import (Any, Dict, Sequence, Tuple, Optional, Union, List,
Iterable, Container)
from typing import (TYPE_CHECKING, Any, Dict, Sequence, Tuple, Optional, Union,
List, Iterable, Container)
from os.path import basename, join, isfile
import logging

Expand All @@ -14,8 +14,13 @@
import pandas as pd
import onnxruntime as ort

from rastervision.pipeline.file_system import get_tmp_dir
from rastervision.pipeline.config import ConfigError
from rastervision.pipeline.file_system.utils import (file_exists, file_to_json,
get_tmp_dir)
from rastervision.pipeline.config import (build_config, Config, ConfigError,
upgrade_config)

if TYPE_CHECKING:
from rastervision.pytorch_learner import LearnerConfig

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -487,3 +492,23 @@ def __call__(self, x: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
if isinstance(out, np.ndarray):
out = torch.from_numpy(out)
return out


def get_learner_config_from_bundle_dir(
model_bundle_dir: str) -> 'LearnerConfig':
config_path = join(model_bundle_dir, 'learner-config.json')
if file_exists(config_path):
cfg = Config.from_file(config_path)
else:
# backward compatibility
config_path = join(model_bundle_dir, 'pipeline-config.json')
if not file_exists(config_path):
raise FileNotFoundError(
'Could not find a valid config file in the bundle.')
pipeline_cfg_dict = file_to_json(config_path)
cfg_dict = pipeline_cfg_dict['learner']
cfg_dict['plugin_versions'] = pipeline_cfg_dict['plugin_versions']
cfg_dict = upgrade_config(cfg_dict)
cfg_dict.pop('plugin_versions', None)
cfg = build_config(cfg_dict)
return cfg
27 changes: 26 additions & 1 deletion tests/pytorch_learner/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,17 @@
import pandas as pd

from rastervision.pipeline.file_system.utils import get_tmp_dir
from rastervision.pipeline.pipeline_config import PipelineConfig
from rastervision.pytorch_learner import (DataConfig, LearnerConfig,
SolverConfig)
from rastervision.pytorch_learner.utils import (
compute_conf_mat, compute_conf_mat_metrics, MinMaxNormalize,
adjust_conv_channels, Parallel, SplitTensor, AddTensors,
validate_albumentation_transform, A, color_to_triple,
channel_groups_to_imgs, plot_channel_groups,
serialize_albumentation_transform, deserialize_albumentation_transform,
aggregate_metrics, log_metrics_to_csv, log_system_details)
aggregate_metrics, log_metrics_to_csv, log_system_details,
get_learner_config_from_bundle_dir)
from tests.data_files.lambda_transforms import lambda_transforms
from tests import data_file_path

Expand Down Expand Up @@ -384,6 +388,27 @@ def test_log_metrics_to_csv(self):
def test_log_system_details(self):
self.assertNoError(log_system_details)

def test_get_learner_config_from_bundle_dir(self):
learner_cfg = LearnerConfig(solver=SolverConfig(), data=DataConfig())
with get_tmp_dir() as tmp_dir:
learner_cfg.to_file(join(tmp_dir, 'learner-config.json'))
cfg = get_learner_config_from_bundle_dir(tmp_dir)
self.assertIsInstance(cfg, LearnerConfig)

def test_get_learner_config_from_bundle_dir_backward_compatibility(self):
class MockLearnerPipelineConfig(PipelineConfig):
learner: LearnerConfig

learner_cfg = LearnerConfig(solver=SolverConfig(), data=DataConfig())
learner_pipeline_cfg = MockLearnerPipelineConfig(learner=learner_cfg)
with get_tmp_dir() as tmp_dir:
self.assertRaises(
FileNotFoundError,
lambda: get_learner_config_from_bundle_dir(tmp_dir))
learner_pipeline_cfg.to_file(join(tmp_dir, 'pipeline-config.json'))
cfg = get_learner_config_from_bundle_dir(tmp_dir)
self.assertIsInstance(cfg, LearnerConfig)


if __name__ == '__main__':
unittest.main()

0 comments on commit 44891f0

Please sign in to comment.