# Matching Strategy

In [1]:
import sys
sys.path.append('../')

from data import datasets
from data import transforms, utils
from models.ssd300 import SSD300

In [2]:
transform = transforms.Compose(
    [transforms.Ignore(difficult=True),
     transforms.Normalize(),
     transforms.Centered(),
     transforms.Resize((300, 300)), # if resizing first, can't be normalized
     transforms.OneHot(class_nums=datasets.VOC_class_nums),
     transforms.ToTensor()]
)

dataset = datasets.Compose(datasets.VOC_class_nums, datasets=(datasets.VOC2007Dataset, datasets.VOC2012_TrainValDataset), transform=transform)

model = SSD300(datasets.VOC_class_nums, batch_norm=False)

In [3]:
%matplotlib notebook
import matplotlib.pyplot as plt
import numpy as np
import cv2
import torch

from models.core.boxes import center2minmax
from models.core.inference import toVisualizeImg, tensor2cvimg, toVisualizeRectangleimg

In [4]:
img, targets = dataset[0]
loc, conf = targets[:, :4], targets[:, 4:]

print(loc)
print(conf)

print(loc.shape)
print(conf.argmax(dim=1).shape)

plt.figure()
plt.imshow(toVisualizeImg(img, loc, conf_indices=conf.argmax(dim=1), classes=datasets.VOC_classes))

tensor([[0.5870, 0.7333, 0.1220, 0.3413],
        [0.4180, 0.8480, 0.1760, 0.2880],
        [0.5360, 0.6573, 0.1080, 0.2800]])
tensor([[0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0.]])
torch.Size([3, 4])
torch.Size([3])


<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x7f5e599ea860>

In [5]:
from models.core.boxes import matching_strategy

model.load_vgg_weights()

print(img.shape)
print(targets.shape)

imgs, gts = utils.batch_ind_fn((dataset[0],))
print(gts)
print(imgs.shape)

box_num = imgs.shape[0]

# predict
predicts, dboxes = model(imgs)
pos_indicator, gt_locs, gt_confs = matching_strategy(gts, dboxes, batch_num=1)

torch.Size([3, 300, 300])
torch.Size([3, 25])
tensor([[3.0000, 0.5870, 0.7333, 0.1220, 0.3413, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [3.0000, 0.4180, 0.8480, 0.1760, 0.2880, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [3.0000, 0.5360, 0.6573, 0.1080, 0.2800, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])
torch.Size([1, 3, 300, 300])


In [6]:
dbox_img = toVisualizeRectangleimg(img, dboxes[::22], thickness=1)
    
plt.figure()
#plt.imshow(resized_img)
plt.imshow(dbox_img)

<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x7f5e506b1ac8>

In [7]:
voc_classes = datasets.VOC_classes + ['back ground']

print(voc_classes, len(voc_classes))
print(gt_locs[pos_indicator].shape)
#print(gt_locs[0].tolist())
print(gt_confs[pos_indicator].argmax(dim=1).shape)
print(gt_confs[0].argmax(dim=1).tolist())

print(gt_locs[pos_indicator].tolist())

['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor', 'back ground'] 21
torch.Size([71, 4])
torch.Size([71])
[20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 

In [8]:
plt.figure()
dboxes_b = torch.cat([dboxes.unsqueeze(0)] * box_num, dim=0)
plt.imshow(toVisualizeRectangleimg(img, dboxes_b[pos_indicator]))
#plt.imshow(toVisualizeImg(img, gt_locs[pos_indicator], conf_indices=gt_confs[pos_indicator].argmax(dim=1), classes=voc_classes))

<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x7f5e50690c88>

In [62]:
angles = np.linspace(0, 255, 20).astype(np.uint8)
#print(angles.shape)
hsvs = np.array((0, 255, 255))[np.newaxis, np.newaxis, :].astype(np.uint8)
hsvs = np.repeat(hsvs, 20, axis=0)
#print(hsvs.shape)
hsvs[:, 0, 0] += angles

rgbs = cv2.cvtColor(hsvs, cv2.COLOR_HSV2RGB).astype(np.int)
locs_mm = center2minmax(loc).numpy()

h, w, c = img.shape
locs_mm[:, ::2] *= w
locs_mm[:, 1::2] *= h
locs_mm = np.clip(locs_mm, 0, w).astype(np.int)

topleft = locs_mm[0, :2]
bottomright = locs_mm[0, 2:]
print(tuple(topleft), tuple(bottomright), tuple(rgbs[0, 0]))
print(tensor2cvimg(img).dtype, tensor2cvimg(img).shape)
plt.figure()
plt.imshow(tensor2cvimg(img))
a = tuple(rgbs[0, 0].tolist())
print(type(a[0]))
cv2.rectangle(tensor2cvimg(img), tuple(topleft), tuple(bottomright), a, 2)

(157, 1) (194, 2) (255, 0, 0)
uint8 (300, 300, 3)


<IPython.core.display.Javascript object>

<class 'int'>


array([[[ 11,  10,  10],
        [ 10,  10,  10],
        [  9,  13,  10],
        ...,
        [166, 186, 195],
        [164, 187, 193],
        [162, 186, 192]],

       [[ 18,  19,  21],
        [ 12,  13,  14],
        [ 10,   9,  10],
        ...,
        [162, 189, 195],
        [162, 189, 196],
        [162, 189, 196]],

       [[ 84,  91,  94],
        [ 69,  75,  77],
        [ 56,  58,  62],
        ...,
        [164, 190, 197],
        [165, 190, 197],
        [164, 188, 195]],

       ...,

       [[ 21,  10,   6],
        [ 19,   9,   5],
        [ 27,  10,   1],
        ...,
        [ 61,  73, 104],
        [ 58,  70, 102],
        [ 55,  67,  99]],

       [[ 35,   9,   5],
        [ 32,  10,   2],
        [ 51,  18,   4],
        ...,
        [ 57,  72, 102],
        [ 62,  77, 107],
        [ 65,  79, 109]],

       [[ 51,  17,   6],
        [ 68,  23,  13],
        [ 86,  28,   9],
        ...,
        [ 57,  74, 104],
        [ 64,  79, 108],
        [ 69,  84, 113]]