In [None]:
# default_exp resnetx

In [None]:
#export
from wong.imports import *
from wong.core import *
from wong.config import cfg, assert_cfg

from torchvision.models.utils import load_state_dict_from_url


In [None]:
from fastcore.all import *  # test_eq

# ResNetX
> a folded resnet

The key distinguishing feature of our proposed architecture is the use of concatenation-skip (addition(additive)-skip) connections like DenseNet (ResNet), but with selective long-range and short range skip connections rather than a dense connectivity.

Despite various parameter-efficient depthwise-convolution-based designs, for GPU-based deployment ResNet architecture provide a comparable or better speed-accuracy trade-off.

Ref:

XNect: Real-time Multi-person 3D Human Pose Estimation with a Single RGB Camera


The proposed networks reduces computations by 20% with equivalent or even superior accuary on the ImageNet dataset, and significantly outperforms state-of-the-art approaches in terms of AP_50 on the MS COCO object detection dataset. 

Ref:
CSPNet: A new backbone that can enhance learning capability of CNN

is more accurate and more computationally efficient than the state of art ResNets networks.

which achieve much better accuracy and efficiency than previous ConvNets.

A residual network with multiple direct paths

In order to compare ResNetX with ResNet, we using ablation method. As ResNet is an special ResNetX when fold=1, we first express ResNet as ResNetX, then we change fold from 1, 2, 3, 4 to evaluate its performance. We first use transfer learning, we got pre-trained model of resnet152, then we fill the weights of ResNetX model with pretrained model, then fine tuning them, we got an better result ; Second method is training the model from scratch, we 

https://petewarden.com/2017/10/29/how-do-cnns-deal-with-position-differences/

As you go deeper into a network, the number of channels will typically increase, but the size of the image will shrink. This shrinking is done using pooling layers, traditionally with average pooling but more commonly using maximum pooling these days.

In [None]:
#export
def get_pred(l:int, d:int=1, start_id:int=None, end_id:int=None):
    "get predecessor layer id."
    if start_id is None: start_id = d
    if end_id is None: end_id = l
    assert l >= 1 and start_id >= d and end_id > start_id
    if l < start_id or l > end_id or d == 1:  # if the current layer index is less than the fold depth, or if fold depth == 1
        pred = l - 1
    else:
        remainder = (l-1-(start_id-d)) % (d-1)
        pred = l - 2 * (1+remainder)
    return pred

Parameters:
- l : current layer id.
- start_id : index of the starting node
- end_id : index of the ending node
- d : fold depth.

Return:
- The previous layer id that directly link to the current layer.


\begin{equation}\label{eq:resnetx}
   i = 
   \left\{
      \begin{array}{ll}
      1 & l < d \lor d=1 ; \\
      2 * (1 + (l-1) \pmod{d-1}) & \textrm{else} .
      \end{array}
      \right.
\end{equation}


In [None]:
get_pred(l=17, d=2, start_id=13)

15

In [None]:
get_pred(l=50, d=5, start_id=8)

44

In [None]:
test_eq(get_pred(l=12, d=1, start_id=1), 11)

test_eq(get_pred(l=8, d=5, start_id=7), 4)
test_eq(get_pred(l=12, d=4, start_id=6), 10)


In [None]:
#export
def layer_diff(cur:int, pred:int, num_nodes:tuple):
    "layer difference between the current layer and the predecessor layer."
    assert cur > pred
    num_nodes = (1,) + num_nodes
    cumsum = 0  # start with 0
    for i, num in enumerate(num_nodes):
        if cumsum <= cur < cumsum + num:
            cur_layer = i
        if cumsum <= pred < cumsum + num:
            pred_layer = i
        cumsum += num
    diff = cur_layer - pred_layer
    return diff

In [None]:
num_nodes = (3,4,6,3)
cur, pred = 4,0
layer_diff(cur, pred, num_nodes)

2

Parameters:
- Stem : the stemming stage, which accept original images, transform them, then input into the backbone network.
- Unit : the operation at nodes.
- Conn : the connections between nodes
- fold : the fold depth
- ni : number of input channels of the backbone network.
- *num_stages : number of stages in the backbone network.*
- num_nodes : number of nodes of every stage in the backbone network.
- start_id : index of starting node of ResNetX
- base : standard width of channels in the backbone network.
- exp : expansion along with the increase of stages.
- bottle_scale : bottleneck scale
- first_downsample: dose down-sample at the start of the first stage.
- deep_stem : using 7x7 or 3 3x3 conv in stemming stage.
- c_in : number of input channels of the Start layer
- c_out : number of classes in the output of the final classifier.
- kwargs : arguments translate into `Unit`

In [None]:
#export
class ResNetX(nn.Module):
    "A folded resnet."
    def __init__(self, Stem, Unit, Conn, Tail, fold:int, ni:int, num_nodes:tuple, start_id:int=None, end_id:int=None,
                 base:int=64, exp:int=2, bottle_scale:int=1, first_downsample:bool=False,
                 c_in:int=3, c_out:int=10, **kwargs):
        super(ResNetX, self).__init__()
        # fold depth should be less than the sum length of any two neighboring stages
        
        if start_id < fold: start_id = fold
        origin_ni = ni
        num_stages = len(num_nodes)
        nhs = [base * exp ** i for i in range(num_stages)] 
        nos = [int(nh * bottle_scale) for nh in nhs]
        strides = [1 if i==0 and not first_downsample else 2 for i in range(num_stages)]
#         print('nhs=', nhs, 'nos=', nos, 'nus=', nus, 'strides=', strides)
        
        self.stem = Stem(c_in, no=ni) # , deep_stem
        
        units = []
        idmappings = []
        cur = 1
        for i, (nh, no, nu, stride) in enumerate(zip(nhs, nos, num_nodes, strides)):
            for j in range(nu):
                if j == 0: # the first node(layer) of each stage
                    units += [Unit(ni, no, nh, stride=stride, **kwargs)]
                else:
                    units += [Unit(no, no, nh, stride=1, **kwargs)]
                    
                pred = get_pred(cur, fold, start_id, end_id) # 
                diff = layer_diff(cur, pred, num_nodes)
                assert diff == 0 or diff == 1 or (diff == 2 and pred == 0), \
                       'cur={}, pred={}, diff={} is not allowed.'.format(cur, pred, diff)
#                 print('fold = {} , cur = {} , pred = {} ,diff = {}'.format(fold, cur, pred, diff))
                if diff == 0:
                    idmappings += [Conn(no, no, stride=1)]
                elif diff == 1:
                    idmappings += [Conn(ni, no, stride=stride)]
                elif diff == 2:
                    idmappings += [Conn(origin_ni, no, stride=stride)]
                cur += 1
            ni = no
        self.units = nn.ModuleList(units)
        self.idmappings = nn.ModuleList(idmappings)
        
        self.classifier = Tail(nos[-1], c_out)
        self.fold, self.start_id, self.end_id = fold, start_id, end_id
        self.num_nodes = num_nodes
        init_cnn(self)
        
    def forward(self, x):
        results = {}
        results[0] = self.stem(x)
        cur = 0
        for i, (unit, idmapping) in enumerate(zip(self.units, self.idmappings)):
            cur += 1
            pred = get_pred(cur, self.fold, self.start_id, self.end_id)
            diff = layer_diff(cur, pred, self.num_nodes)
            if diff == 0:
                results[cur % (2*self.fold-1)] = unit(results[(cur-1) % (2*self.fold-1)]) + idmapping(results[pred % (2*self.fold-1)])
            else:
                results[cur % (2*self.fold-1)] = unit(results[(cur-1) % (2*self.fold-1)])
        x = results[cur % (2*self.fold-1)]

        x = self.classifier(x)
        return x
        
    def my_load_state_dict(self, state_dict, local_to_pretrained):
        error_msgs = []
        def load(module, prefix=''):
            local_name_params = itertools.chain(module._parameters.items(), module._buffers.items())
            local_state = {k: v.data for k, v in local_name_params if v is not None}

            new_prefix = local_to_pretrained.get(prefix, 'none')
            for name, param in local_state.items():
                key = new_prefix + name
                if key in state_dict:
#                     print(key)
                    input_param = state_dict[key]

                    if input_param.shape != param.shape:
                        # local shape should match the one in checkpoint
                        error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
                                          'the shape in current model is {}.'
                                          .format(key, input_param.shape, param.shape))
                        continue

                    try:
                        param.copy_(input_param)
                    except Exception:
                        error_msgs.append('While copying the parameter named "{}", '
                                          'whose dimensions in the model are {} and '
                                          'whose dimensions in the checkpoint are {}.'
                                          .format(key, param.size(), input_param.size()))
                    
            for name, child in module._modules.items():
                if child is not None:
                    load(child, prefix + name + '.')
        load(self)
        load = None # break load->load reference cycle
            
        

In [None]:
#export
def resnet_local_to_pretrained(num_nodes, fold, start_id, end_id):
    "mapping from local state_dict to pretrained state_dict. the pretrained model is restricted to torchvision.models.resnet."
    local_to_pretrained = {  # mapping from the names of local modules to the names of pretrained modules
        'stem.0.': 'conv1.',
        'stem.1.': 'bn1.',
    }

    cumsum = 0
    for i, num in enumerate(num_nodes):
        for j in range(num):
            key = 'units.' + str(cumsum + j) + '.'
            value = 'layer' + str(i+1) + '.' + str(j) + '.'
            downsample0 = 'layer' + str(i+1) + '.0.' + 'downsample.0.'
            downsample1 = 'layer' + str(i+1) + '.0.' + 'downsample.1.'

            pred = get_pred(cumsum + j + 1, fold, start_id, end_id) # 
            diff = layer_diff(cumsum + j + 1, pred, num_nodes)
            if diff == 1:
                idmapping0 = 'idmappings.' + str(cumsum + j) + '.unit.0.'
                idmapping1 = 'idmappings.' + str(cumsum + j) + '.unit.1.'
#                     print(idmapping0, downsample0)
#                     print(idmapping1, downsample1)
                local_to_pretrained[idmapping0] = downsample0
                local_to_pretrained[idmapping1] = downsample1

            for a, b in zip(['1.','2.','4.','5.','7.','8.'], ['conv1.','bn1.','conv2.','bn2.','conv3.','bn3.']):
#                     print (key + a, value + b)
                local_to_pretrained[key + a] = value + b

        cumsum += num
    
    return local_to_pretrained


Three priority levels to set configuration:
- `default_cfg` the default configuration, which set all the option names and their default values
- `cfg_file` the configuration file, which will override the default configuration
- `cfg_list` the configuration list, which will override all the previous configurations.

In [None]:
#export
def resnetx(default_cfg:dict, cfg_file:str=None, cfg_list:list=None, pretrained:bool=False, **kwargs):
    "wrapped resnetx"
    assert default_cfg.__class__.__module__ == 'yacs.config' and default_cfg.__class__.__name__ == 'CfgNode' 
    cfg = default_cfg
    if cfg_file is not None: cfg.merge_from_file(cfg_file)
    if cfg_list is not None: cfg.merge_from_list(cfg_list)
    assert_cfg(cfg)
    cfg.freeze()
    
    Stem = getattr(sys.modules[__name__], cfg.GRAPH.STEM)
    Unit = getattr(sys.modules[__name__], cfg.GRAPH.UNIT)
    Conn = getattr(sys.modules[__name__], cfg.GRAPH.CONN)
    Tail = getattr(sys.modules[__name__], cfg.GRAPH.TAIL)
    # start_id >= fold + 1, fold <= 6
    model = ResNetX(Stem=Stem, Unit=Unit, Conn=Conn, Tail=Tail, fold=cfg.GRAPH.FOLD, ni=cfg.GRAPH.NI, num_nodes=cfg.GRAPH.NUM_NODES, 
                    start_id=cfg.GRAPH.START_ID, end_id=cfg.GRAPH.END_ID, base=cfg.GRAPH.BASE, exp=cfg.GRAPH.EXP, bottle_scale=cfg.GRAPH.BOTTLE_SCALE,
                    first_downsample=cfg.GRAPH.FIRST_DOWNSAMPLE, **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(cfg.URL)
        local_to_pretrained = resnet_local_to_pretrained(cfg.GRAPH.NUM_NODES, cfg.GRAPH.FOLD,cfg.GRAPH.START_ID,cfg.GRAPH.END_ID)
        model.my_load_state_dict(state_dict, local_to_pretrained)
        for param in model.parameters(): # freeze all
            param.requires_grad = False
    return model

In [None]:
cfg

CfgNode({'URL': '', 'GRAPH': CfgNode({'NUM_STAGES': 4, 'NUM_NODES': (3, 8, 36, 3), 'NUM_CHANNELS': (64, 128, 256, 512), 'STEM': '', 'UNIT': '', 'CONN': '', 'TAIL': '', 'FOLD': 1, 'START_ID': 0, 'END_ID': 0, 'NI': 64, 'BASE': 64, 'EXP': 2, 'BOTTLE_SCALE': 4.0, 'FIRST_DOWNSAMPLE': False, 'DEEP_STEM': False})})

In [None]:
num_nodes = (3, 8, 36, 3)
num_all_nodes = sum(num_nodes)
fold = 2
start_id = num_nodes[0] + num_nodes[1] + fold + 1 
end_id = num_nodes[0] + num_nodes[1] + num_nodes[0] + num_nodes[2] - 3 
cfg_list = ["GRAPH.STEM", "resnet_stem",
            "GRAPH.UNIT", "mbconv",  # resnet_bottleneck
            "GRAPH.CONN", "IdentityMapping",
            "GRAPH.TAIL", "Classifier",
            "GRAPH.NUM_NODES", num_nodes,
            "GRAPH.FOLD", fold,
            "GRAPH.START_ID", start_id,
            "GRAPH.END_ID", end_id,
            "GRAPH.NI", 64,
            "GRAPH.BASE", 64,
            "GRAPH.EXP", 2,
            "GRAPH.BOTTLE_SCALE", 0.5, # 4
            "URL", 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
           ]
model = resnetx(cfg, cfg_list=cfg_list, pretrained=False, c_out=100, ks=5)

In [None]:
model

ResNetX(
  (stem): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  )
  (units): ModuleList(
    (0): Sequential(
      (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=64, bias=False)
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
      (6): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (7): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): Sequential(
      (0): Conv2d(32, 64, ke

In [None]:
# cfg_file = 'configs/imagenet/resnet/resnet152.yaml'
# model = resnetx(cfg, cfg_file=cfg_file, pretrained=False, c_out=100)

**Tip** : Three methods to get `class` or `function` object from its string name:

- `getattr(sys.modules[__name__], cfg.GRAPH.STEM)`
- `globals()[cfg.GRAPH.STEM]`
- `eval(cfg.GRAPH.STEM)`

In [None]:
x = torch.randn(2,3,64,64)

In [None]:
with torch.autograd.set_detect_anomaly(True):
    out = model(x)
    out.mean().backward()

In [None]:
"{:,}".format(num_params(model))

'60,225,700'

## Load Pretrained Models

In [None]:
from torchvision.models import resnet152

In [None]:
m_resnet152 = resnet152()

In [None]:
model.__module__

'__main__'