From 03a82fdf1c2f6809ae4b29e803b4c032a82c550e Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 24 Jul 2023 11:16:24 +0800 Subject: [PATCH] update constructor structure --- vision_toolbox/backbones/darknet.py | 56 ++++++++++++++++------------- 1 file changed, 32 insertions(+), 24 deletions(-) diff --git a/vision_toolbox/backbones/darknet.py b/vision_toolbox/backbones/darknet.py index ff97b1b..ccc93ef 100644 --- a/vision_toolbox/backbones/darknet.py +++ b/vision_toolbox/backbones/darknet.py @@ -5,7 +5,7 @@ from __future__ import annotations -from typing import Callable +from typing import Callable, NamedTuple import torch from torch import Tensor, nn @@ -14,6 +14,9 @@ from .base import BaseBackbone +_BASE_URL = "https://github.com/gau-nernst/vision-toolbox/releases/download/v0.0.1/" + + class DarknetBlock(nn.Module): def __init__(self, in_channels: int, expansion: float = 0.5) -> None: super().__init__() @@ -52,24 +55,30 @@ def forward(self, x: Tensor) -> Tensor: return out +class DarknetStageConfig(NamedTuple): + n_blocks: int + out_channels: int + + class Darknet(BaseBackbone): def __init__( self, stem_channels: int, - n_blocks_list: list[int], - out_channels_list: list[int], + stage_configs: list[DarknetStageConfig], stage_cls: Callable[..., nn.Module] = DarknetStage, ): + assert len(stage_configs) > 0 super().__init__() - self.out_channels_list = tuple(out_channels_list) + self.out_channels_list = tuple(cfg.out_channels for cfg in stage_configs) self.stride = 32 self.stem = ConvNormAct(3, stem_channels) self.stages = nn.ModuleList() - in_c = stem_channels - for n, c in zip(n_blocks_list, out_channels_list): - self.stages.append(stage_cls(n, in_c, c) if n > 0 else ConvNormAct(in_c, c, stride=2)) - in_c = c + in_ch = stem_channels + for n_blocks, out_ch in stage_configs: + stage = stage_cls(n_blocks, in_ch, out_ch) if n_blocks else ConvNormAct(in_ch, out_ch, 3, 2) + self.stages.append(stage) + in_ch = out_ch def get_feature_maps(self, x: Tensor) -> list[Tensor]: outputs = [self.stem(x)] @@ -84,25 +93,25 @@ def from_config(variant: str, pretrained: bool = False) -> Darknet: darknet53=((1, 2, 8, 8, 4), DarknetStage, "darknet53-94427f5b.pth"), # YOLOv3 cspdarknet53=((1, 2, 8, 8, 4), CSPDarknetStage, "cspdarknet53-3bfa0423.pth"), # CSPNet/YOLOv4 )[variant] - m = Darknet(32, n_blocks_list, (64, 128, 256, 512, 1024), stage_cls) + stage_configs = list(map(DarknetStageConfig, n_blocks_list, (64, 128, 256, 512, 1024))) + m = Darknet(32, stage_configs, stage_cls) if pretrained: - base_url = "https://github.com/gau-nernst/vision-toolbox/releases/download/v0.0.1/" - m._load_state_dict_from_url(base_url + ckpt) + m._load_state_dict_from_url(_BASE_URL + ckpt) return m class DarknetYOLOv5(BaseBackbone): - def __init__(self, stem_channels: int, n_blocks_list: list[int], out_channels_list: list[int]) -> None: + def __init__(self, stem_channels: int, stage_configs: list[DarknetStageConfig]) -> None: super().__init__() - self.out_channels_list = (stem_channels,) + tuple(out_channels_list) + self.out_channels_list = (stem_channels,) + tuple(cfg.out_channels for cfg in stage_configs) self.stride = 32 self.stem = ConvNormAct(3, stem_channels, 6, 2) self.stages = nn.ModuleList() - in_c = stem_channels - for n, c in zip(n_blocks_list, out_channels_list): - self.stages.append(CSPDarknetStage(n, in_c, c)) - in_c = c + in_ch = stem_channels + for n_blocks, out_ch in stage_configs: + self.stages.append(CSPDarknetStage(n_blocks, in_ch, out_ch)) + in_ch = out_ch def get_feature_maps(self, x: Tensor) -> list[Tensor]: outputs = [self.stem(x)] @@ -119,12 +128,11 @@ def from_config(variant: str, pretrained: bool = False) -> DarknetYOLOv5: l=(1, 1, "darknet_yolov5l-8e25d388.pth"), x=(4 / 3, 5 / 4, "darknet_yolov5x-0ed0c035.pth"), )[variant] - m = DarknetYOLOv5( - int(64 * width_scale), - [int(d * depth_scale) for d in (3, 6, 9, 3)], - [int(w * width_scale) for w in (128, 256, 512, 1024)], - ) + stage_configs = [ + DarknetStageConfig(int(d * depth_scale), int(w * width_scale)) + for d, w in zip((3, 6, 9, 3), (128, 256, 512, 1024)) + ] + m = DarknetYOLOv5(int(64 * width_scale), stage_configs) if pretrained: - base_url = "https://github.com/gau-nernst/vision-toolbox/releases/download/v0.0.1/" - m._load_state_dict_from_url(base_url + ckpt) + m._load_state_dict_from_url(_BASE_URL + ckpt) return m