Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Added implementation of YOLOv1 #23

Merged
merged 8 commits into from
May 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,6 @@ Each dictionary has 3 keys: box coordinates, classification probability, classif
YOLO
-------

.. autofunction:: yolov1

.. autofunction:: yolov2
104 changes: 91 additions & 13 deletions holocron/models/darknet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from collections import OrderedDict
import torch.nn as nn

__all__ = ['Darknet', 'darknet19']
__all__ = ['DarknetV1', 'DarknetV2', 'darknet24', 'darknet19']


def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
Expand All @@ -21,7 +21,71 @@ def conv1x1(in_planes, out_planes, stride=1):
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class DarkBlock(nn.Sequential):
class DarkBlockV1(nn.Sequential):

def __init__(self, planes):

layers = []
k1 = True
for in_planes, out_planes in zip(planes[:-1], planes[1:]):
layers.append(conv1x1(in_planes, out_planes) if k1 else conv3x3(in_planes, out_planes))
layers.append(nn.LeakyReLU(inplace=True))
k1 = not k1

super(DarkBlockV1, self).__init__(*layers)


class DarknetBodyV1(nn.Module):
def __init__(self, layout):

super().__init__()

self.conv1 = nn.Conv2d(3, 64, 7, padding=3, stride=2)
self.activation = nn.LeakyReLU(0.1, inplace=True)
self.pool = nn.MaxPool2d(2)
self.conv2 = conv3x3(64, 192)

self.block1 = DarkBlockV1([192] + layout[0])
self.block2 = DarkBlockV1(layout[0][-1:] + layout[1])
self.block3 = DarkBlockV1(layout[1][-1:] + layout[2])

def forward(self, x):
x = self.activation(self.conv1(x))
x = self.pool(x)
x = self.activation(self.conv2(x))
x = self.pool(x)

x = self.block1(x)
x = self.pool(x)
x = self.block2(x)
x = self.pool(x)
x = self.block3(x)

return x


class DarknetV1(nn.Module):

def __init__(self, layout, num_classes=1000):

super().__init__()

self.features = DarknetBodyV1(layout)

# Pooling (7, 7) or global ?
self.classifier = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
nn.Flatten(),
nn.Linear(layout[2][-1], num_classes))

def forward(self, x):

x = self.features(x)
x = self.classifier(x)
return x


class DarkBlockV2(nn.Sequential):

def __init__(self, in_planes, out_planes, nb_compressions=0):

Expand All @@ -36,10 +100,10 @@ def __init__(self, in_planes, out_planes, nb_compressions=0):
nn.BatchNorm2d(out_planes),
nn.LeakyReLU(0.1, inplace=True)])

super(DarkBlock, self).__init__(*layers)
super(DarkBlockV2, self).__init__(*layers)


class DarknetBody(nn.Module):
class DarknetBodyV2(nn.Module):

def __init__(self, layout, passthrough=False):

Expand All @@ -51,10 +115,10 @@ def __init__(self, layout, passthrough=False):
self.pool = nn.MaxPool2d(2)
self.conv2 = conv3x3(32, 64)
self.bn2 = nn.BatchNorm2d(64)
self.block1 = DarkBlock(64, *layout[0])
self.block2 = DarkBlock(layout[0][0], *layout[1])
self.block3 = DarkBlock(layout[1][0], *layout[2])
self.block4 = DarkBlock(layout[2][0], *layout[3])
self.block1 = DarkBlockV2(64, *layout[0])
self.block2 = DarkBlockV2(layout[0][0], *layout[1])
self.block3 = DarkBlockV2(layout[1][0], *layout[2])
self.block4 = DarkBlockV2(layout[2][0], *layout[3])
self.passthrough = passthrough

def forward(self, x):
Expand All @@ -76,13 +140,13 @@ def forward(self, x):
return x


class Darknet(nn.Module):
class DarknetV2(nn.Module):

def __init__(self, layout, num_classes=20):
def __init__(self, layout, num_classes=1000):

super().__init__()

self.features = DarknetBody(layout)
self.features = DarknetBodyV2(layout)

self.classifier = nn.Sequential(
conv1x1(layout[-1][0], num_classes),
Expand All @@ -97,7 +161,21 @@ def forward(self, x):
return x


def darknet19(num_classes=20):
def darknet24(num_classes=1000):
"""Darknet-24 from
`"You Only Look Once: Unified, Real-Time Object Detection" <https://pjreddie.com/media/files/papers/yolo_1.pdf>`_

Args:
num_classes (int, optional): number of output classes

Returns:
torch.nn.Module: classification model
"""

return DarknetV1([[128, 256, 256, 512], [*([256, 512] * 4), 512, 1024], [512, 1024, 512, 1024]], num_classes)


def darknet19(num_classes=1000):
"""Darknet-19 from
`"YOLO9000: Better, Faster, Stronger" <https://pjreddie.com/media/files/papers/YOLO9000.pdf>`_

Expand All @@ -108,4 +186,4 @@ def darknet19(num_classes=20):
torch.nn.Module: classification model
"""

return Darknet([(128, 1), (256, 1), (512, 2), (1024, 2)], num_classes)
return DarknetV2([(128, 1), (256, 1), (512, 2), (1024, 2)], num_classes)
Loading