Skip to content

Commit

Permalink
Refactor SOLD2 and WunschLineMatcher Dict Config to Dataclasses (#2901)
Browse files Browse the repository at this point in the history
  • Loading branch information
lappemic committed May 17, 2024
1 parent 3ce96a3 commit dfab71f
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 44 deletions.
72 changes: 30 additions & 42 deletions kornia/feature/sold2/sold2.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,21 @@
import warnings
from typing import Any, Dict, Optional, Tuple

import torch
import torch.nn.functional as F

from kornia.core import Module, Tensor, concatenate, pad, stack
from kornia.core.check import KORNIA_CHECK_SHAPE
from kornia.feature.sold2.structures import LineDetectorCfg
from kornia.feature.sold2.structures import DetectorCfg, LineMatcherCfg
from kornia.geometry.conversions import normalize_pixel_coordinates
from kornia.utils import map_location_to_cpu
from kornia.utils import dataclass_to_dict, dict_to_dataclass, map_location_to_cpu

from .backbones import SOLD2Net
from .sold2_detector import LineSegmentDetectionModule, line_map_to_segments, prob_to_junctions

urls: Dict[str, str] = {}
urls["wireframe"] = "http://cmp.felk.cvut.cz/~mishkdmy/models/sold2_wireframe.pth"

default_cfg: Dict[str, Any] = {
"backbone_cfg": {"input_channel": 1, "depth": 4, "num_stacks": 2, "num_blocks": 1, "num_classes": 5},
"use_descriptor": True,
"grid_size": 8,
"keep_border_valid": True,
"detection_thresh": 0.0153846, # = 1/65: threshold of junction detection
"max_num_junctions": 500, # maximum number of junctions per image
"line_detector_cfg": LineDetectorCfg(),
"line_matcher_cfg": {
"cross_check": True,
"num_samples": 5,
"min_dist_pts": 8,
"top_k_candidates": 10,
"grid_size": 4,
},
}


class SOLD2(Module):
r"""Module, which detects and describe line segments in an image.
Expand Down Expand Up @@ -59,27 +43,37 @@ class SOLD2(Module):
>>> matches = sold2.match(line_seg1, line_seg2, desc1[None], desc2[None])
"""

def __init__(self, pretrained: bool = True, config: Optional[Dict[str, Any]] = None) -> None:
def __init__(self, pretrained: bool = True, config: Optional[DetectorCfg] = None) -> None:
if isinstance(config, dict):
warnings.warn(
"Usage of config as a plain dictionary is deprecated in favor of"
" `kornia.features.sold2.structures.DetectorCfg`. The support of plain dictionaries"
"as config will be removed in kornia v0.8.0 (December 2024).",
category=DeprecationWarning,
stacklevel=2,
)
config = dict_to_dataclass(config, DetectorCfg)
super().__init__()
# Initialize some parameters
self.config = default_cfg if config is None else config
self.grid_size = self.config["grid_size"]
self.junc_detect_thresh = self.config.get("detection_thresh", 1 / 65)
self.max_num_junctions = self.config.get("max_num_junctions", 500)
self.config = config if config is not None else DetectorCfg()
self.config.use_descriptor = True # Only difference to SOLD2_detector DetectorCfg
self.grid_size = self.config.grid_size
self.junc_detect_thresh = self.config.detection_thresh
self.max_num_junctions = self.config.max_num_junctions

# Load the pre-trained model
self.model = SOLD2Net(self.config)
self.model = SOLD2Net(dataclass_to_dict(self.config))
if pretrained:
pretrained_dict = torch.hub.load_state_dict_from_url(urls["wireframe"], map_location=map_location_to_cpu)
state_dict = self.adapt_state_dict(pretrained_dict["model_state_dict"])
self.model.load_state_dict(state_dict)
self.eval()

# Initialize the line detector
self.line_detector = LineSegmentDetectionModule(LineDetectorCfg())
self.line_detector = LineSegmentDetectionModule(self.config.line_detector_cfg)

# Initialize the line matcher
self.line_matcher = WunschLineMatcher(**self.config["line_matcher_cfg"])
self.line_matcher = WunschLineMatcher(self.config.line_matcher_cfg)

def forward(self, img: Tensor) -> Dict[str, Any]:
"""
Expand Down Expand Up @@ -143,22 +137,16 @@ class WunschLineMatcher(Module):
TODO: move it later in kornia.feature.matching
"""

def __init__(
self,
cross_check: bool = True,
num_samples: int = 10,
min_dist_pts: int = 8,
top_k_candidates: int = 10,
grid_size: int = 8,
line_score: bool = False,
) -> None:
def __init__(self, config: LineMatcherCfg = LineMatcherCfg()) -> None:
super().__init__()
self.cross_check = cross_check
self.num_samples = num_samples
self.min_dist_pts = min_dist_pts
self.top_k_candidates = top_k_candidates
self.grid_size = grid_size
self.line_score = line_score # True to compute saliency on a line
# Initialize the parameters
self.config = config
self.cross_check = self.config.cross_check
self.num_samples = self.config.num_samples
self.min_dist_pts = self.config.min_dist_pts
self.top_k_candidates = self.config.top_k_candidates
self.grid_size = self.config.grid_size
self.line_score = self.config.line_score

def forward(self, line_seg1: Tensor, line_seg2: Tensor, desc1: Tensor, desc2: Tensor) -> Tensor:
"""Find the best matches between two sets of line segments and their corresponding descriptors."""
Expand Down
4 changes: 2 additions & 2 deletions kornia/feature/sold2/structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ class LineDetectorCfg:
@dataclass
class LineMatcherCfg:
cross_check: bool = True
num_samples: int = 10
num_samples: int = 5
min_dist_pts: int = 8
top_k_candidates: int = 10
grid_size: int = 8
grid_size: int = 4
line_score: bool = False # True to compute saliency on a line


Expand Down

0 comments on commit dfab71f

Please sign in to comment.