diff --git a/timm/models/hgnet.py b/timm/models/hgnet.py index 450c4bfc4..d7f38af64 100644 --- a/timm/models/hgnet.py +++ b/timm/models/hgnet.py @@ -1,6 +1,7 @@ """ PP-HGNet (V1 & V2) Reference: +https://github.com/PaddlePaddle/PaddleClas/blob/develop/docs/zh_CN/models/ImageNet1k/PP-HGNetV2.md The Paddle Implement of PP-HGNet (https://github.com/PaddlePaddle/PaddleClas/blob/release/2.5.1/docs/en/models/PP-HGNet_en.md) PP-HGNet: https://github.com/PaddlePaddle/PaddleClas/blob/release/2.5.1/ppcls/arch/backbone/legendary_models/pp_hgnet.py PP-HGNetv2: https://github.com/PaddlePaddle/PaddleClas/blob/release/2.5.1/ppcls/arch/backbone/legendary_models/pp_hgnet_v2.py @@ -11,17 +12,19 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import SelectAdaptivePool2d +from timm.layers import SelectAdaptivePool2d, DropPath, create_conv2d from ._builder import build_model_with_cfg from ._registry import register_model, generate_default_cfgs -__all__ = ['PPHGNet'] +__all__ = ['HighPerfGpuNet'] class LearnableAffineBlock(nn.Module): - def __init__(self, - scale_value=1.0, - bias_value=0.0): + def __init__( + self, + scale_value=1.0, + bias_value=0.0 + ): super().__init__() self.scale = nn.Parameter(torch.tensor([scale_value]), requires_grad=True) self.bias = nn.Parameter(torch.tensor([bias_value]), requires_grad=True) @@ -32,27 +35,28 @@ def forward(self, x): class ConvBNAct(nn.Module): def __init__( - self, - in_channels, - out_channels, - kernel_size, - stride=1, - groups=1, - use_act=True, - use_lab=False + self, + in_chs, + out_chs, + kernel_size, + stride=1, + groups=1, + padding='', + use_act=True, + use_lab=False ): super().__init__() self.use_act = use_act self.use_lab = use_lab - self.conv = nn.Conv2d( - in_channels, - out_channels, + self.conv = create_conv2d( + in_chs, + out_chs, kernel_size, - stride, - padding=(kernel_size - 1) // 2, + stride=stride, + padding=padding, groups=groups, - bias=False) - self.bn = nn.BatchNorm2d(out_channels) + ) + self.bn = nn.BatchNorm2d(out_chs) if self.use_act: self.act = nn.ReLU() else: @@ -72,27 +76,29 @@ def forward(self, x): class LightConvBNAct(nn.Module): def __init__( - self, - in_channels, - out_channels, - kernel_size, - groups=1, - use_lab=False + self, + in_chs, + out_chs, + kernel_size, + groups=1, + use_lab=False ): super().__init__() self.conv1 = ConvBNAct( - in_channels=in_channels, - out_channels=out_channels, + in_chs, + out_chs, kernel_size=1, use_act=False, - use_lab=use_lab) + use_lab=use_lab, + ) self.conv2 = ConvBNAct( - in_channels=out_channels, - out_channels=out_channels, + out_chs, + out_chs, kernel_size=kernel_size, - groups=out_channels, + groups=out_chs, use_act=True, - use_lab=use_lab) + use_lab=use_lab, + ) def forward(self, x): x = self.conv1(x) @@ -100,15 +106,16 @@ def forward(self, x): return x -class ESEModule(nn.Module): - def __init__(self, channels): +class EseModule(nn.Module): + def __init__(self, chs): super().__init__() self.conv = nn.Conv2d( - in_channels=channels, - out_channels=channels, + chs, + chs, kernel_size=1, stride=1, - padding=0) + padding=0, + ) self.sigmoid = nn.Sigmoid() def forward(self, x): @@ -121,15 +128,15 @@ def forward(self, x): class StemV1(nn.Module): # for PP-HGNet - def __init__(self, stem_channels): + def __init__(self, stem_chs): super().__init__() self.stem = nn.Sequential(*[ ConvBNAct( - in_channels=stem_channels[i], - out_channels=stem_channels[i + 1], + stem_chs[i], + stem_chs[i + 1], kernel_size=3, stride=2 if i == 0 else 1) for i in range( - len(stem_channels) - 1) + len(stem_chs) - 1) ]) self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) @@ -141,39 +148,44 @@ def forward(self, x): class StemV2(nn.Module): # for PP-HGNetv2 - def __init__(self, in_channels, mid_channels, out_channels, use_lab=False): + def __init__(self, in_chs, mid_chs, out_chs, use_lab=False): super().__init__() self.stem1 = ConvBNAct( - in_channels=in_channels, - out_channels=mid_channels, + in_chs, + mid_chs, kernel_size=3, stride=2, - use_lab=use_lab) + use_lab=use_lab, + ) self.stem2a = ConvBNAct( - in_channels=mid_channels, - out_channels=mid_channels // 2, + mid_chs, + mid_chs // 2, kernel_size=2, stride=1, - use_lab=use_lab) + use_lab=use_lab, + ) self.stem2b = ConvBNAct( - in_channels=mid_channels // 2, - out_channels=mid_channels, + mid_chs // 2, + mid_chs, kernel_size=2, stride=1, - use_lab=use_lab) + use_lab=use_lab, + ) self.stem3 = ConvBNAct( - in_channels=mid_channels * 2, - out_channels=mid_channels, + mid_chs * 2, + mid_chs, kernel_size=3, stride=2, - use_lab=use_lab) + use_lab=use_lab, + ) self.stem4 = ConvBNAct( - in_channels=mid_channels, - out_channels=out_channels, + mid_chs, + out_chs, kernel_size=1, stride=1, - use_lab=use_lab) - self.pool = nn.MaxPool2d(kernel_size=2, stride=1) + use_lab=use_lab, + ) + self.pool = nn.MaxPool2d(kernel_size=2, stride=1, ceil_mode=True) def forward(self, x): x = self.stem1(x) @@ -188,18 +200,19 @@ def forward(self, x): return x -class HGBlock(nn.Module): +class HighPerfGpuBlock(nn.Module): def __init__( - self, - in_channels, - mid_channels, - out_channels, - layer_num, - kernel_size=3, - residual=False, - light_block=False, - use_lab=False, - agg='ese', + self, + in_chs, + mid_chs, + out_chs, + layer_num, + kernel_size=3, + residual=False, + light_block=False, + use_lab=False, + agg='ese', + drop_path=0., ): super().__init__() self.residual = residual @@ -209,105 +222,120 @@ def __init__( if light_block: self.layers.append( LightConvBNAct( - in_channels=in_channels if i == 0 else mid_channels, - out_channels=mid_channels, + in_chs if i == 0 else mid_chs, + mid_chs, kernel_size=kernel_size, - use_lab=use_lab,)) + use_lab=use_lab, + ) + ) else: self.layers.append( ConvBNAct( - in_channels=in_channels if i == 0 else mid_channels, - out_channels=mid_channels, + in_chs if i == 0 else mid_chs, + mid_chs, kernel_size=kernel_size, stride=1, - use_lab=use_lab,)) + use_lab=use_lab, + ) + ) # feature aggregation - total_channels = in_channels + layer_num * mid_channels + total_chs = in_chs + layer_num * mid_chs if agg == 'se': aggregation_squeeze_conv = ConvBNAct( - in_channels=total_channels, - out_channels=out_channels // 2, + total_chs, + out_chs // 2, kernel_size=1, stride=1, - use_lab=use_lab) + use_lab=use_lab, + ) aggregation_excitation_conv = ConvBNAct( - in_channels=out_channels // 2, - out_channels=out_channels, + out_chs // 2, + out_chs, kernel_size=1, stride=1, - use_lab=use_lab) + use_lab=use_lab, + ) self.aggregation = nn.Sequential( aggregation_squeeze_conv, - aggregation_excitation_conv) + aggregation_excitation_conv, + ) else: aggregation_conv = ConvBNAct( - in_channels=total_channels, - out_channels=out_channels, + total_chs, + out_chs, kernel_size=1, stride=1, - use_lab=use_lab) - att = ESEModule(out_channels) + use_lab=use_lab, + ) + att = EseModule(out_chs) self.aggregation = nn.Sequential( aggregation_conv, - att) + att, + ) + + self.drop_path = DropPath(drop_path) if drop_path else nn.Identity() def forward(self, x): identity = x - output = [] - output.append(x) + output = [x] for layer in self.layers: x = layer(x) output.append(x) x = torch.cat(output, dim=1) x = self.aggregation(x) if self.residual: - x = x + identity + x = self.drop_path(x) + identity return x -class HGStage(nn.Module): +class HighPerfGpuStage(nn.Module): def __init__( - self, - in_channels, - mid_channels, - out_channels, - block_num, - layer_num, - downsample=True, - stride=2, - light_block=False, - kernel_size=3, - use_lab=False, - agg='ese', + self, + in_chs, + mid_chs, + out_chs, + block_num, + layer_num, + downsample=True, + stride=2, + light_block=False, + kernel_size=3, + use_lab=False, + agg='ese', + drop_path=0., ): super().__init__() self.downsample = downsample if downsample: self.downsample = ConvBNAct( - in_channels=in_channels, - out_channels=in_channels, + in_chs, + in_chs, kernel_size=3, stride=stride, - groups=in_channels, + groups=in_chs, use_act=False, - use_lab=use_lab,) + use_lab=use_lab, + ) else: self.downsample = nn.Identity() blocks_list = [] for i in range(block_num): blocks_list.append( - HGBlock( - in_channels if i == 0 else out_channels, - mid_channels, - out_channels, + HighPerfGpuBlock( + in_chs if i == 0 else out_chs, + mid_chs, + out_chs, layer_num, residual=False if i == 0 else True, kernel_size=kernel_size, light_block=light_block, use_lab=use_lab, - agg=agg)) + agg=agg, + drop_path=drop_path[i] if isinstance(drop_path, (list, tuple)) else drop_path, + ) + ) self.blocks = nn.Sequential(*blocks_list) def forward(self, x): @@ -318,25 +346,26 @@ def forward(self, x): class ClassifierHead(nn.Module): def __init__( - self, - num_features, - num_classes, - pool_type='avg', - drop_rate=0., - use_last_conv=True, - class_expand=2048, - use_lab=False + self, + num_features, + num_classes, + pool_type='avg', + drop_rate=0., + use_last_conv=True, + class_expand=2048, + use_lab=False ): super(ClassifierHead, self).__init__() self.global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=False, input_fmt='NCHW') if use_last_conv: last_conv = nn.Conv2d( - in_channels=num_features, - out_channels=class_expand, + num_features, + class_expand, kernel_size=1, stride=1, padding=0, - bias=False) + bias=False, + ) act = nn.ReLU() if use_lab: lab = LearnableAffineBlock() @@ -365,23 +394,24 @@ def forward(self, x, pre_logits: bool = False): return x -class PPHGNet(nn.Module): +class HighPerfGpuNet(nn.Module): def __init__( - self, - cfg, - in_chans=3, - num_classes=1000, - global_pool='avg', - use_last_conv=True, - class_expand=2048, - drop_rate=0., - use_lab=False, - **kwargs, + self, + cfg, + in_chans=3, + num_classes=1000, + global_pool='avg', + use_last_conv=True, + class_expand=2048, + drop_rate=0., + drop_path_rate=0., + use_lab=False, + **kwargs, ): - super(PPHGNet, self).__init__() + super(HighPerfGpuNet, self).__init__() stem_type = cfg["stem_type"] - stem_channels = cfg["stem_channels"] + stem_chs = cfg["stem_chs"] stages_cfg = [cfg["stage1"], cfg["stage2"], cfg["stage3"], cfg["stage4"]] self.num_classes = num_classes self.drop_rate = drop_rate @@ -392,40 +422,43 @@ def __init__( assert stem_type in ['v1', 'v2'] if stem_type == 'v2': self.stem = StemV2( - in_channels=in_chans, - mid_channels=stem_channels[0], - out_channels=stem_channels[1], + in_chs=in_chans, + mid_chs=stem_chs[0], + out_chs=stem_chs[1], use_lab=use_lab) else: - self.stem = StemV1([in_chans] + stem_channels) + self.stem = StemV1([in_chans] + stem_chs) current_stride = 4 stages = [] self.feature_info = [] + block_depths = [c[3] for c in stages_cfg] + dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(block_depths)).split(block_depths)] for i, stage_config in enumerate(stages_cfg): - in_channels, mid_channels, out_channels, block_num, is_downsample, light_block, kernel_size, layer_num = stage_config - stages += [HGStage( - in_channels=in_channels, - mid_channels=mid_channels, - out_channels=out_channels, + in_chs, mid_chs, out_chs, block_num, downsample, light_block, kernel_size, layer_num = stage_config + stages += [HighPerfGpuStage( + in_chs=in_chs, + mid_chs=mid_chs, + out_chs=out_chs, block_num=block_num, layer_num=layer_num, - downsample=is_downsample, + downsample=downsample, light_block=light_block, kernel_size=kernel_size, use_lab=use_lab, - agg='ese' if stem_type == 'v1' else 'se' + agg='ese' if stem_type == 'v1' else 'se', + drop_path=dpr[i], )] - self.num_features = out_channels - if is_downsample: + self.num_features = out_chs + if downsample: current_stride *= 2 self.feature_info += [dict(num_chs=self.num_features, reduction=current_stride, module=f'stages.{i}')] self.stages = nn.Sequential(*stages) if num_classes > 0: self.head = ClassifierHead( - num_features=self.num_features, + self.num_features, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate, @@ -468,7 +501,7 @@ def reset_classifier(self, num_classes, global_pool='avg'): self.num_classes = num_classes if num_classes > 0: self.head = ClassifierHead( - num_features=self.num_features, + self.num_features, num_classes=num_classes, pool_type=global_pool, drop_rate=self.drop_rate, @@ -476,7 +509,7 @@ def reset_classifier(self, num_classes, global_pool='avg'): class_expand=self.class_expand, use_lab=self.use_lab) else: - if global_pool == 'avg': + if global_pool: self.head = SelectAdaptivePool2d(pool_type=global_pool, flatten=True) else: self.head = nn.Identity() @@ -498,8 +531,8 @@ def forward(self, x): # PP-HGNet hgnet_tiny={ "stem_type": 'v1', - "stem_channels": [48, 48, 96], - # in_channels, mid_channels, out_channels, blocks, downsample, light_block, kernel_size, layer_num + "stem_chs": [48, 48, 96], + # in_chs, mid_chs, out_chs, blocks, downsample, light_block, kernel_size, layer_num "stage1": [96, 96, 224, 1, False, False, 3, 5], "stage2": [224, 128, 448, 1, True, False, 3, 5], "stage3": [448, 160, 512, 2, True, False, 3, 5], @@ -507,8 +540,8 @@ def forward(self, x): }, hgnet_small={ "stem_type": 'v1', - "stem_channels": [64, 64, 128], - # in_channels, mid_channels, out_channels, blocks, downsample, light_block, kernel_size, layer_num + "stem_chs": [64, 64, 128], + # in_chs, mid_chs, out_chs, blocks, downsample, light_block, kernel_size, layer_num "stage1": [128, 128, 256, 1, False, False, 3, 6], "stage2": [256, 160, 512, 1, True, False, 3, 6], "stage3": [512, 192, 768, 2, True, False, 3, 6], @@ -516,8 +549,8 @@ def forward(self, x): }, hgnet_base={ "stem_type": 'v1', - "stem_channels": [96, 96, 160], - # in_channels, mid_channels, out_channels, blocks, downsample, light_block, kernel_size, layer_num + "stem_chs": [96, 96, 160], + # in_chs, mid_chs, out_chs, blocks, downsample, light_block, kernel_size, layer_num "stage1": [160, 192, 320, 1, False, False, 3, 7], "stage2": [320, 224, 640, 2, True, False, 3, 7], "stage3": [640, 256, 960, 3, True, False, 3, 7], @@ -526,8 +559,8 @@ def forward(self, x): # PP-HGNetv2 hgnetv2_b0={ "stem_type": 'v2', - "stem_channels": [16, 16], - # in_channels, mid_channels, out_channels, blocks, downsample, light_block, kernel_size, layer_num + "stem_chs": [16, 16], + # in_chs, mid_chs, out_chs, blocks, downsample, light_block, kernel_size, layer_num "stage1": [16, 16, 64, 1, False, False, 3, 3], "stage2": [64, 32, 256, 1, True, False, 3, 3], "stage3": [256, 64, 512, 2, True, True, 5, 3], @@ -535,8 +568,8 @@ def forward(self, x): }, hgnetv2_b1={ "stem_type": 'v2', - "stem_channels": [24, 32], - # in_channels, mid_channels, out_channels, blocks, downsample, light_block, kernel_size, layer_num + "stem_chs": [24, 32], + # in_chs, mid_chs, out_chs, blocks, downsample, light_block, kernel_size, layer_num "stage1": [32, 32, 64, 1, False, False, 3, 3], "stage2": [64, 48, 256, 1, True, False, 3, 3], "stage3": [256, 96, 512, 2, True, True, 5, 3], @@ -544,8 +577,8 @@ def forward(self, x): }, hgnetv2_b2={ "stem_type": 'v2', - "stem_channels": [24, 32], - # in_channels, mid_channels, out_channels, blocks, downsample, light_block, kernel_size, layer_num + "stem_chs": [24, 32], + # in_chs, mid_chs, out_chs, blocks, downsample, light_block, kernel_size, layer_num "stage1": [32, 32, 96, 1, False, False, 3, 4], "stage2": [96, 64, 384, 1, True, False, 3, 4], "stage3": [384, 128, 768, 3, True, True, 5, 4], @@ -553,8 +586,8 @@ def forward(self, x): }, hgnetv2_b3={ "stem_type": 'v2', - "stem_channels": [24, 32], - # in_channels, mid_channels, out_channels, blocks, downsample, light_block, kernel_size, layer_num + "stem_chs": [24, 32], + # in_chs, mid_chs, out_chs, blocks, downsample, light_block, kernel_size, layer_num "stage1": [32, 32, 128, 1, False, False, 3, 5], "stage2": [128, 64, 512, 1, True, False, 3, 5], "stage3": [512, 128, 1024, 3, True, True, 5, 5], @@ -562,8 +595,8 @@ def forward(self, x): }, hgnetv2_b4={ "stem_type": 'v2', - "stem_channels": [32, 48], - # in_channels, mid_channels, out_channels, blocks, downsample, light_block, kernel_size, layer_num + "stem_chs": [32, 48], + # in_chs, mid_chs, out_chs, blocks, downsample, light_block, kernel_size, layer_num "stage1": [48, 48, 128, 1, False, False, 3, 6], "stage2": [128, 96, 512, 1, True, False, 3, 6], "stage3": [512, 192, 1024, 3, True, True, 5, 6], @@ -571,8 +604,8 @@ def forward(self, x): }, hgnetv2_b5={ "stem_type": 'v2', - "stem_channels": [32, 64], - # in_channels, mid_channels, out_channels, blocks, downsample, light_block, kernel_size, layer_num + "stem_chs": [32, 64], + # in_chs, mid_chs, out_chs, blocks, downsample, light_block, kernel_size, layer_num "stage1": [64, 64, 128, 1, False, False, 3, 6], "stage2": [128, 128, 512, 2, True, False, 3, 6], "stage3": [512, 256, 1024, 5, True, True, 5, 6], @@ -580,8 +613,8 @@ def forward(self, x): }, hgnetv2_b6={ "stem_type": 'v2', - "stem_channels": [48, 96], - # in_channels, mid_channels, out_channels, blocks, downsample, light_block, kernel_size, layer_num + "stem_chs": [48, 96], + # in_chs, mid_chs, out_chs, blocks, downsample, light_block, kernel_size, layer_num "stage1": [96, 96, 192, 2, False, False, 3, 6], "stage2": [192, 192, 512, 3, True, False, 3, 6], "stage3": [512, 384, 1024, 6, True, True, 5, 6], @@ -593,7 +626,7 @@ def forward(self, x): def _create_hgnet(variant, pretrained=False, **kwargs): out_indices = kwargs.pop('out_indices', (0, 1, 2, 3)) return build_model_with_cfg( - PPHGNet, + HighPerfGpuNet, variant, pretrained, model_cfg=model_cfgs[variant], @@ -606,120 +639,106 @@ def _cfg(url='', **kwargs): return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), - 'crop_pct': 0.95, 'interpolation': 'bilinear', + 'crop_pct': 0.965, 'interpolation': 'bicubic', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'classifier': 'head.fc', 'first_conv': 'stem.0.conv', + 'classifier': 'head.fc', 'first_conv': 'stem.stem1.conv', 'test_crop_pct': 1.0, 'test_input_size': (3, 288, 288), **kwargs, } -# default_cfgs = generate_default_cfgs({ -# 'hgnet_tiny.paddle_in1k': _cfg( -# first_conv='stem.0.conv', -# hf_hub_id='timm/'), -# 'hgnet_tiny.ssld_in1k': _cfg( -# first_conv='stem.0.conv', -# hf_hub_id='timm/'), -# 'hgnet_small.paddle_in1k': _cfg( -# first_conv='stem.0.conv', -# hf_hub_id='timm/'), -# 'hgnet_small.ssld_in1k': _cfg( -# first_conv='stem.0.conv', -# hf_hub_id='timm/'), -# 'hgnet_base.ssld_in1k': _cfg( -# first_conv='stem.0.conv', -# hf_hub_id='timm/'), -# 'hgnetv2_b0.ssld_in1k': _cfg( -# first_conv='stem.stem1.conv', -# hf_hub_id='timm/'), -# 'hgnetv2_b0.ssld_stage1': _cfg( -# first_conv='stem.stem1.conv', -# hf_hub_id='timm/'), -# 'hgnetv2_b1.ssld_in1k': _cfg( -# first_conv='stem.stem1.conv', -# hf_hub_id='timm/'), -# 'hgnetv2_b1.ssld_stage1': _cfg( -# first_conv='stem.stem1.conv', -# hf_hub_id='timm/'), -# 'hgnetv2_b2.ssld_in1k': _cfg( -# first_conv='stem.stem1.conv', -# hf_hub_id='timm/'), -# 'hgnetv2_b2.ssld_stage1': _cfg( -# first_conv='stem.stem1.conv', -# hf_hub_id='timm/'), -# 'hgnetv2_b3.ssld_in1k': _cfg( -# first_conv='stem.stem1.conv', -# hf_hub_id='timm/'), -# 'hgnetv2_b3.ssld_stage1': _cfg( -# first_conv='stem.stem1.conv', -# hf_hub_id='timm/'), -# 'hgnetv2_b4.ssld_in1k': _cfg( -# first_conv='stem.stem1.conv', -# hf_hub_id='timm/'), -# 'hgnetv2_b4.ssld_stage1': _cfg( -# first_conv='stem.stem1.conv', -# hf_hub_id='timm/'), -# 'hgnetv2_b5.ssld_in1k': _cfg( -# first_conv='stem.stem1.conv', -# hf_hub_id='timm/'), -# 'hgnetv2_b5.ssld_stage1': _cfg( -# first_conv='stem.stem1.conv', -# hf_hub_id='timm/'), -# 'hgnetv2_b6.ssld_in1k': _cfg( -# first_conv='stem.stem1.conv', -# hf_hub_id='timm/'), -# 'hgnetv2_b6.ssld_stage1': _cfg( -# first_conv='stem.stem1.conv', -# hf_hub_id='timm/'), -# }) +default_cfgs = generate_default_cfgs({ + 'hgnet_tiny.paddle_in1k': _cfg( + first_conv='stem.stem.0.conv', + hf_hub_id='timm/'), + 'hgnet_tiny.ssld_in1k': _cfg( + first_conv='stem.stem.0.conv', + hf_hub_id='timm/'), + 'hgnet_small.paddle_in1k': _cfg( + first_conv='stem.stem.0.conv', + hf_hub_id='timm/'), + 'hgnet_small.ssld_in1k': _cfg( + first_conv='stem.stem.0.conv', + hf_hub_id='timm/'), + 'hgnet_base.ssld_in1k': _cfg( + first_conv='stem.stem.0.conv', + hf_hub_id='timm/'), + 'hgnetv2_b0.ssld_stage2_ft_in1k': _cfg( + hf_hub_id='timm/'), + 'hgnetv2_b0.ssld_stage1_in22k_in1k': _cfg( + hf_hub_id='timm/'), + 'hgnetv2_b1.ssld_stage2_ft_in1k': _cfg( + hf_hub_id='timm/'), + 'hgnetv2_b1.ssld_stage1_in22k_in1k': _cfg( + hf_hub_id='timm/'), + 'hgnetv2_b2.ssld_stage2_ft_in1k': _cfg( + hf_hub_id='timm/'), + 'hgnetv2_b2.ssld_stage1_in22k_in1k': _cfg( + hf_hub_id='timm/'), + 'hgnetv2_b3.ssld_stage2_ft_in1k': _cfg( + hf_hub_id='timm/'), + 'hgnetv2_b3.ssld_stage1_in22k_in1k': _cfg( + hf_hub_id='timm/'), + 'hgnetv2_b4.ssld_stage2_ft_in1k': _cfg( + hf_hub_id='timm/'), + 'hgnetv2_b4.ssld_stage1_in22k_in1k': _cfg( + hf_hub_id='timm/'), + 'hgnetv2_b5.ssld_stage2_ft_in1k': _cfg( + hf_hub_id='timm/'), + 'hgnetv2_b5.ssld_stage1_in22k_in1k': _cfg( + hf_hub_id='timm/'), + 'hgnetv2_b6.ssld_stage2_ft_in1k': _cfg( + hf_hub_id='timm/'), + 'hgnetv2_b6.ssld_stage1_in22k_in1k': _cfg( + hf_hub_id='timm/'), +}) @register_model -def hgnet_tiny(pretrained=False, **kwargs) -> PPHGNet: +def hgnet_tiny(pretrained=False, **kwargs) -> HighPerfGpuNet: return _create_hgnet('hgnet_tiny', pretrained=pretrained, **kwargs) @register_model -def hgnet_small(pretrained=False, **kwargs) -> PPHGNet: +def hgnet_small(pretrained=False, **kwargs) -> HighPerfGpuNet: return _create_hgnet('hgnet_small', pretrained=pretrained, **kwargs) @register_model -def hgnet_base(pretrained=False, **kwargs) -> PPHGNet: +def hgnet_base(pretrained=False, **kwargs) -> HighPerfGpuNet: return _create_hgnet('hgnet_base', pretrained=pretrained, **kwargs) @register_model -def hgnetv2_b0(pretrained=False, **kwargs) -> PPHGNet: +def hgnetv2_b0(pretrained=False, **kwargs) -> HighPerfGpuNet: return _create_hgnet('hgnetv2_b0', pretrained=pretrained, use_lab=True, **kwargs) @register_model -def hgnetv2_b1(pretrained=False, **kwargs) -> PPHGNet: +def hgnetv2_b1(pretrained=False, **kwargs) -> HighPerfGpuNet: return _create_hgnet('hgnetv2_b1', pretrained=pretrained, use_lab=True, **kwargs) @register_model -def hgnetv2_b2(pretrained=False, **kwargs) -> PPHGNet: +def hgnetv2_b2(pretrained=False, **kwargs) -> HighPerfGpuNet: return _create_hgnet('hgnetv2_b2', pretrained=pretrained, use_lab=True, **kwargs) @register_model -def hgnetv2_b3(pretrained=False, **kwargs) -> PPHGNet: +def hgnetv2_b3(pretrained=False, **kwargs) -> HighPerfGpuNet: return _create_hgnet('hgnetv2_b3', pretrained=pretrained, use_lab=True, **kwargs) @register_model -def hgnetv2_b4(pretrained=False, **kwargs) -> PPHGNet: +def hgnetv2_b4(pretrained=False, **kwargs) -> HighPerfGpuNet: return _create_hgnet('hgnetv2_b4', pretrained=pretrained, **kwargs) @register_model -def hgnetv2_b5(pretrained=False, **kwargs) -> PPHGNet: +def hgnetv2_b5(pretrained=False, **kwargs) -> HighPerfGpuNet: return _create_hgnet('hgnetv2_b5', pretrained=pretrained, **kwargs) @register_model -def hgnetv2_b6(pretrained=False, **kwargs) -> PPHGNet: +def hgnetv2_b6(pretrained=False, **kwargs) -> HighPerfGpuNet: return _create_hgnet('hgnetv2_b6', pretrained=pretrained, **kwargs)