diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..054556e8 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +__pycache__/ +runs/ +weights/ diff --git a/configs/yolov6_tiny.py b/configs/yolov6_tiny.py index fea2595a..be455de2 100644 --- a/configs/yolov6_tiny.py +++ b/configs/yolov6_tiny.py @@ -15,7 +15,7 @@ out_channels=[256, 128, 128, 256, 256, 512], ), head=dict( - type='YOLOv6tHead', + type='EffiDeHead', in_channels=[128, 256, 512], num_layers=3, begin_indices=24, diff --git a/configs/yolov6s_finetune.py b/configs/yolov6s_finetune.py index 50d1d5a6..66e6600d 100644 --- a/configs/yolov6s_finetune.py +++ b/configs/yolov6s_finetune.py @@ -15,7 +15,7 @@ out_channels=[256, 128, 128, 256, 256, 512], ), head=dict( - type='YOLOv6sHead', + type='EffiDeHead', in_channels=[128, 256, 512], num_layers=3, begin_indices=24, diff --git a/deploy/ONNX/export_onnx.py b/deploy/ONNX/export_onnx.py index ef29802a..addd8ff6 100644 --- a/deploy/ONNX/export_onnx.py +++ b/deploy/ONNX/export_onnx.py @@ -13,7 +13,7 @@ sys.path.append(str(ROOT)) from yolov6.models.yolo import * -from yolov6.models.effidehead import EffiDeHead +from yolov6.models.effidehead import Detect from yolov6.layers.common import * from yolov6.utils.events import LOGGER from yolov6.utils.checkpoint import load_checkpoint @@ -53,7 +53,7 @@ if isinstance(m, Conv): # assign export-friendly activations if isinstance(m.act, nn.SiLU): m.act = SiLU() - elif isinstance(m, EffiDeHead): + elif isinstance(m, Detect): m.inplace = args.inplace y = model(img) # dry run diff --git a/deploy/OpenVINO/export_openvino.py b/deploy/OpenVINO/export_openvino.py index c5d0a5b4..7b59ae0f 100644 --- a/deploy/OpenVINO/export_openvino.py +++ b/deploy/OpenVINO/export_openvino.py @@ -14,7 +14,7 @@ sys.path.append(str(ROOT)) from yolov6.models.yolo import * -from yolov6.models.effidehead import EffiDeHead +from yolov6.models.effidehead import Detect from yolov6.layers.common import * from yolov6.utils.events import LOGGER from yolov6.utils.checkpoint import load_checkpoint @@ -54,7 +54,7 @@ if isinstance(m, Conv): # assign export-friendly activations if isinstance(m.act, nn.SiLU): m.act = SiLU() - elif isinstance(m, EffiDeHead): + elif isinstance(m, Detect): m.inplace = args.inplace y = model(img) # dry run diff --git a/yolov6/models/effidehead.py b/yolov6/models/effidehead.py index 1e69491d..4d9c1c75 100644 --- a/yolov6/models/effidehead.py +++ b/yolov6/models/effidehead.py @@ -4,20 +4,20 @@ from yolov6.layers.common import * -class EffiDeHead(nn.Module): +class Detect(nn.Module): '''Efficient Decoupled Head''' def __init__(self, num_classes=80, anchors=1, num_layers=3, inplace=True, head_layers=None): # detection layer super().__init__() assert head_layers is not None - self.num_classes = num_classes # number of classes - self.num_outputs = num_classes + 5 # number of outputs per anchor - self.num_layers = num_layers # number of detection layers + self.nc = num_classes # number of classes + self.no = num_classes + 5 # number of outputs per anchor + self.nl = num_layers # number of detection layers if isinstance(anchors, (list, tuple)): - self.num_anchors = len(anchors[0]) // 2 + self.na = len(anchors[0]) // 2 else: - self.num_anchors = anchors + self.na = anchors self.anchors = anchors - self.grid = [torch.zeros(1)] * self.num_layers + self.grid = [torch.zeros(1)] * num_layers self.prior_prob = 1e-2 self.inplace = inplace stride = [8, 16, 32] # strides computed during build @@ -43,17 +43,17 @@ def __init__(self, num_classes=80, anchors=1, num_layers=3, inplace=True, head_l def initialize_biases(self): for conv in self.cls_preds: - b = conv.bias.view(self.num_anchors, -1) + b = conv.bias.view(self.na, -1) b.data.fill_(-math.log((1 - self.prior_prob) / self.prior_prob)) conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True) for conv in self.obj_preds: - b = conv.bias.view(self.num_anchors, -1) + b = conv.bias.view(self.na, -1) b.data.fill_(-math.log((1 - self.prior_prob) / self.prior_prob)) conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True) def forward(self, x): z = [] - for i in range(self.num_layers): + for i in range(self.nl): x[i] = self.stems[i](x[i]) cls_x = x[i] reg_x = x[i] @@ -65,15 +65,15 @@ def forward(self, x): if self.training: x[i] = torch.cat([reg_output, obj_output, cls_output], 1) bs, _, ny, nx = x[i].shape - x[i] = x[i].view(bs, self.num_anchors, self.num_outputs, ny, nx).permute(0, 1, 3, 4, 2).contiguous() + x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous() else: y = torch.cat([reg_output, obj_output.sigmoid(), cls_output.sigmoid()], 1) bs, _, ny, nx = y.shape - y = y.view(bs, self.num_anchors, self.num_outputs, ny, nx).permute(0, 1, 3, 4, 2).contiguous() + y = y.view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous() if self.grid[i].shape[2:4] != y.shape[2:4]: d = self.stride.device yv, xv = torch.meshgrid([torch.arange(ny).to(d), torch.arange(nx).to(d)]) - self.grid[i] = torch.stack((xv, yv), 2).view(1, self.num_anchors, ny, nx, 2).float() + self.grid[i] = torch.stack((xv, yv), 2).view(1, self.na, ny, nx, 2).float() if self.inplace: y[..., 0:2] = (y[..., 0:2] + self.grid[i]) * self.stride[i] # xy y[..., 2:4] = torch.exp(y[..., 2:4]) * self.stride[i] # wh @@ -81,7 +81,7 @@ def forward(self, x): xy = (y[..., 0:2] + self.grid[i]) * self.stride[i] # xy wh = torch.exp(y[..., 2:4]) * self.stride[i] # wh y = torch.cat((xy, wh, y[..., 4:]), -1) - z.append(y.view(bs, -1, self.num_outputs)) + z.append(y.view(bs, -1, self.no)) return x if self.training else torch.cat(z, 1) diff --git a/yolov6/models/yolo.py b/yolov6/models/yolo.py index 7d3cee8b..b30b6b89 100644 --- a/yolov6/models/yolo.py +++ b/yolov6/models/yolo.py @@ -6,7 +6,7 @@ from yolov6.utils.torch_utils import initialize_weights from yolov6.models.efficientrep import EfficientRep from yolov6.models.reppan import RepPANNeck -from yolov6.models.effidehead import EffiDeHead, build_effidehead_layer +from yolov6.models.effidehead import Detect, build_effidehead_layer class Model(nn.Module): @@ -14,15 +14,15 @@ def __init__(self, config, channels=3, num_classes=None, anchors=None): # model super().__init__() # Build network num_layers = config.model.head.num_layers - self.backbone, self.neck, self.head = build_network(config, channels, num_classes, anchors, num_layers) + self.backbone, self.neck, self.detect = build_network(config, channels, num_classes, anchors, num_layers) # Init Detect head begin_indices = config.model.head.begin_indices out_indices_head = config.model.head.out_indices - self.stride = self.head.stride - self.head.i = begin_indices - self.head.f = out_indices_head - self.head.initialize_biases() + self.stride = self.detect.stride + self.detect.i = begin_indices + self.detect.f = out_indices_head + self.detect.initialize_biases() # Init weights initialize_weights(self) @@ -30,13 +30,13 @@ def __init__(self, config, channels=3, num_classes=None, anchors=None): # model def forward(self, x): x = self.backbone(x) x = self.neck(x) - x = self.head(x) + x = self.detect(x) return x def _apply(self, fn): self = super()._apply(fn) - self.head.stride = fn(self.head.stride) - self.head.grid = list(map(fn, self.head.grid)) + self.detect.stride = fn(self.detect.stride) + self.detect.grid = list(map(fn, self.detect.grid)) return self @@ -69,7 +69,7 @@ def build_network(config, channels, num_classes, anchors, num_layers): head_layers = build_effidehead_layer(channels_list, num_anchors, num_classes) - head = EffiDeHead(num_classes, anchors, num_layers, head_layers=head_layers) + head = Detect(num_classes, anchors, num_layers, head_layers=head_layers) return backbone, neck, head