In [2]:
import sys
from pathlib import Path

# VSCode에서만 사용 가능한 건지 확인 필요
FILE = Path(__vsc_ipynb_file__).resolve()
ROOT = FILE.parents[1]
if str(ROOT) not in sys.path:
    sys.path.append(str(ROOT))
    
import torch


In [3]:
from utils.tal.anchor_generator import make_anchors, dist2bbox, bbox2dist

In [7]:
torch.randn(1, 2, 3, 4)

tensor([[[[ 1.07694, -0.65539, -1.86710,  0.23305],
          [-0.26468,  0.22513,  0.24349, -0.69900],
          [-1.55059,  0.26184,  1.37744, -1.62508]],

         [[ 1.14924,  0.44390,  0.02803, -2.12241],
          [ 0.40042, -0.29661,  1.64121,  1.56942],
          [ 0.01325,  0.67009, -0.32774, -0.24635]]]])

In [4]:
input_size = 640

strides = [8, 16, 32]

feats = [
    torch.randn(1, 256, input_size // strides[0], input_size // strides[0]),
    torch.randn(1, 256, input_size // strides[1], input_size // strides[1]),
    torch.randn(1, 256, input_size // strides[2], input_size // strides[2])
]

anchor_points, stride_tensor = make_anchors(feats, strides)

print('Anchor Points:', anchor_points)
print('Stride Tensor:', stride_tensor)
print('Number of Anchor Points:', anchor_points.shape)
print('Number of Stride Tensor Elements:', stride_tensor.shape)

Anchor Points: tensor([[ 0.50000,  0.50000],
        [ 1.50000,  0.50000],
        [ 2.50000,  0.50000],
        ...,
        [17.50000, 19.50000],
        [18.50000, 19.50000],
        [19.50000, 19.50000]])
Stride Tensor: tensor([[ 8.],
        [ 8.],
        [ 8.],
        ...,
        [32.],
        [32.],
        [32.]])
Number of Anchor Points: torch.Size([8400, 2])
Number of Stride Tensor Elements: torch.Size([8400, 1])


```python
def make_anchors(feats, strides, grid_cell_offset=0.5):
    """Generate anchors from features."""
    anchor_points, stride_tensor = [], []
    assert feats is not None
    dtype, device = feats[0].dtype, feats[0].device
    for i, stride in enumerate(strides):
        _, _, h, w = feats[i].shape
        sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset  # shift x
        sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset  # shift y
        sy, sx = torch.meshgrid(sy, sx, indexing='ij') if TORCH_1_10 else torch.meshgrid(sy, sx)
        anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2))
        stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))
    return torch.cat(anchor_points), torch.cat(stride_tensor)
```

In [10]:
w = 4
h = 4
sx = torch.arange(end=w) + 0.5
sy = torch.arange(end=h) + 0.5
torch.meshgrid(sy, sx, indexing='ij')

(tensor([[0.50000, 0.50000, 0.50000, 0.50000],
         [1.50000, 1.50000, 1.50000, 1.50000],
         [2.50000, 2.50000, 2.50000, 2.50000],
         [3.50000, 3.50000, 3.50000, 3.50000]]),
 tensor([[0.50000, 1.50000, 2.50000, 3.50000],
         [0.50000, 1.50000, 2.50000, 3.50000],
         [0.50000, 1.50000, 2.50000, 3.50000],
         [0.50000, 1.50000, 2.50000, 3.50000]]))