## Resnet demo

Build a resnet  from scratch and compare to torchvision model.

In [1]:
import torch.utils.model_zoo as model_zoo
import torchvision

from gamma import *
from gamma.pytorch import *
from functools import reduce

#### 1. Load pre-trained weights from torch model zoo

In [2]:
model_name = 'resnet50'

In [3]:
state_dict = model_zoo.load_url(torchvision.models.resnet.model_urls[model_name])
state = {n: constant(n, [v.data, tuple(v.size())], []) for n, v in state_dict.items()}
draw(state, sep='.', direction='TB', scale=0.7)

#### 2. Network description

In [4]:
basic_block = node('basic_block', ['in_channels', 'out_channels', 'stride'])
bottleneck_block = node('bottleneck_block', ['in_channels', 'h1_channels', 'h2_channels', 'out_channels', 'stride'])

model_config = {
    'resnet18': [
            [(basic_block, 64,  64,  1), (basic_block, 64,  64,  1)],
            [(basic_block, 64,  128, 2), (basic_block, 128, 128, 1)],
            [(basic_block, 128, 256, 2), (basic_block, 256, 256, 1)],
            [(basic_block, 256, 512, 2), (basic_block, 512, 512, 1)],
    ],
    'resnet34': [
            [(basic_block, 64,  64,  1)] + [(basic_block, 64,  64,  1)]*2,
            [(basic_block, 64,  128, 2)] + [(basic_block, 128, 128, 1)]*3,
            [(basic_block, 128, 256, 2)] + [(basic_block, 256, 256, 1)]*5,
            [(basic_block, 256, 512, 2)] + [(basic_block, 512, 512, 1)]*2,
    ],
    'resnet50': [
            [(bottleneck_block, 64, 64, 64, 256, 1)] + [(bottleneck_block, 256, 64, 64, 256, 1)]*2,
            [(bottleneck_block, 256, 128, 128, 512, 2)] + [(bottleneck_block, 512, 128, 128, 512, 1)]*3,
            [(bottleneck_block, 512, 256, 256, 1024, 2)] + [(bottleneck_block, 1024, 256, 256, 1024, 1)]*5,
            [(bottleneck_block, 1024, 512, 512, 2048, 2)] + [(bottleneck_block, 2048, 512, 512, 2048, 1)]*2
    ],
    'resnet101': [
            [(bottleneck_block, 64, 64, 64, 256, 1)] + [(bottleneck_block, 256, 64, 64, 256, 1)]*2,
            [(bottleneck_block, 256, 128, 128, 512, 2)] + [(bottleneck_block, 512, 128, 128, 512, 1)]*3,
            [(bottleneck_block, 512, 256, 256, 1024, 2)] + [(bottleneck_block, 1024, 256, 256, 1024, 1)]*22,
            [(bottleneck_block, 1024, 512, 512, 2048, 2)] + [(bottleneck_block, 2048, 512, 512, 2048, 1)]*2
    ],
    'resnet152': [
            [(bottleneck_block, 64, 64, 64, 256, 1)] + [(bottleneck_block, 256, 64, 64, 256, 1)]*2,
            [(bottleneck_block, 256, 128, 128, 512, 2)] + [(bottleneck_block, 512, 128, 128, 512, 1)]*7,
            [(bottleneck_block, 512, 256, 256, 1024, 2)] + [(bottleneck_block, 1024, 256, 256, 1024, 1)]*35,
            [(bottleneck_block, 1024, 512, 512, 2048, 2)] + [(bottleneck_block, 2048, 512, 512, 2048, 1)]*2
    ]

}

blocks = [block_type((f'layer{i}', f'{j}'), params) for (i, layer) in enumerate(model_config[model_name], 1) 
          for j, (block_type, *params) in enumerate(layer)]

In [5]:
num_classes = 1000
prep = [
    conv('conv1', [3, 64, (7, 7), 2, 3, False]),
    bn('bn1', [64]),
    relu('relu1', [True]),
    max_pool('maxpool1', [(3, 3), 2, 1])
]

classifier = [
    global_avg_pool('avgpool', []),
    linear('fc', [blocks[-1]['params']['out_channels'], num_classes])
]

net_initial = pipeline(prep + blocks + classifier)
draw(net_initial)

#### 3. Rewrite rules

In [6]:
_in, _out, _0, _1, _2, _3, _4, _5, _6, _7, _8, _9 = var('in'), var('out'), *map(var, range(10))

shortcut  = node('shortcut',      ['in_channels', 'out_channels', 'stride'])

#1. expand_blocks
@bind_vars
def expand_blocks(block_name, in_channels, out_channels, stride):
    LHS = {_out: basic_block(block_name, [in_channels, out_channels, stride], [_in]),}
  
    RHS = add_prefix(block_name, {
        _0: conv('conv1', [in_channels, out_channels, (3, 3), stride, 1, False], [_in]),   
        _1: bn('bn1', [out_channels], [_0]),
        _2: relu('relu1', [True], [_1]),
        _3: conv('conv2', [out_channels, out_channels, (3, 3), 1, 1, False], [_2]),
        _4: bn('bn2', [out_channels], [_3]),
        _5: shortcut('shortcut', [in_channels, out_channels, stride], [_in]),
        _6: add('add', [True], [_4, _5]),     
        _out: relu('relu2', [True], [_6]),
    })
  
    return LHS, RHS

@bind_vars
def expand_bn_blocks(block_name, in_channels, h1_channels, h2_channels, out_channels, stride):
    LHS = {_out: bottleneck_block(block_name, [in_channels, h1_channels, h2_channels, out_channels, stride], [_in]),}
  
    RHS = add_prefix(block_name, {
        _0: conv('conv1', [in_channels, h1_channels, (1, 1), 1, 0, False], [_in]),   
        _1: bn('bn1', [h1_channels], [_0]),
        _2: relu('relu1', [True], [_1]),
        _3: conv('conv2', [h1_channels, h2_channels, (3, 3), stride, 1, False], [_2]),
        _4: bn('bn2', [h2_channels], [_3]),
        _5: relu('relu2', [True], [_4]),
        _6: conv('conv3', [h2_channels, out_channels, (1, 1), 1, 0, False], [_5]),
        _7: bn('bn3', [out_channels], [_6]),
        
        _8: shortcut('shortcut', [in_channels, out_channels, stride], [_in]),
        _9: add('add', [True], [_7, _8]),     
        _out: relu('relu3', [True], [_9]),
    })
  
    return LHS, RHS


#2. replace_shortcuts
@bind_vars
def replace_id_shortcuts(block_name, in_channels):  
    LHS = {_out: shortcut((block_name, 'shortcut'), [in_channels, in_channels, 1], [_in])}
    RHS = {_out: identity((block_name, 'identity'), [], [_in])}
    return LHS, RHS

@bind_vars
def replace_shortcuts(block_name, in_channels, out_channels, stride):
    LHS = {_out: shortcut((block_name, 'shortcut'), [in_channels, out_channels, stride], [_in])}
    RHS = {
        _0: conv(((block_name, 'downsample'), '0'), [in_channels, out_channels, (1, 1), stride, 0, False], [_in]),
        _out: bn(((block_name, 'downsample'), '1'), [out_channels], [_0]),  
    } 
    return LHS, RHS

rules = [expand_blocks(), expand_bn_blocks(), replace_id_shortcuts(), replace_shortcuts()]

for rule in rules: draw(rule, scale=0.7)

In [7]:
net = reduce(apply_rule, rules, net_initial)
net = reindex(net, {n: path_str(a['label'], '.') for n, a in net.items()})
draw(net, scale=0.7)

#### 4. Build torch module and load weights

In [8]:
model = TorchGraph(net).train()
load_state(model, state_dict)

#### 5. Test

In [9]:
from numpy.testing import assert_almost_equal

torch_model = getattr(torchvision.models, model_name)(pretrained=True).train()
inputs = torch.rand((16, 3, 224, 224))

output = model({'input': inputs})
torch_output = torch_model(inputs)

assert_almost_equal(to_numpy(output['fc']), to_numpy(torch_output), decimal=6)
print('Success!')

Success!
