Skip to content

Commit

Permalink
Unify sold2 config dataclasses (#2899)
Browse files Browse the repository at this point in the history
* refactor: store sold2_detector config dataclasses in utils/structures.py

* fix: update path to DetectorCfg in Docstring

* Move structures.py and restate import path

* Update SOLD2_detector docstring

* Add LineMatcherCfg dataclass to structures.py

* Update LineMatcherCfg to include line_score

* refactor: cfg of SOLD2 and WunschLineMatcher to be dataclasses

* my bad...

* fix: update WunschLineMatcher initialization

* fix: update SOLD2 initialiation

* fix: rollback due to inconsistencies with sold2 config to dataclass

* chore: remove repetitive words (#2902)

* chore: remove repetitive words

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* CI: Drop macos-latest runner for torch 1.9.1 (#2905)

* fix (CI): remove old torch on macos

* chore: ensure last pytorch

* [pre-commit.ci] pre-commit suggestions (#2894)

* [pre-commit.ci] pre-commit suggestions

updates:
- [github.com/astral-sh/ruff-pre-commit: v0.4.1 → v0.4.2](astral-sh/ruff-pre-commit@v0.4.1...v0.4.2)

* fix up031

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: João Gustavo A. Amorim <joaogustavoamorim@gmail.com>

* feat: in_range filtering (#2895)

* initial commit

* add docs

* add tests

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update docs

* correct typing

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update tests and docs

* change randn to rand

* Modify docs indentation

* correct docs and remove unused vars

* add return_mask

* Remove shape in doc

* fix docs format

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: edgar <edgar.riba@gmail.com>

---------

Co-authored-by: peicuiping <168072318+peicuiping@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: João Gustavo A. Amorim <joaogustavoamorim@gmail.com>
Co-authored-by: Vicent Gilabert <44602177+vgilabert94@users.noreply.github.com>
Co-authored-by: edgar <edgar.riba@gmail.com>
  • Loading branch information
6 people committed May 16, 2024
1 parent bdd07f3 commit 3ce96a3
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 57 deletions.
4 changes: 2 additions & 2 deletions kornia/feature/sold2/sold2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,16 @@

from kornia.core import Module, Tensor, concatenate, pad, stack
from kornia.core.check import KORNIA_CHECK_SHAPE
from kornia.feature.sold2.structures import LineDetectorCfg
from kornia.geometry.conversions import normalize_pixel_coordinates
from kornia.utils import map_location_to_cpu

from .backbones import SOLD2Net
from .sold2_detector import LineDetectorCfg, LineSegmentDetectionModule, line_map_to_segments, prob_to_junctions
from .sold2_detector import LineSegmentDetectionModule, line_map_to_segments, prob_to_junctions

urls: Dict[str, str] = {}
urls["wireframe"] = "http://cmp.felk.cvut.cz/~mishkdmy/models/sold2_wireframe.pth"


default_cfg: Dict[str, Any] = {
"backbone_cfg": {"input_channel": 1, "depth": 4, "num_stacks": 2, "num_blocks": 1, "num_classes": 5},
"use_descriptor": True,
Expand Down
58 changes: 3 additions & 55 deletions kornia/feature/sold2/sold2_detector.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import math
import warnings
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Tuple

import torch

from kornia.core import Module, Tensor, concatenate, sin, stack, tensor, where, zeros
from kornia.core.check import KORNIA_CHECK_SHAPE
from kornia.feature.sold2.structures import DetectorCfg, HeatMapRefineCfg, JunctionRefineCfg, LineDetectorCfg
from kornia.geometry.bbox import nms
from kornia.utils import dataclass_to_dict, dict_to_dataclass, map_location_to_cpu, torch_meshgrid

Expand All @@ -16,58 +16,6 @@
urls["wireframe"] = "https://www.polybox.ethz.ch/index.php/s/blOrW89gqSLoHOk/download"


@dataclass
class HeatMapRefineCfg:
mode: str = "local"
ratio: float = 0.2
valid_thresh: float = 0.001
num_blocks: int = 20
overlap_ratio: float = 0.5


@dataclass
class JunctionRefineCfg:
num_perturbs: int = 9
perturb_interval: float = 0.25


@dataclass
class LineDetectorCfg:
detect_thresh: float = 0.5
num_samples: int = 64
inlier_thresh: float = 0.99
use_candidate_suppression: bool = True
nms_dist_tolerance: float = 3.0
heatmap_low_thresh: float = 0.15
heatmap_high_thresh: float = 0.2
max_local_patch_radius: float = 3
lambda_radius: float = 2.0
use_heatmap_refinement: bool = True
heatmap_refine_cfg: HeatMapRefineCfg = field(default_factory=HeatMapRefineCfg)
use_junction_refinement: bool = True
junction_refine_cfg: JunctionRefineCfg = field(default_factory=JunctionRefineCfg)


@dataclass
class BackboneCfg:
input_channel: int = 1
depth: int = 4
num_stacks: int = 2
num_blocks: int = 1
num_classes: int = 5


@dataclass
class DetectorCfg:
backbone_cfg: BackboneCfg = field(default_factory=BackboneCfg)
use_descriptor: bool = False
grid_size: int = 8
keep_border_valid: bool = True
detection_thresh: float = 0.0153846 # = 1/65: threshold of junction detection
max_num_junctions: int = 500 # maximum number of junctions per image
line_detector_cfg: LineDetectorCfg = field(default_factory=LineDetectorCfg)


class SOLD2_detector(Module):
r"""Module, which detects line segments in an image.
Expand All @@ -93,8 +41,8 @@ def __init__(self, pretrained: bool = True, config: Optional[DetectorCfg] = None
if isinstance(config, dict):
warnings.warn(
"Usage of config as a plain dictionary is deprecated in favor of"
" `kornia.feature.sold2.sold2_detector.DetectorCfg`. The support of plain dictionaries"
"as config will be removed in kornia v0.8.0 (December 2024).",
"`kornia.features.sold2.structures.DetectorCfg`. The support of plain"
"dictionaries as config will be removed in kornia v0.8.0 (December 2024).",
category=DeprecationWarning,
stacklevel=2,
)
Expand Down
64 changes: 64 additions & 0 deletions kornia/feature/sold2/structures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from dataclasses import dataclass, field


@dataclass
class HeatMapRefineCfg:
mode: str = "local"
ratio: float = 0.2
valid_thresh: float = 0.001
num_blocks: int = 20
overlap_ratio: float = 0.5


@dataclass
class JunctionRefineCfg:
num_perturbs: int = 9
perturb_interval: float = 0.25


@dataclass
class LineDetectorCfg:
detect_thresh: float = 0.5
num_samples: int = 64
inlier_thresh: float = 0.99
use_candidate_suppression: bool = True
nms_dist_tolerance: float = 3.0
heatmap_low_thresh: float = 0.15
heatmap_high_thresh: float = 0.2
max_local_patch_radius: float = 3
lambda_radius: float = 2.0
use_heatmap_refinement: bool = True
heatmap_refine_cfg: HeatMapRefineCfg = field(default_factory=HeatMapRefineCfg)
use_junction_refinement: bool = True
junction_refine_cfg: JunctionRefineCfg = field(default_factory=JunctionRefineCfg)


@dataclass
class LineMatcherCfg:
cross_check: bool = True
num_samples: int = 10
min_dist_pts: int = 8
top_k_candidates: int = 10
grid_size: int = 8
line_score: bool = False # True to compute saliency on a line


@dataclass
class BackboneCfg:
input_channel: int = 1
depth: int = 4
num_stacks: int = 2
num_blocks: int = 1
num_classes: int = 5


@dataclass
class DetectorCfg:
backbone_cfg: BackboneCfg = field(default_factory=BackboneCfg)
use_descriptor: bool = False
grid_size: int = 8
keep_border_valid: bool = True
detection_thresh: float = 0.0153846 # = 1/65: threshold of junction detection
max_num_junctions: int = 500 # maximum number of junctions per image
line_detector_cfg: LineDetectorCfg = field(default_factory=LineDetectorCfg)
line_matcher_cfg: LineMatcherCfg = field(default_factory=LineMatcherCfg)

0 comments on commit 3ce96a3

Please sign in to comment.