In [None]:
# default_exp resnetx

In [None]:
#export
from wong.imports import *
from wong.core import *
from wong.config import cfg

In [None]:
#export
def get_pred(l:int, d:int=1):
    "get predecessor layer id."
    assert l >= 1
    if l < d or d == 1:  # if the current layer index is less than the fold depth, or if fold depth == 1
        pred = l - 1
    else:
        remainder = l % (d-1)
        if remainder == 0:
            pred = l - 2 * (d-1)
        else:
            pred = l - 2 * remainder
#         remainder1 = l % (2*(d-1))
#         if 1 <= remainder1 <= d-1:
#             pred = l - 2 * remainder1
#         else:
#             remainder2 = (remainder1 + d-1) % (2*(d-1))
#             pred = l - 2 * remainder2
    return pred

Parameters:
- l : current layer id.
- d : fold depth.

Return:
- The previous layer id that directly link to the current layer.

In [None]:
test_eq(get_pred(l=12, d=1), 11)

test_eq(get_pred(l=12, d=4), 6)
test_eq(get_pred(l=11, d=4), 7)
test_eq(get_pred(l=10, d=4), 8)
test_eq(get_pred(l=9, d=4), 3)
test_eq(get_pred(l=8, d=4), 4)
test_eq(get_pred(l=7, d=4), 5)
test_eq(get_pred(l=6, d=4), 0)

test_eq(get_pred(l=4, d=3), 0)
test_eq(get_pred(l=5, d=3), 3)
test_eq(get_pred(l=6, d=3), 2)

In [None]:
def layer_diff(cur:int, pred:int, num_nodes:tuple):
    "layer difference between the current layer and the predecessor layer."
    assert cur > pred
    num_nodes = (1,) + num_nodes
    cumsum = 0  # start with 0
    for i, num in enumerate(num_nodes):
        if cumsum <= cur < cumsum + num:
            cur_layer = i
        if cumsum <= pred < cumsum + num:
            pred_layer = i
        cumsum += num
    diff = cur_layer - pred_layer
    return diff

In [None]:
num_nodes = (3,4,6,3)
cur, pred = 4,0
layer_diff(cur, pred, num_nodes)

2

Parameters:
- Start : the start layer, which accept original images, transform them, then input into the backbone network.
- Unit : the operation at nodes.
- fold : the fold depth
- ni : number of input channels of the backbone network.
- num_stages : number of stages in the backbone network.
- num_nodes : number of nodes of every stage in the backbone network.
- base : standard width of channels in the backbone network.
- exp : expansion along with the increase of stages.
- bottle_scale : bottleneck scale
- first_downsample: dose down-sample at the start of the first stage.
- c_in : number of input channels of the Start layer
- c_out : number of classes in the output of the final classifier.
- kwargs : arguments translate into `Unit`

In [None]:
#export
class ResNetX(nn.Module):
    "A folded resnet."
    def __init__(self, Start, Unit, fold:int, ni:int, num_nodes:tuple, base:int=64, exp:int=2, 
                 bottle_scale:int=1, first_downsample:bool=False, c_in:int=3, c_out:int=10, **kwargs):
        super(ResNetX, self).__init__()
        # fold depth should be less than the sum length of any two neighboring stages
        
        self.fold = fold
        origin_ni = ni
        num_stages = len(num_nodes)
        nhs = [base * exp ** i for i in range(num_stages)] 
        nos = [nh * bottle_scale for nh in nhs]
        strides = [1 if i==0 and not first_downsample else 2 for i in range(num_stages)]
#         print('nhs=', nhs, 'nos=', nos, 'nus=', nus, 'strides=', strides)
        
        self.start = Start(c_in, ni)
        
        units = []
        idmappings = []
        cur = 1
        for i, (nh, no, nu, stride) in enumerate(zip(nhs, nos, num_nodes, strides)):
            for j in range(nu):
                if j == 0: # the first node(layer) of each stage
                    units += [Unit(ni, no, nh, stride=stride, **kwargs)]
                else:
                    units += [Unit(no, no, nh, stride=1, **kwargs)]
                    
                pred = get_pred(cur, fold) # 
                diff = layer_diff(cur, pred, num_nodes)
                assert diff == 0 or diff == 1 or (diff == 2 and pred == 0), \
                       'cur={}, pred={}, diff={} is not allowed.'.format(cur, pred, diff)
                if diff == 0:
                    idmappings += [IdentityMapping(no, no, stride=1)]
                elif diff == 1:
                    idmappings += [IdentityMapping(ni, no, stride=stride)]
                elif diff == 2:
                    idmappings += [IdentityMapping(origin_ni, no, stride=stride)]
                cur += 1
            ni = no
        self.units = nn.ModuleList(units)
        self.idmappings = nn.ModuleList(idmappings)
        
        self.classifier = Classifier(nos[-1], c_out)
        init_cnn(self)
        
    def forward(self, x):
        results = {}
        results[0] = self.start(x)
        cur = 0
        for i, (unit, idmapping) in enumerate(zip(self.units, self.idmappings)):
            cur += 1
            pred = get_pred(cur, self.fold)
            results[cur % (2*self.fold-1)] = unit(results[(cur-1) % (2*self.fold-1)]) + idmapping(results[pred % (2*self.fold-1)])
        x = results[cur % (2*self.fold-1)]

        x = self.classifier(x)
        return x
        

In [None]:
model = ResNetX(Start = conv_bn, Unit = resnet_bottleneck, fold=5, ni=64, num_nodes=(3,8,36,3), base=64,
                exp=2, bottle_scale=4, first_downsample=False, zero_bn=True)

In [None]:
x = torch.randn(2,3,64,64)

In [None]:
with torch.autograd.set_detect_anomaly(True):
    out = model(x)
    out.mean().backward()

In [None]:
"{:,}".format(num_params(model))

'64,894,538'