Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 25 additions & 29 deletions timm/models/hrnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,16 @@
from __future__ import division
from __future__ import print_function

import os
import logging
import functools

import torch
import torch.nn as nn
import torch._utils
import torch.nn.functional as F

from .resnet import BasicBlock, Bottleneck # leveraging ResNet blocks w/ additional features like SE
from .registry import register_model
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import load_pretrained
from .layers import SelectAdaptivePool2d
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .registry import register_model
from .resnet import BasicBlock, Bottleneck # leveraging ResNet blocks w/ additional features like SE

_BN_MOMENTUM = 0.1
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -101,7 +97,7 @@ def _cfg(url='', **kwargs):
),
),

hrnet_w18_small_v2 = dict(
hrnet_w18_small_v2=dict(
STEM_WIDTH=64,
STAGE1=dict(
NUM_MODULES=1,
Expand Down Expand Up @@ -137,7 +133,7 @@ def _cfg(url='', **kwargs):
),
),

hrnet_w18 = dict(
hrnet_w18=dict(
STEM_WIDTH=64,
STAGE1=dict(
NUM_MODULES=1,
Expand Down Expand Up @@ -173,7 +169,7 @@ def _cfg(url='', **kwargs):
),
),

hrnet_w30 = dict(
hrnet_w30=dict(
STEM_WIDTH=64,
STAGE1=dict(
NUM_MODULES=1,
Expand Down Expand Up @@ -209,7 +205,7 @@ def _cfg(url='', **kwargs):
),
),

hrnet_w32 = dict(
hrnet_w32=dict(
STEM_WIDTH=64,
STAGE1=dict(
NUM_MODULES=1,
Expand Down Expand Up @@ -245,7 +241,7 @@ def _cfg(url='', **kwargs):
),
),

hrnet_w40 = dict(
hrnet_w40=dict(
STEM_WIDTH=64,
STAGE1=dict(
NUM_MODULES=1,
Expand Down Expand Up @@ -281,7 +277,7 @@ def _cfg(url='', **kwargs):
),
),

hrnet_w44 = dict(
hrnet_w44=dict(
STEM_WIDTH=64,
STAGE1=dict(
NUM_MODULES=1,
Expand Down Expand Up @@ -317,7 +313,7 @@ def _cfg(url='', **kwargs):
),
),

hrnet_w48 = dict(
hrnet_w48=dict(
STEM_WIDTH=64,
STAGE1=dict(
NUM_MODULES=1,
Expand Down Expand Up @@ -353,7 +349,7 @@ def _cfg(url='', **kwargs):
),
),

hrnet_w64 = dict(
hrnet_w64=dict(
STEM_WIDTH=64,
STAGE1=dict(
NUM_MODULES=1,
Expand Down Expand Up @@ -456,7 +452,7 @@ def _make_branches(self, num_branches, block, num_blocks, num_channels):

def _make_fuse_layers(self):
if self.num_branches == 1:
return None
return nn.Identity()

num_branches = self.num_branches
num_inchannels = self.num_inchannels
Expand All @@ -470,7 +466,7 @@ def _make_fuse_layers(self):
nn.BatchNorm2d(num_inchannels[i], momentum=_BN_MOMENTUM),
nn.Upsample(scale_factor=2 ** (j - i), mode='nearest')))
elif j == i:
fuse_layer.append(None)
fuse_layer.append(nn.Identity())
else:
conv3x3s = []
for k in range(i - j):
Expand Down Expand Up @@ -619,7 +615,7 @@ def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer)
nn.BatchNorm2d(num_channels_cur_layer[i], momentum=_BN_MOMENTUM),
nn.ReLU(inplace=True)))
else:
transition_layers.append(None)
transition_layers.append(nn.Identity())
else:
conv3x3s = []
for j in range(i + 1 - num_branches_pre):
Expand Down Expand Up @@ -686,8 +682,11 @@ def get_classifier(self):
def reset_classifier(self, num_classes, global_pool='avg'):
self.num_classes = num_classes
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.classifier = nn.Linear(
self.num_features * self.global_pool.feat_mult(), num_classes) if num_classes else None
num_features = self.num_features * self.global_pool.feat_mult()
if num_classes:
self.classifier = nn.Linear(num_features, num_classes)
else:
self.classifier = nn.Identity()

def forward_features(self, x):
x = self.conv1(x)
Expand All @@ -699,24 +698,21 @@ def forward_features(self, x):
x = self.layer1(x)

x_list = []
for i in range(self.stage2_cfg['NUM_BRANCHES']):
if self.transition1[i] is not None:
x_list.append(self.transition1[i](x))
else:
x_list.append(x)
for i in range(len(self.transition1)):
x_list.append(self.transition1[i](x))
y_list = self.stage2(x_list)

x_list = []
for i in range(self.stage3_cfg['NUM_BRANCHES']):
if self.transition2[i] is not None:
for i in range(len(self.transition2)):
if not isinstance(self.transition2[i], nn.Identity):
x_list.append(self.transition2[i](y_list[-1]))
else:
x_list.append(y_list[i])
y_list = self.stage3(x_list)

x_list = []
for i in range(self.stage4_cfg['NUM_BRANCHES']):
if self.transition3[i] is not None:
for i in range(len(self.transition3)):
if not isinstance(self.transition3[i], nn.Identity):
x_list.append(self.transition3[i](y_list[-1]))
else:
x_list.append(y_list[i])
Expand Down