## Resnet demo

Build a resnet  from scratch and compare to torchvision model.

In [1]:
import torch.utils.model_zoo as model_zoo
import torchvision

from gamma import *
from gamma.pytorch import *

#### 1. Load pre-trained weights from torch model zoo

In [2]:
model_name = 'resnet50'

In [3]:
state_dict = model_zoo.load_url(torchvision.models.resnet.model_urls[model_name])
name_rules = [
    ('conv1.{}', 'prep.conv.{}'),
    ('bn1.{}',   'prep.bn.{}'),
    ('layer{}.downsample.0.{}', 'layer{}.downsample.conv.{}'),
    ('layer{}.downsample.1.{}', 'layer{}.downsample.bn.{}'),
    ('layer{}.conv{k}.{}', 'layer{}.block{k}.conv.{}'),
    ('layer{}.bn{k}.{}', 'layer{}.block{k}.bn.{}'),
    ('fc.{}', 'fc.{}')
]

state_dict = rename(state_dict, name_rules)
state = {n: constant(n, [v.data, tuple(v.size())], []) for n, v in state_dict.items()}
draw(state, sep='.', direction='TB')

#### 2. Network description

In [4]:
block = node(namedtuple('Block', ['in_channels', 'out_channels', 'stride', 'h_channels']), h_channels=None)

model_config = {
    'resnet18': [
            [(64,  64,  1), (64,  64,  1)],
            [(64,  128, 2), (128, 128, 1)],
            [(128, 256, 2), (256, 256, 1)],
            [(256, 512, 2), (512, 512, 1)],
    ],
    'resnet34': [
            [(64,  64,  1)]*3,
            [(64,  128, 2)] + [(128, 128, 1)]*3,
            [(128, 256, 2)] + [(256, 256, 1)]*5,
            [(256, 512, 2)] + [(512, 512, 1)]*2,
    ],
    'resnet50': [
            [(64, 256, 1, 64)] + [(256, 256, 1, 64)]*2,
            [(256, 512, 2, 128)] + [(512, 512, 1, 128)]*3,
            [(512, 1024, 2, 256)] + [(1024, 1024, 1, 256)]*5,
            [(1024, 2048, 2, 512)] + [(2048, 2048, 1, 512)]*2
    ],
    'resnet101': [
            [(64, 256, 1, 64)] + [(256, 256, 1, 64)]*2,
            [(256, 512, 2, 128)] + [(512, 512, 1, 128)]*3,
            [(512, 1024, 2, 256)] + [(1024, 1024, 1, 256)]*22,
            [(1024, 2048, 2, 512)] + [(2048, 2048, 1, 512)]*2
    ],
    'resnet152': [
            [(64, 256, 1, 64)] + [(256, 256, 1, 64)]*2,
            [(256, 512, 2, 128)] + [(512, 512, 1, 128)]*7,
            [(512, 1024, 2, 256)] + [(1024, 1024, 1, 256)]*35,
            [(1024, 2048, 2, 512)] + [(2048, 2048, 1, 512)]*2
    ]
}

blocks = [block((f'layer{i}', f'{j}'), *params) for (i, layer) in enumerate(model_config[model_name], 1) 
          for j, params in enumerate(layer)]

In [5]:
num_classes = 1000
act = F.relu

prep = [
    conv_bn('prep', 3, 64, (7, 7), stride=2, padding=3, activation=act, inputs_=['input']),
    max_pool('maxpool', (3, 3), stride=2, padding=1)
]

classifier = [
    global_avg_pool('avgpool'),
    linear('fc', blocks[-1]['params']['out_channels'], num_classes)
]

net_initial = pipeline(prep + blocks + classifier)
draw(net_initial)

#### 3. Rewrite rules

In [6]:
_in, _out, _0, _1, _2, _3, _4, _5, _6 = var('in'), var('out'), *map(var, range(7))
#1. add_shortcuts
@bind_vars
def add_shortcuts(block_name, in_channels, out_channels, stride, h_channels):
    LHS = {_out: block(block_name, in_channels, out_channels, stride, h_channels, inputs_=[_in])}
    RHS = {
        _0: block(block_name, in_channels, out_channels, stride, h_channels, inputs_=[_in]),
        _1: shortcut((block_name, 'shortcut'), in_channels, out_channels, stride, inputs_=[_in]),
        _2: add((block_name, 'add'), True, inputs_=[_0, _1]),
        _out: activation_func((block_name, 'act'), act, inputs_=[_2]),
    }
    return LHS, RHS

@bind_vars
def replace_identity_shortcuts(block_name, in_channels):  
    LHS = {_out: shortcut((block_name, 'shortcut'), in_channels, in_channels, 1, inputs_=[_in])}
    RHS = {_out: shortcut((block_name, 'shortcut'), in_channels, in_channels, 1, True, inputs_=[_in])}
    return LHS, RHS

@bind_vars
def replace_shortcuts(block_name, in_channels, out_channels, stride):
    LHS = {_out: shortcut((block_name, 'shortcut'), in_channels, out_channels, stride, inputs_=[_in])}
    RHS = {_out: conv_bn((block_name, 'downsample'), in_channels, out_channels, (1, 1), stride, inputs_=[_in])}
    return LHS, RHS


#2.expand_blocks
@bind_vars
def expand_basic_blocks(block_name, in_channels, out_channels, stride):
    LHS = {_out: block(block_name, in_channels, out_channels, stride, inputs_=[_in])}  
    RHS = add_prefix(block_name, {
        _0: conv_bn('block1', in_channels, out_channels, (3, 3), stride=stride, padding=1, activation=act, inputs_=[_in]),   
        _out: conv_bn('block2', out_channels, out_channels, (3, 3), stride=1, padding=1, inputs_=[_0]),
    })
    return LHS, RHS

                                          
@bind_vars
def expand_blocks(block_name, in_channels, out_channels, stride, h_channels):
    LHS = {_out: block(block_name, in_channels,out_channels, stride, h_channels, inputs_=[_in])}
    RHS = add_prefix(block_name, {
        _0: conv_bn('block1', in_channels, h_channels, (1,1), stride=1, activation=act, inputs_=[_in]),
        _1: conv_bn('block2', h_channels, h_channels, (3,3), stride=stride, padding=1, activation=act, inputs_=[_0]),
        _out: conv_bn('block3', h_channels, out_channels, (1,1), stride=1, inputs_=[_1]),
    })
    return LHS, RHS


rules = [add_shortcuts(), replace_identity_shortcuts(), replace_shortcuts(), expand_basic_blocks(), expand_blocks(), expand_conv_bns()]
for rule in rules: draw(rule)

In [7]:
net = apply_rules(net_initial, rules)
draw(net)

#### 4. Build torch module and load weights

In [8]:
model = TorchGraph(index_by_labels(net, '.')).train()
load_state(model, state_dict)

#### 5. Test

In [9]:
from numpy.testing import assert_almost_equal

torch_model = getattr(torchvision.models, model_name)(pretrained=True).train()
inputs = torch.rand((2, 3, 224, 224))

output = model({'input': inputs})
torch_output = torch_model(inputs)

assert_almost_equal(to_numpy(output['fc']), to_numpy(torch_output), decimal=6)
print('Success!')

Success!
