Skip to content

Commit

Permalink
Merge pull request #24 from Chilicyy/main
Browse files Browse the repository at this point in the history
update
  • Loading branch information
Chilicyy committed Jun 23, 2022
2 parents be354e8 + 8060703 commit 3185477
Show file tree
Hide file tree
Showing 7 changed files with 33 additions and 30 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
__pycache__/
runs/
weights/
2 changes: 1 addition & 1 deletion configs/yolov6_tiny.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
out_channels=[256, 128, 128, 256, 256, 512],
),
head=dict(
type='YOLOv6tHead',
type='EffiDeHead',
in_channels=[128, 256, 512],
num_layers=3,
begin_indices=24,
Expand Down
2 changes: 1 addition & 1 deletion configs/yolov6s_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
out_channels=[256, 128, 128, 256, 256, 512],
),
head=dict(
type='YOLOv6sHead',
type='EffiDeHead',
in_channels=[128, 256, 512],
num_layers=3,
begin_indices=24,
Expand Down
4 changes: 2 additions & 2 deletions deploy/ONNX/export_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
sys.path.append(str(ROOT))

from yolov6.models.yolo import *
from yolov6.models.effidehead import EffiDeHead
from yolov6.models.effidehead import Detect
from yolov6.layers.common import *
from yolov6.utils.events import LOGGER
from yolov6.utils.checkpoint import load_checkpoint
Expand Down Expand Up @@ -53,7 +53,7 @@
if isinstance(m, Conv): # assign export-friendly activations
if isinstance(m.act, nn.SiLU):
m.act = SiLU()
elif isinstance(m, EffiDeHead):
elif isinstance(m, Detect):
m.inplace = args.inplace

y = model(img) # dry run
Expand Down
4 changes: 2 additions & 2 deletions deploy/OpenVINO/export_openvino.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
sys.path.append(str(ROOT))

from yolov6.models.yolo import *
from yolov6.models.effidehead import EffiDeHead
from yolov6.models.effidehead import Detect
from yolov6.layers.common import *
from yolov6.utils.events import LOGGER
from yolov6.utils.checkpoint import load_checkpoint
Expand Down Expand Up @@ -54,7 +54,7 @@
if isinstance(m, Conv): # assign export-friendly activations
if isinstance(m.act, nn.SiLU):
m.act = SiLU()
elif isinstance(m, EffiDeHead):
elif isinstance(m, Detect):
m.inplace = args.inplace

y = model(img) # dry run
Expand Down
28 changes: 14 additions & 14 deletions yolov6/models/effidehead.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,20 @@
from yolov6.layers.common import *


class EffiDeHead(nn.Module):
class Detect(nn.Module):
'''Efficient Decoupled Head'''
def __init__(self, num_classes=80, anchors=1, num_layers=3, inplace=True, head_layers=None): # detection layer
super().__init__()
assert head_layers is not None
self.num_classes = num_classes # number of classes
self.num_outputs = num_classes + 5 # number of outputs per anchor
self.num_layers = num_layers # number of detection layers
self.nc = num_classes # number of classes
self.no = num_classes + 5 # number of outputs per anchor
self.nl = num_layers # number of detection layers
if isinstance(anchors, (list, tuple)):
self.num_anchors = len(anchors[0]) // 2
self.na = len(anchors[0]) // 2
else:
self.num_anchors = anchors
self.na = anchors
self.anchors = anchors
self.grid = [torch.zeros(1)] * self.num_layers
self.grid = [torch.zeros(1)] * num_layers
self.prior_prob = 1e-2
self.inplace = inplace
stride = [8, 16, 32] # strides computed during build
Expand All @@ -43,17 +43,17 @@ def __init__(self, num_classes=80, anchors=1, num_layers=3, inplace=True, head_l

def initialize_biases(self):
for conv in self.cls_preds:
b = conv.bias.view(self.num_anchors, -1)
b = conv.bias.view(self.na, -1)
b.data.fill_(-math.log((1 - self.prior_prob) / self.prior_prob))
conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
for conv in self.obj_preds:
b = conv.bias.view(self.num_anchors, -1)
b = conv.bias.view(self.na, -1)
b.data.fill_(-math.log((1 - self.prior_prob) / self.prior_prob))
conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)

def forward(self, x):
z = []
for i in range(self.num_layers):
for i in range(self.nl):
x[i] = self.stems[i](x[i])
cls_x = x[i]
reg_x = x[i]
Expand All @@ -65,23 +65,23 @@ def forward(self, x):
if self.training:
x[i] = torch.cat([reg_output, obj_output, cls_output], 1)
bs, _, ny, nx = x[i].shape
x[i] = x[i].view(bs, self.num_anchors, self.num_outputs, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
else:
y = torch.cat([reg_output, obj_output.sigmoid(), cls_output.sigmoid()], 1)
bs, _, ny, nx = y.shape
y = y.view(bs, self.num_anchors, self.num_outputs, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
y = y.view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
if self.grid[i].shape[2:4] != y.shape[2:4]:
d = self.stride.device
yv, xv = torch.meshgrid([torch.arange(ny).to(d), torch.arange(nx).to(d)])
self.grid[i] = torch.stack((xv, yv), 2).view(1, self.num_anchors, ny, nx, 2).float()
self.grid[i] = torch.stack((xv, yv), 2).view(1, self.na, ny, nx, 2).float()
if self.inplace:
y[..., 0:2] = (y[..., 0:2] + self.grid[i]) * self.stride[i] # xy
y[..., 2:4] = torch.exp(y[..., 2:4]) * self.stride[i] # wh
else:
xy = (y[..., 0:2] + self.grid[i]) * self.stride[i] # xy
wh = torch.exp(y[..., 2:4]) * self.stride[i] # wh
y = torch.cat((xy, wh, y[..., 4:]), -1)
z.append(y.view(bs, -1, self.num_outputs))
z.append(y.view(bs, -1, self.no))
return x if self.training else torch.cat(z, 1)


Expand Down
20 changes: 10 additions & 10 deletions yolov6/models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,37 +6,37 @@
from yolov6.utils.torch_utils import initialize_weights
from yolov6.models.efficientrep import EfficientRep
from yolov6.models.reppan import RepPANNeck
from yolov6.models.effidehead import EffiDeHead, build_effidehead_layer
from yolov6.models.effidehead import Detect, build_effidehead_layer


class Model(nn.Module):
def __init__(self, config, channels=3, num_classes=None, anchors=None): # model, input channels, number of classes
super().__init__()
# Build network
num_layers = config.model.head.num_layers
self.backbone, self.neck, self.head = build_network(config, channels, num_classes, anchors, num_layers)
self.backbone, self.neck, self.detect = build_network(config, channels, num_classes, anchors, num_layers)

# Init Detect head
begin_indices = config.model.head.begin_indices
out_indices_head = config.model.head.out_indices
self.stride = self.head.stride
self.head.i = begin_indices
self.head.f = out_indices_head
self.head.initialize_biases()
self.stride = self.detect.stride
self.detect.i = begin_indices
self.detect.f = out_indices_head
self.detect.initialize_biases()

# Init weights
initialize_weights(self)

def forward(self, x):
x = self.backbone(x)
x = self.neck(x)
x = self.head(x)
x = self.detect(x)
return x

def _apply(self, fn):
self = super()._apply(fn)
self.head.stride = fn(self.head.stride)
self.head.grid = list(map(fn, self.head.grid))
self.detect.stride = fn(self.detect.stride)
self.detect.grid = list(map(fn, self.detect.grid))
return self


Expand Down Expand Up @@ -69,7 +69,7 @@ def build_network(config, channels, num_classes, anchors, num_layers):

head_layers = build_effidehead_layer(channels_list, num_anchors, num_classes)

head = EffiDeHead(num_classes, anchors, num_layers, head_layers=head_layers)
head = Detect(num_classes, anchors, num_layers, head_layers=head_layers)

return backbone, neck, head

Expand Down

0 comments on commit 3185477

Please sign in to comment.