Skip to content

Commit

Permalink
fix issue when using latest omegaconf
Browse files Browse the repository at this point in the history
Summary:
`ShapeSpec` was a namedtuple, which supports both indexing and attribute access. This confused the latest omegaconf: lastest omegaconf interprets it as a tuple instead.

We never need to index ShapeSpec like `input_shape[0]`. Only attribute access like `input_shape.channels` is needed. So this PR changes it to dataclass.

And it also hardens the dataclass support in LazyConfig so that `instantiate()`, etc works correctly with it.

Follow up of 32c32e3. sstsai-adl I think this will fix the error you're seeing.

Pull Request resolved: #4253

Reviewed By: sstsai-adl

Differential Revision: D36635140

fbshipit-source-id: 8526ebc546941c448cdf9afe77a946ef612c74a5
  • Loading branch information
ppwwyyxx authored and facebook-github-bot committed May 25, 2022
1 parent 45b3fce commit 3cc9908
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 51 deletions.
7 changes: 6 additions & 1 deletion detectron2/config/instantiate.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def instantiate(cfg):
Returns:
object instantiated by cfg
"""
from omegaconf import ListConfig
from omegaconf import ListConfig, DictConfig, OmegaConf

if isinstance(cfg, ListConfig):
lst = [instantiate(x) for x in cfg]
Expand All @@ -56,6 +56,11 @@ def instantiate(cfg):
# list[objects] as arguments, such as ResNet, DatasetMapper
return [instantiate(x) for x in cfg]

# If input is a DictConfig backed by dataclasses (i.e. omegaconf's structured config),
# instantiate it to the actual dataclass.
if isinstance(cfg, DictConfig) and dataclasses.is_dataclass(cfg._metadata.object_type):
return OmegaConf.to_object(cfg)

if isinstance(cfg, abc.Mapping) and "_target_" in cfg:
# conceptually equivalent to hydra.utils.instantiate(cfg) with _convert_=all,
# but faster: https://github.com/facebookresearch/hydra/issues/1200
Expand Down
12 changes: 10 additions & 2 deletions detectron2/config/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from typing import List, Tuple, Union
import cloudpickle
import yaml
from omegaconf import DictConfig, ListConfig, OmegaConf
from omegaconf import DictConfig, ListConfig, OmegaConf, SCMode

from detectron2.utils.file_io import PathManager
from detectron2.utils.registry import _convert_target_to_string
Expand Down Expand Up @@ -266,7 +266,15 @@ def _replace_type_by_name(x):

save_pkl = False
try:
dict = OmegaConf.to_container(cfg, resolve=False)
dict = OmegaConf.to_container(
cfg,
# Do not resolve interpolation when saving, i.e. do not turn ${a} into
# actual values when saving.
resolve=False,
# Save structures (dataclasses) in a format that can be instantiated later.
# Without this option, the type information of the dataclass will be erased.
structured_config_mode=SCMode.INSTANTIATE,
)
dumped = yaml.dump(dict, default_flow_style=None, allow_unicode=True, width=9999)
with PathManager.open(filename, "w") as f:
f.write(dumped)
Expand Down
18 changes: 8 additions & 10 deletions detectron2/layers/shape_spec.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,18 @@
# -*- coding: utf-8 -*-
# Copyright (c) Facebook, Inc. and its affiliates.
from collections import namedtuple
from dataclasses import dataclass
from typing import Optional


class ShapeSpec(namedtuple("_ShapeSpec", ["channels", "height", "width", "stride"])):
@dataclass
class ShapeSpec:
"""
A simple structure that contains basic shape specification about a tensor.
It is often used as the auxiliary inputs/outputs of models,
to complement the lack of shape inference ability among pytorch modules.
Attributes:
channels:
height:
width:
stride:
"""

def __new__(cls, channels=None, height=None, width=None, stride=None):
return super().__new__(cls, channels, height, width, stride)
channels: Optional[int] = None
height: Optional[int] = None
width: Optional[int] = None
stride: Optional[int] = None
16 changes: 15 additions & 1 deletion detectron2/utils/testing.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# Copyright (c) Facebook, Inc. and its affiliates.
import io
import numpy as np
import os
import tempfile
import torch

from detectron2 import model_zoo
from detectron2.config import CfgNode, instantiate
from detectron2.config import CfgNode, LazyConfig, instantiate
from detectron2.data import DatasetCatalog
from detectron2.data.detection_utils import read_image
from detectron2.modeling import build_model
Expand Down Expand Up @@ -139,3 +141,15 @@ def reload_script_model(module):
torch.jit.save(module, buffer)
buffer.seek(0)
return torch.jit.load(buffer)


def reload_lazy_config(cfg):
"""
Save an object by LazyConfig.save and load it back.
This is used to test that a config still works the same after
serialization/deserialization.
"""
with tempfile.TemporaryDirectory(prefix="detectron2") as d:
fname = os.path.join(d, "d2_cfg_test.yaml")
LazyConfig.save(cfg, fname)
return LazyConfig.load(fname)
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,16 +181,16 @@ def get_model_zoo_configs() -> List[str]:
"future", # used by caffe2
"pydot", # used to save caffe2 SVGs
"dataclasses; python_version<'3.7'",
"omegaconf>=2.1,<=2.2.0",
"omegaconf>=2.1",
"hydra-core>=1.1",
"black==22.3.0",
"scipy>1.5.1",
# If a new dependency is required at import time (in addition to runtime), it
# probably needs to exist in docs/requirements.txt, or as a mock in docs/conf.py
],
extras_require={
# optional dependencies, required by some features
"all": [
"scipy>1.5.1",
"shapely",
"pygments>=2.2",
"psutil",
Expand Down
79 changes: 44 additions & 35 deletions tests/config/test_instantiate_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
from omegaconf import __version__ as oc_version
from dataclasses import dataclass

from detectron2.config import instantiate, LazyCall as L
from detectron2.config import LazyConfig, instantiate, LazyCall as L
from detectron2.layers import ShapeSpec
from detectron2.utils.testing import reload_lazy_config

OC_VERSION = tuple(int(x) for x in oc_version.split(".")[:2])

Expand All @@ -25,32 +26,28 @@ def __call__(self, call_arg):
return call_arg + self.int_arg


@dataclass
class TestDataClass:
x: int
y: str


@unittest.skipIf(OC_VERSION < (2, 1), "omegaconf version too old")
class TestConstruction(unittest.TestCase):
def test_basic_construct(self):
objconf = L(TestClass)(
cfg = L(TestClass)(
int_arg=3,
list_arg=[10],
dict_arg={},
extra_arg=L(TestClass)(int_arg=4, list_arg="${..list_arg}"),
)

obj = instantiate(objconf)
self.assertIsInstance(obj, TestClass)
self.assertEqual(obj.int_arg, 3)
self.assertEqual(obj.extra_arg.int_arg, 4)
self.assertEqual(obj.extra_arg.list_arg, obj.list_arg)
for x in [cfg, reload_lazy_config(cfg)]:
obj = instantiate(x)
self.assertIsInstance(obj, TestClass)
self.assertEqual(obj.int_arg, 3)
self.assertEqual(obj.extra_arg.int_arg, 4)
self.assertEqual(obj.extra_arg.list_arg, obj.list_arg)

objconf.extra_arg.list_arg = [5]
obj = instantiate(objconf)
self.assertIsInstance(obj, TestClass)
self.assertEqual(obj.extra_arg.list_arg, [5])
# Test interpolation
x.extra_arg.list_arg = [5]
obj = instantiate(x)
self.assertIsInstance(obj, TestClass)
self.assertEqual(obj.extra_arg.list_arg, [5])

def test_instantiate_other_obj(self):
# do nothing for other obj
Expand All @@ -68,33 +65,45 @@ def test_instantiate_lazy_target(self):
objconf._target_._target_ = TestClass
self.assertEqual(instantiate(objconf), 7)

def test_instantiate_lst(self):
def test_instantiate_list(self):
lst = [1, 2, L(TestClass)(int_arg=1)]
x = L(TestClass)(int_arg=lst) # list as an argument should be recursively instantiated
x = instantiate(x).int_arg
self.assertEqual(x[:2], [1, 2])
self.assertIsInstance(x[2], TestClass)
self.assertEqual(x[2].int_arg, 1)

def test_instantiate_namedtuple(self):
x = L(TestClass)(int_arg=ShapeSpec(channels=1, width=3))
# test serialization
with tempfile.TemporaryDirectory() as d:
fname = os.path.join(d, "d2_test.yaml")
OmegaConf.save(x, fname)
with open(fname) as f:
x = yaml.unsafe_load(f)

x = instantiate(x)
self.assertIsInstance(x.int_arg, ShapeSpec)
self.assertEqual(x.int_arg.channels, 1)
def test_instantiate_dataclass(self):
cfg = L(ShapeSpec)(channels=1, width=3)
# Test original cfg as well as serialization
for x in [cfg, reload_lazy_config(cfg)]:
obj = instantiate(x)
self.assertIsInstance(obj, ShapeSpec)
self.assertEqual(obj.channels, 1)
self.assertEqual(obj.height, None)

def test_instantiate_dataclass_as_subconfig(self):
cfg = L(TestClass)(int_arg=1, extra_arg=ShapeSpec(channels=1, width=3))
# Test original cfg as well as serialization
for x in [cfg, reload_lazy_config(cfg)]:
obj = instantiate(x)
self.assertIsInstance(obj.extra_arg, ShapeSpec)
self.assertEqual(obj.extra_arg.channels, 1)
self.assertEqual(obj.extra_arg.height, None)

def test_bad_lazycall(self):
with self.assertRaises(Exception):
L(3)

def test_instantiate_dataclass(self):
a = L(TestDataClass)(x=1, y="s")
a = instantiate(a)
self.assertEqual(a.x, 1)
self.assertEqual(a.y, "s")
def test_interpolation(self):
cfg = L(TestClass)(int_arg=3, extra_arg="${int_arg}")

cfg.int_arg = 4
obj = instantiate(cfg)
self.assertEqual(obj.extra_arg, 4)

# Test that interpolation still works after serialization
cfg = reload_lazy_config(cfg)
cfg.int_arg = 5
obj = instantiate(cfg)
self.assertEqual(obj.extra_arg, 5)

0 comments on commit 3cc9908

Please sign in to comment.