Skip to content

Commit

Permalink
feat: Improved YOLO concatdownsample layer's efficiency (#48)
Browse files Browse the repository at this point in the history
* feat: Improved concatdowsample speed

* test: Updated concat downsample unittest

* refactor: Updated default argument for YOLO training
  • Loading branch information
frgfm authored Jun 26, 2020
1 parent 511c1b1 commit 880489c
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 18 deletions.
10 changes: 3 additions & 7 deletions holocron/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,21 +106,17 @@ def concat_downsample2d(x, scale_factor):
scale_factor (int): spatial scaling factor
Returns:
torch.Tensor[N, 4C, H / 2, W / 2]: downsampled tensor
torch.Tensor[N, scale_factor ** 2 * C, H / scale_factor, W / scale_factor]: downsampled tensor
"""

b, c, h, w = x.shape

if (h % scale_factor != 0) or (w % scale_factor != 0):
raise AssertionError("Spatial size of input tensor must be multiples of `scale_factor`")
new_h, new_w = h // scale_factor, w // scale_factor

# N * C * H * W --> N * C * (H/scale_factor) * scale_factor * (W/scale_factor) * scale_factor
out = x.view(b, c, new_h, scale_factor, new_w, scale_factor)
# Move extra axes to last position to flatten them with channel dimension
out = out.permute(0, 2, 4, 1, 3, 5).flatten(3)
# Reorder all axes
out = out.permute(0, 3, 1, 2)
out = torch.cat([x[..., i::scale_factor, j::scale_factor]
for i in range(scale_factor) for j in range(scale_factor)], dim=1)

return out

Expand Down
2 changes: 1 addition & 1 deletion references/detection/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def parse_args():
formatter_class=argparse.ArgumentDefaultsHelpFormatter)

parser.add_argument('data_path', type=str, help='path to dataset folder')
parser.add_argument('--model', default='darknet19', help='model')
parser.add_argument('--model', default='yolov2', help='model')
parser.add_argument("--freeze-backbone", dest='freeze_backbone', action='store_true',
help="Should the backbone be frozen")
parser.add_argument('--device', default='cuda', help='device')
Expand Down
20 changes: 10 additions & 10 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,20 +171,20 @@ def test_concatdownsample2d(self):

num_batches = 2
num_chan = 4
x = torch.rand(num_batches, num_chan, 4, 4)
scale_factor = 2
x = torch.arange(num_batches * num_chan * 4 ** 2).view(num_batches, num_chan, 4, 4)

# Test functional API
self.assertRaises(AssertionError, F.concat_downsample2d, x, 3)
out = F.concat_downsample2d(x, 2)
self.assertEqual(out.shape, (num_batches, num_chan * 2 ** 2, x.shape[2] // 2, x.shape[3] // 2))
self.assertTrue(torch.equal(out, torch.stack((x[..., ::2, ::2],
x[..., ::2, 1::2],
x[..., 1::2, ::2],
x[..., 1::2, 1::2]), dim=2).view(num_batches, -1,
x.shape[2] // 2,
x.shape[3] // 2)))
out = F.concat_downsample2d(x, scale_factor)
self.assertEqual(out.shape, (num_batches, num_chan * scale_factor ** 2,
x.shape[2] // scale_factor, x.shape[3] // scale_factor))

# Check first and last values
self.assertTrue(torch.equal(out[0][0], torch.tensor([[0, 2], [8, 10]])))
self.assertTrue(torch.equal(out[0][-num_chan], torch.tensor([[5, 7], [13, 15]])))
# Test module
mod = downsample.ConcatDownsample2d(2)
mod = downsample.ConcatDownsample2d(scale_factor)
self.assertTrue(torch.equal(mod(x), out))

def test_init(self):
Expand Down

0 comments on commit 880489c

Please sign in to comment.