# Playground for build & test JPLDD Model

In [1]:
# Imports
import cv2
import matplotlib.pyplot as plt

import math
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.functional import pixel_shuffle, softmax
import torchvision
from torchvision.models import resnet
from torch.nn.modules.utils import _pair
from typing import Optional, Callable

Now Load sample Image:

## Add Network Definitions
ALIKED Backbone Encoder Parts

In [2]:
aliked_cfgs = {
    "aliked-t16": {
        "c1": 8,
        "c2": 16,
        "c3": 32,
        "c4": 64,
        "dim": 64,
        "K": 3,
        "M": 16,
    },
    "aliked-n16": {
        "c1": 16,
        "c2": 32,
        "c3": 64,
        "c4": 128,
        "dim": 128,
        "K": 3,
        "M": 16,
    },
    "aliked-n16rot": {
        "c1": 16,
        "c2": 32,
        "c3": 64,
        "c4": 128,
        "dim": 128,
        "K": 3,
        "M": 16,
    },
    "aliked-n32": {
        "c1": 16,
        "c2": 32,
        "c3": 64,
        "c4": 128,
        "dim": 128,
        "K": 3,
        "M": 32,
    },
}


class AlikedEncoder(nn.Module):
    def __init__(self, conf):
        super().__init__()
        # get configurations
        c1, c2, c3, c4, dim = conf["c1"], conf["c2"], conf["c3"], conf["c4"], conf["dim"]
        conv_types = ["conv", "conv", "dcn", "dcn"]
        mask = False

        # build model
        self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)
        self.pool4 = nn.AvgPool2d(kernel_size=4, stride=4)
        self.norm = nn.BatchNorm2d
        self.gate = nn.SELU(inplace=True)
        self.block1 = ConvBlock(3, c1, self.gate, self.norm, conv_type=conv_types[0])
        self.block2 = ResBlock(
            c1,
            c2,
            1,
            nn.Conv2d(c1, c2, 1),
            gate=self.gate,
            norm_layer=self.norm,
            conv_type=conv_types[1],
        )
        self.block3 = ResBlock(
            c2,
            c3,
            1,
            nn.Conv2d(c2, c3, 1),
            gate=self.gate,
            norm_layer=self.norm,
            conv_type=conv_types[2],
            mask=mask,
        )
        self.block4 = ResBlock(
            c3,
            c4,
            1,
            nn.Conv2d(c3, c4, 1),
            gate=self.gate,
            norm_layer=self.norm,
            conv_type=conv_types[3],
            mask=mask,
        )
        self.conv1 = resnet.conv1x1(c1, dim // 4)
        self.conv2 = resnet.conv1x1(c2, dim // 4)
        self.conv3 = resnet.conv1x1(c3, dim // 4)
        self.conv4 = resnet.conv1x1(dim, dim // 4)
        self.upsample2 = nn.Upsample(
            scale_factor=2, mode="bilinear", align_corners=True
        )
        self.upsample4 = nn.Upsample(
            scale_factor=4, mode="bilinear", align_corners=True
        )
        self.upsample8 = nn.Upsample(
            scale_factor=8, mode="bilinear", align_corners=True
        )
        self.upsample32 = nn.Upsample(
            scale_factor=32, mode="bilinear", align_corners=True
        )

    def forward(self, image: torch.Tensor) -> torch.Tensor:
        # ================================== feature encoder
        x1 = self.block1(image)  # B x c1 x H x W
        x2 = self.pool2(x1)
        x2 = self.block2(x2)  # B x c2 x H/2 x W/2
        x3 = self.pool4(x2)
        x3 = self.block3(x3)  # B x c3 x H/8 x W/8
        x4 = self.pool4(x3)
        x4 = self.block4(x4)  # B x dim x H/32 x W/32
        # ================================== feature aggregation
        x1 = self.gate(self.conv1(x1))  # B x dim//4 x H x W
        x2 = self.gate(self.conv2(x2))  # B x dim//4 x H//2 x W//2
        x3 = self.gate(self.conv3(x3))  # B x dim//4 x H//8 x W//8
        x4 = self.gate(self.conv4(x4))  # B x dim//4 x H//32 x W//32
        x2_up = self.upsample2(x2)  # B x dim//4 x H x W
        x3_up = self.upsample8(x3)  # B x dim//4 x H x W
        x4_up = self.upsample32(x4)  # B x dim//4 x H x W
        x1234 = torch.cat([x1, x2_up, x3_up, x4_up], dim=1)

        return x1234
    
class DeformableConv2d(nn.Module):
    def __init__(
            self,
            in_channels,
            out_channels,
            kernel_size=3,
            stride=1,
            padding=1,
            bias=False,
            mask=False,
    ):
        super(DeformableConv2d, self).__init__()

        self.padding = padding
        self.mask = mask

        self.channel_num = (
            3 * kernel_size * kernel_size if mask else 2 * kernel_size * kernel_size
        )
        self.offset_conv = nn.Conv2d(
            in_channels,
            self.channel_num,
            kernel_size=kernel_size,
            stride=stride,
            padding=self.padding,
            bias=True,
        )

        self.regular_conv = nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=self.padding,
            bias=bias,
        )

    def forward(self, x):
        h, w = x.shape[2:]
        max_offset = max(h, w) / 4.0

        out = self.offset_conv(x)
        if self.mask:
            o1, o2, mask = torch.chunk(out, 3, dim=1)
            offset = torch.cat((o1, o2), dim=1)
            mask = torch.sigmoid(mask)
        else:
            offset = out
            mask = None
        offset = offset.clamp(-max_offset, max_offset)
        x = torchvision.ops.deform_conv2d(
            input=x,
            offset=offset,
            weight=self.regular_conv.weight,
            bias=self.regular_conv.bias,
            padding=self.padding,
            mask=mask,
        )
        return x


def get_conv(
        inplanes,
        planes,
        kernel_size=3,
        stride=1,
        padding=1,
        bias=False,
        conv_type="conv",
        mask=False,
):
    if conv_type == "conv":
        conv = nn.Conv2d(
            inplanes,
            planes,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            bias=bias,
        )
    elif conv_type == "dcn":
        conv = DeformableConv2d(
            inplanes,
            planes,
            kernel_size=kernel_size,
            stride=stride,
            padding=_pair(padding),
            bias=bias,
            mask=mask,
        )
    else:
        raise TypeError
    return conv


class ConvBlock(nn.Module):
    def __init__(
            self,
            in_channels,
            out_channels,
            gate: Optional[Callable[..., nn.Module]] = None,
            norm_layer: Optional[Callable[..., nn.Module]] = None,
            conv_type: str = "conv",
            mask: bool = False,
    ):
        super().__init__()
        if gate is None:
            self.gate = nn.ReLU(inplace=True)
        else:
            self.gate = gate
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self.conv1 = get_conv(
            in_channels, out_channels, kernel_size=3, conv_type=conv_type, mask=mask
        )
        self.bn1 = norm_layer(out_channels)
        self.conv2 = get_conv(
            out_channels, out_channels, kernel_size=3, conv_type=conv_type, mask=mask
        )
        self.bn2 = norm_layer(out_channels)

    def forward(self, x):
        x = self.gate(self.bn1(self.conv1(x)))  # B x in_channels x H x W
        x = self.gate(self.bn2(self.conv2(x)))  # B x out_channels x H x W
        return x


# modified based on torchvision\models\resnet.py#27->BasicBlock
class ResBlock(nn.Module):
    expansion: int = 1

    def __init__(
            self,
            inplanes: int,
            planes: int,
            stride: int = 1,
            downsample: Optional[nn.Module] = None,
            groups: int = 1,
            base_width: int = 64,
            dilation: int = 1,
            gate: Optional[Callable[..., nn.Module]] = None,
            norm_layer: Optional[Callable[..., nn.Module]] = None,
            conv_type: str = "conv",
            mask: bool = False,
    ) -> None:
        super(ResBlock, self).__init__()
        if gate is None:
            self.gate = nn.ReLU(inplace=True)
        else:
            self.gate = gate
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError("ResBlock only supports groups=1 and base_width=64")
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in ResBlock")
        # Both self.conv1 and self.downsample layers
        # downsample the input when stride != 1
        self.conv1 = get_conv(
            inplanes, planes, kernel_size=3, conv_type=conv_type, mask=mask
        )
        self.bn1 = norm_layer(planes)
        self.conv2 = get_conv(
            planes, planes, kernel_size=3, conv_type=conv_type, mask=mask
        )
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.gate(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.gate(out)

        return out

SDDH(Descriptors) also from ALIKED

In [3]:
# Helper used in SDDH
def get_patches(
        tensor: torch.Tensor, required_corners: torch.Tensor, ps: int
) -> torch.Tensor:
    c, h, w = tensor.shape
    corner = (required_corners - ps / 2 + 1).long()
    corner[:, 0] = corner[:, 0].clamp(min=0, max=w - 1 - ps)
    corner[:, 1] = corner[:, 1].clamp(min=0, max=h - 1 - ps)
    offset = torch.arange(0, ps)

    kw = {"indexing": "ij"} if torch.__version__ >= "1.10" else {}
    x, y = torch.meshgrid(offset, offset, **kw)
    patches = torch.stack((x, y)).permute(2, 1, 0).unsqueeze(2)
    patches = patches.to(corner) + corner[None, None]
    pts = patches.reshape(-1, 2)
    sampled = tensor.permute(1, 2, 0)[tuple(pts.T)[::-1]]
    sampled = sampled.reshape(ps, ps, -1, c)
    assert sampled.shape[:3] == patches.shape[:3]
    return sampled.permute(2, 3, 0, 1)


class SDDH(nn.Module):
    def __init__(
            self,
            dims: int,
            kernel_size: int = 3,
            n_pos: int = 8,
            gate=nn.ReLU(),
            conv2D=False,
            mask=False,
    ):
        super(SDDH, self).__init__()
        self.kernel_size = kernel_size
        self.n_pos = n_pos
        self.conv2D = conv2D
        self.mask = mask

        self.get_patches_func = get_patches

        # estimate offsets
        self.channel_num = 3 * n_pos if mask else 2 * n_pos
        self.offset_conv = nn.Sequential(
            nn.Conv2d(
                dims,
                self.channel_num,
                kernel_size=kernel_size,
                stride=1,
                padding=0,
                bias=True,
            ),
            gate,
            nn.Conv2d(
                self.channel_num,
                self.channel_num,
                kernel_size=1,
                stride=1,
                padding=0,
                bias=True,
            ),
        )

        # sampled feature conv
        self.sf_conv = nn.Conv2d(
            dims, dims, kernel_size=1, stride=1, padding=0, bias=False
        )

        # convM
        if not conv2D:
            # deformable desc weights
            agg_weights = torch.nn.Parameter(torch.rand(n_pos, dims, dims))
            self.register_parameter("agg_weights", agg_weights)
        else:
            self.convM = nn.Conv2d(
                dims * n_pos, dims, kernel_size=1, stride=1, padding=0, bias=False
            )

    def forward(self, x, keypoints):
        # x: [B,C,H,W]
        # keypoints: list, [[N_kpts,2], ...] (w,h)
        b, c, h, w = x.shape
        wh = torch.tensor([[w - 1, h - 1]], device=x.device)
        max_offset = max(h, w) / 4.0

        offsets = []
        descriptors = []
        # get offsets for each keypoint
        for ib in range(b):
            xi, kptsi = x[ib], keypoints[ib]
            kptsi_wh = (kptsi / 2 + 0.5) * wh
            N_kpts = len(kptsi)

            if self.kernel_size > 1:
                patch = self.get_patches_func(
                    xi, kptsi_wh.long(), self.kernel_size
                )  # [N_kpts, C, K, K]
            else:
                kptsi_wh_long = kptsi_wh.long()
                patch = (
                    xi[:, kptsi_wh_long[:, 1], kptsi_wh_long[:, 0]]
                    .permute(1, 0)
                    .reshape(N_kpts, c, 1, 1)
                )

            offset = self.offset_conv(patch).clamp(
                -max_offset, max_offset
            )  # [N_kpts, 2*n_pos, 1, 1]
            if self.mask:
                offset = (
                    offset[:, :, 0, 0].view(N_kpts, 3, self.n_pos).permute(0, 2, 1)
                )  # [N_kpts, n_pos, 3]
                offset = offset[:, :, :-1]  # [N_kpts, n_pos, 2]
                mask_weight = torch.sigmoid(offset[:, :, -1])  # [N_kpts, n_pos]
            else:
                offset = (
                    offset[:, :, 0, 0].view(N_kpts, 2, self.n_pos).permute(0, 2, 1)
                )  # [N_kpts, n_pos, 2]
            offsets.append(offset)  # for visualization

            # get sample positions
            pos = kptsi_wh.unsqueeze(1) + offset  # [N_kpts, n_pos, 2]
            pos = 2.0 * pos / wh[None] - 1
            pos = pos.reshape(1, N_kpts * self.n_pos, 1, 2)

            # sample features
            features = F.grid_sample(
                xi.unsqueeze(0), pos, mode="bilinear", align_corners=True
            )  # [1,C,(N_kpts*n_pos),1]
            features = features.reshape(c, N_kpts, self.n_pos, 1).permute(
                1, 0, 2, 3
            )  # [N_kpts, C, n_pos, 1]
            if self.mask:
                features = torch.einsum("ncpo,np->ncpo", features, mask_weight)

            features = torch.selu_(self.sf_conv(features)).squeeze(
                -1
            )  # [N_kpts, C, n_pos]
            # convM
            if not self.conv2D:
                descs = torch.einsum(
                    "ncp,pcd->nd", features, self.agg_weights
                )  # [N_kpts, C]
            else:
                features = features.reshape(N_kpts, -1)[
                           :, :, None, None
                           ]  # [N_kpts, C*n_pos, 1, 1]
                descs = self.convM(features).squeeze()  # [N_kpts, C]

            # normalize
            descs = F.normalize(descs, p=2.0, dim=1)
            descriptors.append(descs)

        return descriptors, offsets


SMH(Keypoint+Junction-Map) taken from ALIKED

In [4]:
class SMH(nn.Module):
    def __init__(self, input_dim):
        super(SMH, self).__init__()
        self.gate = nn.SELU(inplace=True)
        self.score_head = nn.Sequential(
            resnet.conv1x1(input_dim, 8),
            self.gate,
            resnet.conv3x3(8, 4),
            self.gate,
            resnet.conv3x3(4, 4),
            self.gate,
            resnet.conv3x3(4, 1),
        )

    def forward(self, x):
        # expects feature map not normalized
        return torch.sigmoid(self.score_head(x))

DKD(Extract Keypoints from heatmap)

In [5]:
# Util Needed for DKD
def simple_nms(scores: torch.Tensor, nms_radius: int):
    """Fast Non-maximum suppression to remove nearby points"""

    zeros = torch.zeros_like(scores)
    max_mask = scores == torch.nn.functional.max_pool2d(
        scores, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius
    )

    for _ in range(2):
        supp_mask = (
                torch.nn.functional.max_pool2d(
                    max_mask.float(),
                    kernel_size=nms_radius * 2 + 1,
                    stride=1,
                    padding=nms_radius,
                )
                > 0
        )
        supp_scores = torch.where(supp_mask, zeros, scores)
        new_max_mask = supp_scores == torch.nn.functional.max_pool2d(
            supp_scores, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius
        )
        max_mask = max_mask | (new_max_mask & (~supp_mask))
    return torch.where(max_mask, scores, zeros)

class DKD(nn.Module):
    def __init__(
            self,
            radius: int = 2,
            top_k: int = 0,
            scores_th: float = 0.2,
            n_limit: int = 20000,
    ):
        """
        Args:
            radius: soft detection radius, kernel size is (2 * radius + 1)
            top_k: top_k > 0: return top k keypoints
            scores_th: top_k <= 0 threshold mode:
                scores_th > 0: return keypoints with scores>scores_th
                else: return keypoints with scores > scores.mean()
            n_limit: max number of keypoint in threshold mode
        """
        super().__init__()
        self.radius = radius
        self.top_k = top_k
        self.scores_th = scores_th
        self.n_limit = n_limit
        self.kernel_size = 2 * self.radius + 1
        self.temperature = 0.1  # tuned temperature
        self.unfold = nn.Unfold(kernel_size=self.kernel_size, padding=self.radius)
        # local xy grid
        x = torch.linspace(-self.radius, self.radius, self.kernel_size)
        # (kernel_size*kernel_size) x 2 : (w,h)
        kw = {"indexing": "ij"} if torch.__version__ >= "1.10" else {}
        self.hw_grid = (
            torch.stack(torch.meshgrid([x, x], **kw)).view(2, -1).t()[:, [1, 0]]
        )

    def forward(
            self,
            scores_map: torch.Tensor,
            sub_pixel: bool = True,
            image_size: Optional[torch.Tensor] = None,
    ):
        """
        :param scores_map: Bx1xHxW
        :param descriptor_map: BxCxHxW
        :param sub_pixel: whether to use sub-pixel keypoint detection
        :return: kpts: list[Nx2,...]; kptscores: list[N,....] normalised position: -1~1
        """
        b, c, h, w = scores_map.shape
        scores_nograd = scores_map.detach()
        nms_scores = simple_nms(scores_nograd, self.radius)

        # remove border
        nms_scores[:, :, : self.radius, :] = 0
        nms_scores[:, :, :, : self.radius] = 0
        if image_size is not None:
            for i in range(scores_map.shape[0]):
                w, h = image_size[i].long()
                nms_scores[i, :, h.item() - self.radius:, :] = 0
                nms_scores[i, :, :, w.item() - self.radius:] = 0
        else:
            nms_scores[:, :, -self.radius:, :] = 0
            nms_scores[:, :, :, -self.radius:] = 0

        # detect keypoints without grad
        if self.top_k > 0:
            topk = torch.topk(nms_scores.view(b, -1), self.top_k)
            indices_keypoints = [topk.indices[i] for i in range(b)]  # B x top_k
        else:
            if self.scores_th > 0:
                masks = nms_scores > self.scores_th
                if masks.sum() == 0:
                    th = scores_nograd.reshape(b, -1).mean(dim=1)  # th = self.scores_th
                    masks = nms_scores > th.reshape(b, 1, 1, 1)
            else:
                th = scores_nograd.reshape(b, -1).mean(dim=1)  # th = self.scores_th
                masks = nms_scores > th.reshape(b, 1, 1, 1)
            masks = masks.reshape(b, -1)

            indices_keypoints = []  # list, B x (any size)
            scores_view = scores_nograd.reshape(b, -1)
            for mask, scores in zip(masks, scores_view):
                indices = mask.nonzero()[:, 0]
                if len(indices) > self.n_limit:
                    kpts_sc = scores[indices]
                    sort_idx = kpts_sc.sort(descending=True)[1]
                    sel_idx = sort_idx[: self.n_limit]
                    indices = indices[sel_idx]
                indices_keypoints.append(indices)

        wh = torch.tensor([w - 1, h - 1], device=scores_nograd.device)

        keypoints = []
        scoredispersitys = []
        kptscores = []
        if sub_pixel:
            # detect soft keypoints with grad backpropagation
            patches = self.unfold(scores_map)  # B x (kernel**2) x (H*W)
            self.hw_grid = self.hw_grid.to(scores_map)  # to device
            for b_idx in range(b):
                patch = patches[b_idx].t()  # (H*W) x (kernel**2)
                indices_kpt = indices_keypoints[
                    b_idx
                ]  # one dimension vector, say its size is M
                patch_scores = patch[indices_kpt]  # M x (kernel**2)
                keypoints_xy_nms = torch.stack(
                    [indices_kpt % w, torch.div(indices_kpt, w, rounding_mode="trunc")],
                    dim=1,
                )  # Mx2

                # max is detached to prevent undesired backprop loops in the graph
                max_v = patch_scores.max(dim=1).values.detach()[:, None]
                x_exp = (
                        (patch_scores - max_v) / self.temperature
                ).exp()  # M * (kernel**2), in [0, 1]

                # \frac{ \sum{(i,j) \times \exp(x/T)} }{ \sum{\exp(x/T)} }
                xy_residual = (
                        x_exp @ self.hw_grid / x_exp.sum(dim=1)[:, None]
                )  # Soft-argmax, Mx2

                hw_grid_dist2 = (
                        torch.norm(
                            (self.hw_grid[None, :, :] - xy_residual[:, None, :])
                            / self.radius,
                            dim=-1,
                        )
                        ** 2
                )
                scoredispersity = (x_exp * hw_grid_dist2).sum(dim=1) / x_exp.sum(dim=1)

                # compute result keypoints
                keypoints_xy = keypoints_xy_nms + xy_residual
                keypoints_xy = keypoints_xy / wh * 2 - 1  # (w,h) -> (-1~1,-1~1)

                kptscore = torch.nn.functional.grid_sample(
                    scores_map[b_idx].unsqueeze(0),
                    keypoints_xy.view(1, 1, -1, 2),
                    mode="bilinear",
                    align_corners=True,
                )[
                           0, 0, 0, :
                           ]  # CxN

                keypoints.append(keypoints_xy)
                scoredispersitys.append(scoredispersity)
                kptscores.append(kptscore)
        else:
            for b_idx in range(b):
                indices_kpt = indices_keypoints[
                    b_idx
                ]  # one dimension vector, say its size is M
                # To avoid warning: UserWarning: __floordiv__ is deprecated
                keypoints_xy_nms = torch.stack(
                    [indices_kpt % w, torch.div(indices_kpt, w, rounding_mode="trunc")],
                    dim=1,
                )  # Mx2
                keypoints_xy = keypoints_xy_nms / wh * 2 - 1  # (w,h) -> (-1~1,-1~1)
                kptscore = torch.nn.functional.grid_sample(
                    scores_map[b_idx].unsqueeze(0),
                    keypoints_xy.view(1, 1, -1, 2),
                    mode="bilinear",
                    align_corners=True,
                )[
                           0, 0, 0, :
                           ]  # CxN
                keypoints.append(keypoints_xy)
                scoredispersitys.append(kptscore)  # for jit.script compatability
                kptscores.append(kptscore)

        return keypoints, scoredispersitys, kptscores

Line Heatmap decoder (Taken from SOLD2)

In [6]:
class PixelShuffleDecoder(nn.Module):
    """ Pixel shuffle decoder. """

    def __init__(self, input_feat_dim=128, num_upsample=2, output_channel=2):
        super(PixelShuffleDecoder, self).__init__()
        # Get channel parameters
        self.channel_conf = self.get_channel_conf(num_upsample)

        # Define the pixel shuffle
        self.pixshuffle = nn.PixelShuffle(2)

        # Process the feature
        self.conv_block_lst = []
        # The input block
        self.conv_block_lst.append(
            nn.Sequential(
                nn.Conv2d(input_feat_dim, self.channel_conf[0],
                          kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(self.channel_conf[0]),
                nn.ReLU(inplace=True)
            ))

        # Intermediate block
        for channel in self.channel_conf[1:-1]:
            self.conv_block_lst.append(
                nn.Sequential(
                    nn.Conv2d(channel, channel, kernel_size=3,
                              stride=1, padding=1),
                    nn.BatchNorm2d(channel),
                    nn.ReLU(inplace=True)
                ))

        # Output block
        self.conv_block_lst.append(
            nn.Conv2d(self.channel_conf[-1], output_channel,
                      kernel_size=1, stride=1, padding=0)
        )
        self.conv_block_lst = nn.ModuleList(self.conv_block_lst)

    # Get num of channels based on number of upsampling.
    def get_channel_conf(self, num_upsample):
        if num_upsample == 2:
            return [256, 64, 16]
        elif num_upsample == 3:
            return [256, 64, 16, 4]

    def forward(self, input_features):
        # Iterate til output block
        out = input_features
        for block in self.conv_block_lst[:-1]:
            out = block(out)
            out = self.pixshuffle(out)

        # Output layer
        out = self.conv_block_lst[-1](out)

        return out

Add Line Extractor (taken from SOLD2)

In [7]:
def line_map_to_segments(junctions, line_map):
    """ Convert a line map to a Nx2x2 list of segments. """
    line_map_tmp = line_map.copy()

    output_segments = np.zeros([0, 2, 2])
    for idx in range(junctions.shape[0]):
        # if no connectivity, just skip it
        if line_map_tmp[idx, :].sum() == 0:
            continue
        # Record the line segment
        else:
            for idx2 in np.where(line_map_tmp[idx, :] == 1)[0]:
                p1 = junctions[idx, :]  # HW format
                p2 = junctions[idx2, :]
                single_seg = np.concatenate([p1[None, ...], p2[None, ...]],
                                            axis=0)
                output_segments = np.concatenate(
                    (output_segments, single_seg[None, ...]), axis=0)

                # Update line_map
                line_map_tmp[idx, idx2] = 0
                line_map_tmp[idx2, idx] = 0

    return output_segments


# Taken from SOLD2
def convert_junc_predictions(predictions, grid_size,
                             detect_thresh=1 / 65, topk=300):
    """ Convert torch predictions to numpy arrays for evaluation. """
    # Convert to probability outputs first
    junc_prob = softmax(predictions.detach(), dim=1).cpu()
    junc_pred = junc_prob[:, :-1, :, :]

    junc_prob_np = junc_prob.numpy().transpose(0, 2, 3, 1)[:, :, :, :-1]
    junc_prob_np = np.sum(junc_prob_np, axis=-1)
    junc_pred_np = pixel_shuffle(
        junc_pred, grid_size).cpu().numpy().transpose(0, 2, 3, 1)
    junc_pred_np_nms = super_nms(junc_pred_np, grid_size, detect_thresh, topk)
    junc_pred_np = junc_pred_np.squeeze(-1)

    return {"junc_pred": junc_pred_np, "junc_pred_nms": junc_pred_np_nms,
            "junc_prob": junc_prob_np}


# Taken from SOLD2
def super_nms(prob_predictions, dist_thresh, prob_thresh=0.01, top_k=0):
    """ Non-maximum suppression adapted from SuperPoint. """
    # Iterate through batch dimension
    im_h = prob_predictions.shape[1]
    im_w = prob_predictions.shape[2]
    output_lst = []
    for i in range(prob_predictions.shape[0]):
        # print(i)
        prob_pred = prob_predictions[i, ...]
        # Filter the points using prob_thresh
        coord = np.where(prob_pred >= prob_thresh)  # HW format
        points = np.concatenate((coord[0][..., None], coord[1][..., None]),
                                axis=1)  # HW format

        # Get the probability score
        prob_score = prob_pred[points[:, 0], points[:, 1]]

        # Perform super nms
        # Modify the in_points to xy format (instead of HW format)
        in_points = np.concatenate((coord[1][..., None], coord[0][..., None],
                                    prob_score), axis=1).T
        keep_points_, keep_inds = nms_fast(in_points, im_h, im_w, dist_thresh)
        # Remember to flip outputs back to HW format
        keep_points = np.round(np.flip(keep_points_[:2, :], axis=0).T)
        keep_score = keep_points_[-1, :].T

        # Whether we only keep the topk value
        if (top_k > 0) or (top_k is None):
            k = min([keep_points.shape[0], top_k])
            keep_points = keep_points[:k, :]
            keep_score = keep_score[:k]

        # Re-compose the probability map
        output_map = np.zeros([im_h, im_w])
        output_map[keep_points[:, 0].astype(np.int),
        keep_points[:, 1].astype(np.int)] = keep_score.squeeze()

        output_lst.append(output_map[None, ...])

    return np.concatenate(output_lst, axis=0)


# Taken from SOLD2
def nms_fast(in_corners, H, W, dist_thresh):
    """
    Run a faster approximate Non-Max-Suppression on numpy corners shaped:
      3xN [x_i,y_i,conf_i]^T

    Algo summary: Create a grid sized HxW. Assign each corner location a 1,
    rest are zeros. Iterate through all the 1's and convert them to -1 or 0.
    Suppress points by setting nearby values to 0.

    Grid Value Legend:
    -1 : Kept.
     0 : Empty or suppressed.
     1 : To be processed (converted to either kept or supressed).

    NOTE: The NMS first rounds points to integers, so NMS distance might not
    be exactly dist_thresh. It also assumes points are within image boundary.

    Inputs
      in_corners - 3xN numpy array with corners [x_i, y_i, confidence_i]^T.
      H - Image height.
      W - Image width.
      dist_thresh - Distance to suppress, measured as an infinite distance.
    Returns
      nmsed_corners - 3xN numpy matrix with surviving corners.
      nmsed_inds - N length numpy vector with surviving corner indices.
    """
    grid = np.zeros((H, W)).astype(int)  # Track NMS data.
    inds = np.zeros((H, W)).astype(int)  # Store indices of points.
    # Sort by confidence and round to nearest int.
    inds1 = np.argsort(-in_corners[2, :])
    corners = in_corners[:, inds1]
    rcorners = corners[:2, :].round().astype(int)  # Rounded corners.
    # Check for edge case of 0 or 1 corners.
    if rcorners.shape[1] == 0:
        return np.zeros((3, 0)).astype(int), np.zeros(0).astype(int)
    if rcorners.shape[1] == 1:
        out = np.vstack((rcorners, in_corners[2])).reshape(3, 1)
        return out, np.zeros((1)).astype(int)
    # Initialize the grid.
    for i, rc in enumerate(rcorners.T):
        grid[rcorners[1, i], rcorners[0, i]] = 1
        inds[rcorners[1, i], rcorners[0, i]] = i
    # Pad the border of the grid, so that we can NMS points near the border.
    pad = dist_thresh
    grid = np.pad(grid, ((pad, pad), (pad, pad)), mode='constant')
    # Iterate through points, highest to lowest conf, suppress neighborhood.
    count = 0
    for i, rc in enumerate(rcorners.T):
        # Account for top and left padding.
        pt = (rc[0] + pad, rc[1] + pad)
        if grid[pt[1], pt[0]] == 1:  # If not yet suppressed.
            grid[pt[1] - pad:pt[1] + pad + 1, pt[0] - pad:pt[0] + pad + 1] = 0
            grid[pt[1], pt[0]] = -1
            count += 1
    # Get all surviving -1's and return sorted array of remaining corners.
    keepy, keepx = np.where(grid == -1)
    keepy, keepx = keepy - pad, keepx - pad
    inds_keep = inds[keepy, keepx]
    out = corners[:, inds_keep]
    values = out[-1, :]
    inds2 = np.argsort(-values)
    out = out[:, inds2]
    out_inds = inds1[inds_keep[inds2]]
    return out, out_inds

class LineExtractor(object):
    """
    Not learned method for line detection from junctions and line-heatmaps
    Adapted from SOLD2 implementation
    """
    # Line detector cfg in SOLD 2 can be found sold2/config/export_line_features.yaml under (line_detector_cfg)
    # ToDo: check config handling & different configs for training and evaluation? (sold2 gives only detect thresh in training, other params not defined)
    def __init__(self, device, line_extractor_cfg):
        self.grid_size = line_extractor_cfg["grid_size"]
        self.junc_detect_thresh = line_extractor_cfg["junc_detect_thresh"]
        self.max_num_junctions = line_extractor_cfg["max_num_junctions"]
        self.device = device

        self.detect_thresh = line_extractor_cfg["detect_thresh"] # in cfg file: detection_thresh: 0.0153846 # 1/65
        self.line_detector = LineSegmentDetectionModule(self.detect_thresh)

    def __call__(self, junctions, line_heatmap, valid_mask=None):
        junc_np = convert_junc_predictions(
            junctions, self.grid_size,
            self.junc_detect_thresh, self.max_num_junctions)
        if valid_mask is None:
            junctions = np.where(junc_np["junc_pred_nms"].squeeze())
        else:
            junctions = np.where(junc_np["junc_pred_nms"].squeeze()
                                 * valid_mask)
        junctions = np.concatenate(
            [junctions[0][..., None], junctions[1][..., None]], axis=-1)

        if line_heatmap.shape[1] == 2:
            # Convert to single channel directly from here
            heatmap = softmax(line_heatmap, dim=1)[:, 1:, :, :]
        else:
            heatmap = torch.sigmoid(line_heatmap)
        heatmap = heatmap.cpu().numpy().transpose(0, 2, 3, 1)[0, :, :, 0]

        # Run the line detector.
        line_map, junctions, heatmap = self.line_detector.detect(
            junctions, heatmap, device=self.device)
        heatmap = heatmap.cpu().numpy()
        if isinstance(line_map, torch.Tensor):
            line_map = line_map.cpu().numpy()
        if isinstance(junctions, torch.Tensor):
            junctions = junctions.cpu().numpy()
        line_segments = line_map_to_segments(junctions, line_map)

        return line_segments, line_heatmap, junctions


class LineSegmentDetectionModule(object):
    """ Module extracting line segments from junctions and line heatmaps. """

    def __init__(
            self, detect_thresh, num_samples=64, sampling_method="local_max",
            inlier_thresh=0., heatmap_low_thresh=0.15, heatmap_high_thresh=0.2,
            max_local_patch_radius=3, lambda_radius=2.,
            use_candidate_suppression=False, nms_dist_tolerance=3.,
            use_heatmap_refinement=False, heatmap_refine_cfg=None,
            use_junction_refinement=False, junction_refine_cfg=None):
        """
        Parameters:
            detect_thresh: The probability threshold for mean activation (0. ~ 1.)
            num_samples: Number of sampling locations along the line segments.
            sampling_method: Sampling method on locations ("bilinear" or "local_max").
            inlier_thresh: The min inlier ratio to satisfy (0. ~ 1.) => 0. means no threshold.
            heatmap_low_thresh: The lowest threshold for the pixel to be considered as candidate in junction recovery.
            heatmap_high_thresh: The higher threshold for NMS in junction recovery.
            max_local_patch_radius: The max patch to be considered in local maximum search.
            lambda_radius: The lambda factor in linear local maximum search formulation
            use_candidate_suppression: Apply candidate suppression to break long segments into short sub-segments.
            nms_dist_tolerance: The distance tolerance for nms. Decide whether the junctions are on the line.
            use_heatmap_refinement: Use heatmap refinement method or not.
            heatmap_refine_cfg: The configs for heatmap refinement methods.
            use_junction_refinement: Use junction refinement method or not.
            junction_refine_cfg: The configs for junction refinement methods.
        """
        # Line detection parameters
        self.detect_thresh = detect_thresh

        # Line sampling parameters
        self.num_samples = num_samples
        self.sampling_method = sampling_method
        self.inlier_thresh = inlier_thresh
        self.local_patch_radius = max_local_patch_radius
        self.lambda_radius = lambda_radius

        # Detecting junctions on the boundary parameters
        self.low_thresh = heatmap_low_thresh
        self.high_thresh = heatmap_high_thresh

        # Pre-compute the linspace sampler
        self.sampler = np.linspace(0, 1, self.num_samples)
        self.torch_sampler = torch.linspace(0, 1, self.num_samples)

        # Long line segment suppression configuration
        self.use_candidate_suppression = use_candidate_suppression
        self.nms_dist_tolerance = nms_dist_tolerance

        # Heatmap refinement configuration
        self.use_heatmap_refinement = use_heatmap_refinement
        self.heatmap_refine_cfg = heatmap_refine_cfg
        if self.use_heatmap_refinement and self.heatmap_refine_cfg is None:
            raise ValueError("[Error] Missing heatmap refinement config.")

        # Junction refinement configuration
        self.use_junction_refinement = use_junction_refinement
        self.junction_refine_cfg = junction_refine_cfg
        if self.use_junction_refinement and self.junction_refine_cfg is None:
            raise ValueError("[Error] Missing junction refinement config.")

    def convert_inputs(self, inputs, device):
        """ Convert inputs to desired torch tensor. """
        if isinstance(inputs, np.ndarray):
            outputs = torch.tensor(inputs, dtype=torch.float32, device=device)
        elif isinstance(inputs, torch.Tensor):
            outputs = inputs.to(torch.float32).to(device)
        else:
            raise ValueError(
                "[Error] Inputs must either be torch tensor or numpy ndarray.")

        return outputs

    def detect(self, junctions, heatmap, device=torch.device("cpu")):
        """ Main function performing line segment detection. """
        # Convert inputs to torch tensor
        junctions = self.convert_inputs(junctions, device=device)
        heatmap = self.convert_inputs(heatmap, device=device)

        # Perform the heatmap refinement
        if self.use_heatmap_refinement:
            if self.heatmap_refine_cfg["mode"] == "global":
                heatmap = self.refine_heatmap(
                    heatmap,
                    self.heatmap_refine_cfg["ratio"],
                    self.heatmap_refine_cfg["valid_thresh"]
                )
            elif self.heatmap_refine_cfg["mode"] == "local":
                heatmap = self.refine_heatmap_local(
                    heatmap,
                    self.heatmap_refine_cfg["num_blocks"],
                    self.heatmap_refine_cfg["overlap_ratio"],
                    self.heatmap_refine_cfg["ratio"],
                    self.heatmap_refine_cfg["valid_thresh"]
                )

        # Initialize empty line map
        num_junctions = junctions.shape[0]
        line_map_pred = torch.zeros([num_junctions, num_junctions],
                                    device=device, dtype=torch.int32)

        # Stop if there are not enough junctions
        if num_junctions < 2:
            return line_map_pred, junctions, heatmap

        # Generate the candidate map
        candidate_map = torch.triu(torch.ones(
            [num_junctions, num_junctions], device=device, dtype=torch.int32),
            diagonal=1)

        # Fetch the image boundary
        if len(heatmap.shape) > 2:
            H, W, _ = heatmap.shape
        else:
            H, W = heatmap.shape

        # Optionally perform candidate filtering
        if self.use_candidate_suppression:
            candidate_map = self.candidate_suppression(junctions,
                                                       candidate_map)

        # Fetch the candidates
        candidate_index_map = torch.where(candidate_map)
        candidate_index_map = torch.cat([candidate_index_map[0][..., None],
                                         candidate_index_map[1][..., None]],
                                        dim=-1)

        # Get the corresponding start and end junctions
        candidate_junc_start = junctions[candidate_index_map[:, 0], :]
        candidate_junc_end = junctions[candidate_index_map[:, 1], :]

        # Get the sampling locations (N x 64)
        sampler = self.torch_sampler.to(device)[None, ...]
        cand_samples_h = candidate_junc_start[:, 0:1] * sampler + \
                         candidate_junc_end[:, 0:1] * (1 - sampler)
        cand_samples_w = candidate_junc_start[:, 1:2] * sampler + \
                         candidate_junc_end[:, 1:2] * (1 - sampler)

        # Clip to image boundary
        cand_h = torch.clamp(cand_samples_h, min=0, max=H - 1)
        cand_w = torch.clamp(cand_samples_w, min=0, max=W - 1)

        # Local maximum search
        if self.sampling_method == "local_max":
            # Compute normalized segment lengths
            segments_length = torch.sqrt(torch.sum(
                (candidate_junc_start.to(torch.float32) -
                 candidate_junc_end.to(torch.float32)) ** 2, dim=-1))
            normalized_seg_length = (segments_length
                                     / (((H ** 2) + (W ** 2)) ** 0.5))

            # Perform local max search
            num_cand = cand_h.shape[0]
            group_size = 10000
            if num_cand > group_size:
                num_iter = math.ceil(num_cand / group_size)
                sampled_feat_lst = []
                for iter_idx in range(num_iter):
                    if not iter_idx == num_iter - 1:
                        cand_h_ = cand_h[iter_idx * group_size:
                                         (iter_idx + 1) * group_size, :]
                        cand_w_ = cand_w[iter_idx * group_size:
                                         (iter_idx + 1) * group_size, :]
                        normalized_seg_length_ = normalized_seg_length[
                                                 iter_idx * group_size: (iter_idx + 1) * group_size]
                    else:
                        cand_h_ = cand_h[iter_idx * group_size:, :]
                        cand_w_ = cand_w[iter_idx * group_size:, :]
                        normalized_seg_length_ = normalized_seg_length[
                                                 iter_idx * group_size:]
                    sampled_feat_ = self.detect_local_max(
                        heatmap, cand_h_, cand_w_, H, W,
                        normalized_seg_length_, device)
                    sampled_feat_lst.append(sampled_feat_)
                sampled_feat = torch.cat(sampled_feat_lst, dim=0)
            else:
                sampled_feat = self.detect_local_max(
                    heatmap, cand_h, cand_w, H, W,
                    normalized_seg_length, device)
        # Bilinear sampling
        elif self.sampling_method == "bilinear":
            # Perform bilinear sampling
            sampled_feat = self.detect_bilinear(
                heatmap, cand_h, cand_w, H, W, device)
        else:
            raise ValueError("[Error] Unknown sampling method.")

        # [Simple threshold detection]
        # detection_results is a mask over all candidates
        detection_results = (torch.mean(sampled_feat, dim=-1)
                             > self.detect_thresh)

        # [Inlier threshold detection]
        if self.inlier_thresh > 0.:
            inlier_ratio = torch.sum(
                sampled_feat > self.detect_thresh,
                dim=-1).to(torch.float32) / self.num_samples
            detection_results_inlier = inlier_ratio >= self.inlier_thresh
            detection_results = detection_results * detection_results_inlier

        # Convert detection results back to line_map_pred
        detected_junc_indexes = candidate_index_map[detection_results, :]
        line_map_pred[detected_junc_indexes[:, 0],
        detected_junc_indexes[:, 1]] = 1
        line_map_pred[detected_junc_indexes[:, 1],
        detected_junc_indexes[:, 0]] = 1

        # Perform junction refinement
        if self.use_junction_refinement and len(detected_junc_indexes) > 0:
            junctions, line_map_pred = self.refine_junction_perturb(
                junctions, line_map_pred, heatmap, H, W, device)

        return line_map_pred, junctions, heatmap

    def refine_heatmap(self, heatmap, ratio=0.2, valid_thresh=1e-2):
        """ Global heatmap refinement method. """
        # Grab the top 10% values
        heatmap_values = heatmap[heatmap > valid_thresh]
        sorted_values = torch.sort(heatmap_values, descending=True)[0]
        top10_len = math.ceil(sorted_values.shape[0] * ratio)
        max20 = torch.mean(sorted_values[:top10_len])
        heatmap = torch.clamp(heatmap / max20, min=0., max=1.)
        return heatmap

    def refine_heatmap_local(self, heatmap, num_blocks=5, overlap_ratio=0.5,
                             ratio=0.2, valid_thresh=2e-3):
        """ Local heatmap refinement method. """
        # Get the shape of the heatmap
        H, W = heatmap.shape
        increase_ratio = 1 - overlap_ratio
        h_block = round(H / (1 + (num_blocks - 1) * increase_ratio))
        w_block = round(W / (1 + (num_blocks - 1) * increase_ratio))

        count_map = torch.zeros(heatmap.shape, dtype=torch.float,
                                device=heatmap.device)
        heatmap_output = torch.zeros(heatmap.shape, dtype=torch.float,
                                     device=heatmap.device)
        # Iterate through each block
        for h_idx in range(num_blocks):
            for w_idx in range(num_blocks):
                # Fetch the heatmap
                h_start = round(h_idx * h_block * increase_ratio)
                w_start = round(w_idx * w_block * increase_ratio)
                h_end = h_start + h_block if h_idx < num_blocks - 1 else H
                w_end = w_start + w_block if w_idx < num_blocks - 1 else W

                subheatmap = heatmap[h_start:h_end, w_start:w_end]
                if subheatmap.max() > valid_thresh:
                    subheatmap = self.refine_heatmap(
                        subheatmap, ratio, valid_thresh=valid_thresh)

                # Aggregate it to the final heatmap
                heatmap_output[h_start:h_end, w_start:w_end] += subheatmap
                count_map[h_start:h_end, w_start:w_end] += 1
        heatmap_output = torch.clamp(heatmap_output / count_map,
                                     max=1., min=0.)

        return heatmap_output

    def candidate_suppression(self, junctions, candidate_map):
        """ Suppress overlapping long lines in the candidate segments. """
        # Define the distance tolerance
        dist_tolerance = self.nms_dist_tolerance

        # Compute distance between junction pairs
        # (num_junc x 1 x 2) - (1 x num_junc x 2) => num_junc x num_junc map
        line_dist_map = torch.sum((torch.unsqueeze(junctions, dim=1)
                                   - junctions[None, ...]) ** 2, dim=-1) ** 0.5

        # Fetch all the "detected lines"
        seg_indexes = torch.where(torch.triu(candidate_map, diagonal=1))
        start_point_idxs = seg_indexes[0]
        end_point_idxs = seg_indexes[1]
        start_points = junctions[start_point_idxs, :]
        end_points = junctions[end_point_idxs, :]

        # Fetch corresponding entries
        line_dists = line_dist_map[start_point_idxs, end_point_idxs]

        # Check whether they are on the line
        dir_vecs = ((end_points - start_points)
                    / torch.norm(end_points - start_points,
                                 dim=-1)[..., None])
        # Get the orthogonal distance
        cand_vecs = junctions[None, ...] - start_points.unsqueeze(dim=1)
        cand_vecs_norm = torch.norm(cand_vecs, dim=-1)
        # Check whether they are projected directly onto the segment
        proj = (torch.einsum('bij,bjk->bik', cand_vecs, dir_vecs[..., None])
                / line_dists[..., None, None])
        # proj is num_segs x num_junction x 1
        proj_mask = (proj >= 0) * (proj <= 1)
        cand_angles = torch.acos(
            torch.einsum('bij,bjk->bik', cand_vecs, dir_vecs[..., None])
            / cand_vecs_norm[..., None])
        cand_dists = cand_vecs_norm[..., None] * torch.sin(cand_angles)
        junc_dist_mask = cand_dists <= dist_tolerance
        junc_mask = junc_dist_mask * proj_mask

        # Minus starting points
        num_segs = start_point_idxs.shape[0]
        junc_counts = torch.sum(junc_mask, dim=[1, 2])
        junc_counts -= junc_mask[..., 0][torch.arange(0, num_segs),
        start_point_idxs].to(torch.int)
        junc_counts -= junc_mask[..., 0][torch.arange(0, num_segs),
        end_point_idxs].to(torch.int)

        # Get the invalid candidate mask
        final_mask = junc_counts > 0
        candidate_map[start_point_idxs[final_mask],
        end_point_idxs[final_mask]] = 0

        return candidate_map

    def refine_junction_perturb(self, junctions, line_map_pred,
                                heatmap, H, W, device):
        """ Refine the line endpoints in a similar way as in LSD. """
        # Get the config
        junction_refine_cfg = self.junction_refine_cfg

        # Fetch refinement parameters
        num_perturbs = junction_refine_cfg["num_perturbs"]
        perturb_interval = junction_refine_cfg["perturb_interval"]
        side_perturbs = (num_perturbs - 1) // 2
        # Fetch the 2D perturb mat
        perturb_vec = torch.arange(
            start=-perturb_interval * side_perturbs,
            end=perturb_interval * (side_perturbs + 1),
            step=perturb_interval, device=device)
        w1_grid, h1_grid, w2_grid, h2_grid = torch.meshgrid(
            perturb_vec, perturb_vec, perturb_vec, perturb_vec)
        perturb_tensor = torch.cat([
            w1_grid[..., None], h1_grid[..., None],
            w2_grid[..., None], h2_grid[..., None]], dim=-1)
        perturb_tensor_flat = perturb_tensor.view(-1, 2, 2)

        # Fetch the junctions and line_map
        junctions = junctions.clone()
        line_map = line_map_pred

        # Fetch all the detected lines
        detected_seg_indexes = torch.where(torch.triu(line_map, diagonal=1))
        start_point_idxs = detected_seg_indexes[0]
        end_point_idxs = detected_seg_indexes[1]
        start_points = junctions[start_point_idxs, :]
        end_points = junctions[end_point_idxs, :]

        line_segments = torch.cat([start_points.unsqueeze(dim=1),
                                   end_points.unsqueeze(dim=1)], dim=1)

        line_segment_candidates = (line_segments.unsqueeze(dim=1)
                                   + perturb_tensor_flat[None, ...])
        # Clip the boundaries
        line_segment_candidates[..., 0] = torch.clamp(
            line_segment_candidates[..., 0], min=0, max=H - 1)
        line_segment_candidates[..., 1] = torch.clamp(
            line_segment_candidates[..., 1], min=0, max=W - 1)

        # Iterate through all the segments
        refined_segment_lst = []
        num_segments = line_segments.shape[0]
        for idx in range(num_segments):
            segment = line_segment_candidates[idx, ...]
            # Get the corresponding start and end junctions
            candidate_junc_start = segment[:, 0, :]
            candidate_junc_end = segment[:, 1, :]

            # Get the sampling locations (N x 64)
            sampler = self.torch_sampler.to(device)[None, ...]
            cand_samples_h = (candidate_junc_start[:, 0:1] * sampler +
                              candidate_junc_end[:, 0:1] * (1 - sampler))
            cand_samples_w = (candidate_junc_start[:, 1:2] * sampler +
                              candidate_junc_end[:, 1:2] * (1 - sampler))

            # Clip to image boundary
            cand_h = torch.clamp(cand_samples_h, min=0, max=H - 1)
            cand_w = torch.clamp(cand_samples_w, min=0, max=W - 1)

            # Perform bilinear sampling
            segment_feat = self.detect_bilinear(
                heatmap, cand_h, cand_w, H, W, device)
            segment_results = torch.mean(segment_feat, dim=-1)
            max_idx = torch.argmax(segment_results)
            refined_segment_lst.append(segment[max_idx, ...][None, ...])

        # Concatenate back to segments
        refined_segments = torch.cat(refined_segment_lst, dim=0)

        # Convert back to junctions and line_map
        junctions_new = torch.cat(
            [refined_segments[:, 0, :], refined_segments[:, 1, :]], dim=0)
        junctions_new = torch.unique(junctions_new, dim=0)
        line_map_new = self.segments_to_line_map(junctions_new,
                                                 refined_segments)

        return junctions_new, line_map_new

    def segments_to_line_map(self, junctions, segments):
        """ Convert the list of segments to line map. """
        # Create empty line map
        device = junctions.device
        num_junctions = junctions.shape[0]
        line_map = torch.zeros([num_junctions, num_junctions], device=device)

        # Iterate through every segment
        for idx in range(segments.shape[0]):
            # Get the junctions from a single segement
            seg = segments[idx, ...]
            junction1 = seg[0, :]
            junction2 = seg[1, :]

            # Get index
            idx_junction1 = torch.where(
                (junctions == junction1).sum(axis=1) == 2)[0]
            idx_junction2 = torch.where(
                (junctions == junction2).sum(axis=1) == 2)[0]

            # label the corresponding entries
            line_map[idx_junction1, idx_junction2] = 1
            line_map[idx_junction2, idx_junction1] = 1

        return line_map

    def detect_bilinear(self, heatmap, cand_h, cand_w, H, W, device):
        """ Detection by bilinear sampling. """
        # Get the floor and ceiling locations
        cand_h_floor = torch.floor(cand_h).to(torch.long)
        cand_h_ceil = torch.ceil(cand_h).to(torch.long)
        cand_w_floor = torch.floor(cand_w).to(torch.long)
        cand_w_ceil = torch.ceil(cand_w).to(torch.long)

        # Perform the bilinear sampling
        cand_samples_feat = (
                heatmap[cand_h_floor, cand_w_floor] * (cand_h_ceil - cand_h)
                * (cand_w_ceil - cand_w) + heatmap[cand_h_floor, cand_w_ceil]
                * (cand_h_ceil - cand_h) * (cand_w - cand_w_floor) +
                heatmap[cand_h_ceil, cand_w_floor] * (cand_h - cand_h_floor)
                * (cand_w_ceil - cand_w) + heatmap[cand_h_ceil, cand_w_ceil]
                * (cand_h - cand_h_floor) * (cand_w - cand_w_floor))

        return cand_samples_feat

    def detect_local_max(self, heatmap, cand_h, cand_w, H, W,
                         normalized_seg_length, device):
        """ Detection by local maximum search. """
        # Compute the distance threshold
        dist_thresh = (0.5 * (2 ** 0.5)
                       + self.lambda_radius * normalized_seg_length)
        # Make it N x 64
        dist_thresh = torch.repeat_interleave(dist_thresh[..., None],
                                              self.num_samples, dim=-1)

        # Compute the candidate points
        cand_points = torch.cat([cand_h[..., None], cand_w[..., None]],
                                dim=-1)
        cand_points_round = torch.round(cand_points)  # N x 64 x 2

        # Construct local patches 9x9 = 81
        patch_mask = torch.zeros([int(2 * self.local_patch_radius + 1),
                                  int(2 * self.local_patch_radius + 1)],
                                 device=device)
        patch_center = torch.tensor(
            [[self.local_patch_radius, self.local_patch_radius]],
            device=device, dtype=torch.float32)
        H_patch_points, W_patch_points = torch.where(patch_mask >= 0)
        patch_points = torch.cat([H_patch_points[..., None],
                                  W_patch_points[..., None]], dim=-1)
        # Fetch the circle region
        patch_center_dist = torch.sqrt(torch.sum(
            (patch_points - patch_center) ** 2, dim=-1))
        patch_points = (patch_points[patch_center_dist
                                     <= self.local_patch_radius, :])
        # Shift [0, 0] to the center
        patch_points = patch_points - self.local_patch_radius

        # Construct local patch mask
        patch_points_shifted = (torch.unsqueeze(cand_points_round, dim=2)
                                + patch_points[None, None, ...])
        patch_dist = torch.sqrt(torch.sum((torch.unsqueeze(cand_points, dim=2)
                                           - patch_points_shifted) ** 2,
                                          dim=-1))
        patch_dist_mask = patch_dist < dist_thresh[..., None]

        # Get all points => num_points_center x num_patch_points x 2
        points_H = torch.clamp(patch_points_shifted[:, :, :, 0], min=0,
                               max=H - 1).to(torch.long)
        points_W = torch.clamp(patch_points_shifted[:, :, :, 1], min=0,
                               max=W - 1).to(torch.long)
        points = torch.cat([points_H[..., None], points_W[..., None]], dim=-1)

        # Sample the feature (N x 64 x 81)
        sampled_feat = heatmap[points[:, :, :, 0], points[:, :, :, 1]]
        # Filtering using the valid mask
        sampled_feat = sampled_feat * patch_dist_mask.to(torch.float32)
        if len(sampled_feat) == 0:
            sampled_feat_lmax = torch.empty(0, 64)
        else:
            sampled_feat_lmax, _ = torch.max(sampled_feat, dim=-1)

        return sampled_feat_lmax


Other utils

In [8]:
class InputPadder(object):
    """Pads images such that dimensions are divisible by 8"""

    def __init__(self, h: int, w: int, divis_by: int = 8):
        self.ht = h
        self.wd = w
        pad_ht = (((self.ht // divis_by) + 1) * divis_by - self.ht) % divis_by
        pad_wd = (((self.wd // divis_by) + 1) * divis_by - self.wd) % divis_by
        self._pad = [
            pad_wd // 2,
            pad_wd - pad_wd // 2,
            pad_ht // 2,
            pad_ht - pad_ht // 2,
        ]

    def pad(self, x: torch.Tensor):
        assert x.ndim == 4
        return F.pad(x, self._pad, mode="replicate")

    def unpad(self, x: torch.Tensor):
        assert x.ndim == 4
        ht = x.shape[-2]
        wd = x.shape[-1]
        c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
        return x[..., c[0]: c[1], c[2]: c[3]]

In [None]:
# Basemodel

In [9]:
from abc import ABCMeta, abstractmethod
from copy import copy

import omegaconf
from omegaconf import OmegaConf
from torch import nn

class MetaModel(ABCMeta):
    def __prepare__(name, bases, **kwds):
        total_conf = OmegaConf.create()
        for base in bases:
            for key in ("base_default_conf", "default_conf"):
                update = getattr(base, key, {})
                if isinstance(update, dict):
                    update = OmegaConf.create(update)
                total_conf = OmegaConf.merge(total_conf, update)
        return dict(base_default_conf=total_conf)


class BaseModel(nn.Module, metaclass=MetaModel):
    """
    What the child model is expect to declare:
        default_conf: dictionary of the default configuration of the model.
        It recursively updates the default_conf of all parent classes, and
        it is updated by the user-provided configuration passed to __init__.
        Configurations can be nested.

        required_data_keys: list of expected keys in the input data dictionary.

        strict_conf (optional): boolean. If false, BaseModel does not raise
        an error when the user provides an unknown configuration entry.

        _init(self, conf): initialization method, where conf is the final
        configuration object (also accessible with `self.conf`). Accessing
        unknown configuration entries will raise an error.

        _forward(self, data): method that returns a dictionary of batched
        prediction tensors based on a dictionary of batched input data tensors.

        loss(self, pred, data): method that returns a dictionary of losses,
        computed from model predictions and input data. Each loss is a batch
        of scalars, i.e. a torch.Tensor of shape (B,).
        The total loss to be optimized has the key `'total'`.

        metrics(self, pred, data): method that returns a dictionary of metrics,
        each as a batch of scalars.
    """

    default_conf = {
        "name": None,
        "trainable": True,  # if false: do not optimize this model parameters
        "freeze_batch_normalization": False,  # use test-time statistics
        "timeit": False,  # time forward pass
    }
    required_data_keys = []
    strict_conf = False

    are_weights_initialized = False

    def __init__(self, conf):
        """Perform some logic and call the _init method of the child model."""
        super().__init__()
        default_conf = OmegaConf.merge(
            self.base_default_conf, OmegaConf.create(self.default_conf)
        )
        if self.strict_conf:
            OmegaConf.set_struct(default_conf, True)

        # fixme: backward compatibility
        if "pad" in conf and "pad" not in default_conf:  # backward compat.
            with omegaconf.read_write(conf):
                with omegaconf.open_dict(conf):
                    conf["interpolation"] = {"pad": conf.pop("pad")}

        if isinstance(conf, dict):
            conf = OmegaConf.create(conf)
        self.conf = conf = OmegaConf.merge(default_conf, conf)
        OmegaConf.set_readonly(conf, True)
        OmegaConf.set_struct(conf, True)
        self.required_data_keys = copy(self.required_data_keys)
        self._init(conf)

        if not conf.trainable:
            for p in self.parameters():
                p.requires_grad = False

    def train(self, mode=True):
        super().train(mode)

        def freeze_bn(module):
            if isinstance(module, nn.modules.batchnorm._BatchNorm):
                module.eval()

        if self.conf.freeze_batch_normalization:
            self.apply(freeze_bn)

        return self

    def forward(self, data):
        """Check the data and call the _forward method of the child model."""

        def recursive_key_check(expected, given):
            for key in expected:
                assert key in given, f"Missing key {key} in data"
                if isinstance(expected, dict):
                    recursive_key_check(expected[key], given[key])

        recursive_key_check(self.required_data_keys, data)
        return self._forward(data)

    @abstractmethod
    def _init(self, conf):
        """To be implemented by the child class."""
        raise NotImplementedError

    @abstractmethod
    def _forward(self, data):
        """To be implemented by the child class."""
        raise NotImplementedError

    @abstractmethod
    def loss(self, pred, data):
        """To be implemented by the child class."""
        raise NotImplementedError

    def load_state_dict(self, *args, **kwargs):
        """Load the state dict of the model, and set the model to initialized."""
        ret = super().load_state_dict(*args, **kwargs)
        self.set_initialized()
        return ret

    def is_initialized(self):
        """Recursively check if the model is initialized, i.e. weights are loaded"""
        is_initialized = True  # initialize to true and perform recursive and
        for _, w in self.named_children():
            if isinstance(w, BaseModel):
                # if children is BaseModel, we perform recursive check
                is_initialized = is_initialized and w.is_initialized()
            else:
                # else, we check if self is initialized or the children has no params
                n_params = len(list(w.parameters()))
                is_initialized = is_initialized and (
                    n_params == 0 or self.are_weights_initialized
                )
        return is_initialized

    def set_initialized(self, to: bool = True):
        """Recursively set the initialization state."""
        self.are_weights_initialized = to
        for _, w in self.named_parameters():
            if isinstance(w, BaseModel):
                w.set_initialized(to)

## Main Network file

In [20]:
class JointPointLineDetectorDescriptor(BaseModel):
    # currently contains only ALIKED
    default_conf = {  # ToDo: create default conf once everything is running -> default conf is merged with input conf to the init method!
        "model_name": "aliked-n16",
        "max_num_keypoints": -1,
        "detection_threshold": 0.2,
        "force_num_keypoints": False,
        "pretrained": True,
        "nms_radius": 2,
    }
    
    n_limit_max = 20000 # taken from ALIKED which gives max num keypoints to detect! ToDo
    
    line_extractor_cfg = {
        "detect_thresh": 1/65,
        "grid_size": 8,
        "junc_detect_thresh": 1/65,
        "max_num_junctions": 300
    }

    required_data_keys = ["image"]

    def _init(self, conf):
        print(f"final config dict(type={type(conf)}): {conf}")
        # get configurations
        # c1-c4 -> output dimensions of encoder blocks, dim -> dimension of hidden feature map
        # K=Kernel-Size, M=num sampling pos
        aliked_model_cfg = aliked_cfgs[conf.model_name]
        dim = aliked_model_cfg["dim"]
        K = aliked_model_cfg["K"]
        M = aliked_model_cfg["M"]
        # Load Network Components
        print(f"aliked cfg(type={type(aliked_model_cfg)}): {aliked_model_cfg}")
        self.encoder = AlikedEncoder(aliked_model_cfg)
        self.keypoint_and_junction_branch = SMH(dim)  # using SMH from ALIKE here
        self.dkd = DKD(radius=conf.nms_radius,
                       top_k=-1 if conf.detection_threshold > 0 else conf.max_num_keypoints,
                       scores_th=conf.detection_threshold,
                       n_limit=(
                           conf.max_num_keypoints
                           if conf.max_num_keypoints > 0
                           else self.n_limit_max
                       ), )  # Differentiable Keypoint Detection from ALIKE
        self.descriptor_branch = SDDH(dim, K, M, gate=nn.SELU(inplace=True), conv2D=False, mask=False)
        self.line_heatmap_branch = PixelShuffleDecoder(input_feat_dim=dim)  # Use SOLD2 branch
        self.line_extractor = LineExtractor(torch.device("cpu"), self.line_extractor_cfg)  # USe SOLD2 one
        self.line_descriptor = torch.lerp  # we take the endpoints of lines and interpolate to get the descriptor

    def _forward(self, data):
        # load image and padder
        image = data["image"]
        div_by = 2 ** 5
        print(f"image({type(image)}): {image.shape}")
        padder = InputPadder(image.shape[-2], image.shape[-1], div_by)

        # Get Hidden Feature Map and Keypoint/junction scoring
        feature_map_padded = self.encoder(padder.pad(image))
        score_map_padded = self.keypoint_and_junction_branch(feature_map_padded)
        feature_map_padded_normalized = torch.nn.functional.normalize(feature_map_padded, p=2, dim=1)
        feature_map = padder.unpad(feature_map_padded_normalized)
        keypoint_and_junction_score_map = padder.unpad(score_map_padded)

        line_heatmap = self.line_heatmap_branch.forward(feature_map)

        keypoints, kptscores, scoredispersitys = self.dkd(
            keypoint_and_junction_score_map, image_size=data.get("image_size")
        )

        # ToDo: Does it work well to use keypoints for juctions?? -> Design decision
        # ToDo: need preprocessing like in SOLD2 repo before passing it to line extractor?
        line_map, junctions, heatmap = self.line_extractor(keypoints, line_heatmap)
        line_segments = line_map_to_segments(junctions, line_map)

        descriptors, offsets = self.desc_head(feature_map, keypoints)
        # TODO: can we make sure endpoints are always keypoints?! + Fix Input to this function
        line_descriptors = self.line_descriptor(line_segments[0], line_segments[1], 0.5)  # TODO: Interpolate line-endpoint descriptors

        _, _, h, w = image.shape
        wh = torch.tensor([w, h], device=image.device)
        # no padding required,
        # we can set detection_threshold=-1 and conf.max_num_keypoints
        return {
            "keypoints": wh * (torch.stack(keypoints) + 1) / 2.0,  # B N 2
            "keypoint_descriptors": torch.stack(descriptors),  # B N D
            "keypoint_scores": torch.stack(kptscores),  # B N
            "score_dispersity": torch.stack(scoredispersitys),
            "score_map": keypoint_and_junction_score_map,  # Bx1xHxW
            "line_heatmap": heatmap,
            "line_endpoints": line_segments,  # as tuples
            "line_descriptors": line_descriptors  # as vectors
        }

    def loss(self, pred, data):
        raise NotImplementedError
    
    def count_trainable_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

In [11]:
# ToDo: Figure out default config
config = {
    "backbone-encoder": {
        "name": "ALIKED" # name to (load?)/name backbone encoder rest of config is defined by default conf or passed here -> default conf is taken and overwritten by passed conf if/where given
    },
    "keypoint-and-junction-decoder": {
        "name": "ALIKED-SMH" # name to (load?)/name backbone encoder rest of config is defined by default conf or passed here -> default conf is taken and overwritten by passed conf if/where given
    },
    "keypoint-detector": {
        "name": "ALIKE-DKD" # name to (load?)/name backbone encoder rest of config is defined by default conf or passed here -> default conf is taken and overwritten by passed conf if/where given
    },
    "line-heatmap-decoder": {
        "name": "SOLD2-PixelShuffle" # name to (load?)/name backbone encoder rest of config is defined by default conf or passed here -> default conf is taken and overwritten by passed conf if/where given
    },
    "line-detector": {
        "name": "SOLD2-Lineextractor" # name to (load?)/name backbone encoder rest of config is defined by default conf or passed here -> default conf is taken and overwritten by passed conf if/where given
    },
    "descriptors": {
        "name": "ALIKED-SDDH"  # name to (load?)/name backbone encoder rest of config is defined by default conf or passed here -> default conf is taken and overwritten by passed conf if/where given
    }
}

## Testing

In [17]:
# taken from glue_factory.utils.image
def read_image(path, grayscale=False):
    """Read an image from path as RGB or grayscale"""
    mode = cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR
    image = cv2.imread(str(path), mode)
    if image is None:
        raise IOError(f"Could not read image at {path}.")
    if not grayscale:
        image = image[..., ::-1]
    return image

from PIL import Image # alternative to load image

def load_img_pil(path, resize=True, as_npy=True):
    img = Image.open(path)
    if resize:
        img = torchvision.transforms.Resize(800)(img)
    if as_npy:
        img_npy = torchvision.transforms.ToTensor()(img).numpy()
        img_npy = np.expand_dims(img_npy, axis=0).copy() # add artificial batch dimension
        #img_npy = np.transpose(img_npy, (0, 3, 1, 2)) # reshape
        return img_npy
    return img
# print(img.shape)
# plt.imshow(img)

In [None]:
# load img
img_path = "sample_eiffel.jpg"
img = load_img_pil(img_path, resize=True, as_npy=True) # load, rescale and reshape 
# initialize
device = torch.device("cpu")
our_network = JointPointLineDetectorDescriptor(config)
num_params = our_network.count_trainable_parameters()
print(f"Num-Parameters trainable: {num_params}")
our_network.to(device)
our_network.eval()

# Prepare Img
 # rescale to 800 x 800
print(f"img shape: {img.shape}")
img = torch.from_numpy(img).float().to(device) # convert to tensor and move onto device

# run eval
with torch.no_grad():
    test_data = {"image": img}
    predictions = our_network.forward(test_data)
predictions

final config dict(type=<class 'omegaconf.dictconfig.DictConfig'>): {'name': None, 'trainable': True, 'freeze_batch_normalization': False, 'timeit': False, 'model_name': 'aliked-n16', 'max_num_keypoints': -1, 'detection_threshold': 0.2, 'force_num_keypoints': False, 'pretrained': True, 'nms_radius': 2, 'backbone-encoder': {'name': 'ALIKED'}, 'keypoint-and-junction-decoder': {'name': 'ALIKED-SMH'}, 'keypoint-detector': {'name': 'ALIKE-DKD'}, 'line-heatmap-decoder': {'name': 'SOLD2-PixelShuffle'}, 'line-detector': {'name': 'SOLD2-Lineextractor'}, 'descriptors': {'name': 'ALIKED-SDDH'}}
aliked cfg(type=<class 'dict'>): {'c1': 16, 'c2': 32, 'c3': 64, 'c4': 128, 'dim': 128, 'K': 3, 'M': 16}
Num-Parameters trainable: 1010126
img shape: (1, 3, 1332, 800)
torch.float32
image(<class 'torch.Tensor'>): torch.Size([1, 3, 1332, 800])
