## Resnet demo

Build a resnet  from scratch and compare to torchvision model.

In [1]:
import torch.utils.model_zoo as model_zoo
import torchvision

from gamma import *
from gamma.pytorch import *

#### 1. Load pre-trained weights from torch model zoo

In [2]:
model_name = 'resnet50'

In [3]:
state_dict = model_zoo.load_url(torchvision.models.resnet.model_urls[model_name])
name_rules = [
    ('conv1.{}', 'prep.conv.{}'),
    ('bn1.{}',   'prep.bn.{}'),
    ('layer{}.downsample.0.{}', 'layer{}.downsample.conv.{}'),
    ('layer{}.downsample.1.{}', 'layer{}.downsample.bn.{}'),
    ('layer{}.conv{k}.{}', 'layer{}.block{k}.conv.{}'),
    ('layer{}.bn{k}.{}', 'layer{}.block{k}.bn.{}'),
    ('fc.{}', 'fc.{}')
]

state_dict = rename(state_dict, name_rules)
state = {n: (constant([v.data, tuple(v.size())]), []) for n, v in state_dict.items()}
draw(state, sep='.', direction='TB')

#### 2. Network description

In [4]:
block = node(namedtuple('Block', ['in_channels', 'out_channels', 'stride', 'h_channels']), h_channels=None)

model_config = {
    'resnet18': [
            [(64,  64,  1), (64,  64,  1)],
            [(64,  128, 2), (128, 128, 1)],
            [(128, 256, 2), (256, 256, 1)],
            [(256, 512, 2), (512, 512, 1)],
    ],
    'resnet34': [
            [(64,  64,  1)]*3,
            [(64,  128, 2)] + [(128, 128, 1)]*3,
            [(128, 256, 2)] + [(256, 256, 1)]*5,
            [(256, 512, 2)] + [(512, 512, 1)]*2,
    ],
    'resnet50': [
            [(64, 256, 1, 64)] + [(256, 256, 1, 64)]*2,
            [(256, 512, 2, 128)] + [(512, 512, 1, 128)]*3,
            [(512, 1024, 2, 256)] + [(1024, 1024, 1, 256)]*5,
            [(1024, 2048, 2, 512)] + [(2048, 2048, 1, 512)]*2
    ],
    'resnet101': [
            [(64, 256, 1, 64)] + [(256, 256, 1, 64)]*2,
            [(256, 512, 2, 128)] + [(512, 512, 1, 128)]*3,
            [(512, 1024, 2, 256)] + [(1024, 1024, 1, 256)]*22,
            [(1024, 2048, 2, 512)] + [(2048, 2048, 1, 512)]*2
    ],
    'resnet152': [
            [(64, 256, 1, 64)] + [(256, 256, 1, 64)]*2,
            [(256, 512, 2, 128)] + [(512, 512, 1, 128)]*7,
            [(512, 1024, 2, 256)] + [(1024, 1024, 1, 256)]*35,
            [(1024, 2048, 2, 512)] + [(2048, 2048, 1, 512)]*2
    ]
}

blocks = [((f'layer{i}', f'{j}'), block(*params)) for (i, layer) in enumerate(model_config[model_name], 1) 
          for j, params in enumerate(layer)]

In [5]:
num_classes = 1000

net_initial = pipeline([
    ('prep', conv_bn(3, 64, (7, 7), stride=2, padding=3, activation=F.relu), ['input']),
    ('maxpool', max_pool((3, 3), stride=2, padding=1)),
    *blocks,
    ('avgpool', global_avg_pool()),
    ('fc', linear(blocks[-1][1]['params']['out_channels'], num_classes))
])

draw(net_initial)

#### 3. Rewrite rules

In [6]:
#1. add_shortcuts
@bind_vars
def add_shortcuts(block_name, in_channels, out_channels, stride, h_channels, _in):
    LHS = {block_name: (block(in_channels, out_channels, stride, h_channels), [_in])}
    RHS = pipeline([
        ((block_name, 'block'), block(in_channels, out_channels, stride, h_channels), [_in]),
        ((block_name, 'shortcut'), shortcut(in_channels, out_channels, stride), [_in]),
        ((block_name, 'add'), add(True), [(block_name, 'block'), (block_name, 'shortcut')]),
        ((block_name, 'act'), activation_func(F.relu)) 
    ])
    return LHS, RHS, (block_name, (block_name, 'act'))

@bind_vars
def replace_identity_shortcuts(block_name, in_channels, _in):  
    LHS = {(block_name, 'shortcut'): (shortcut(in_channels, in_channels, 1), [_in])}
    RHS = {(block_name, 'shortcut'): (shortcut(in_channels, in_channels, 1, True), [_in])}
    return LHS, RHS

@bind_vars
def replace_shortcuts(block_name, in_channels, out_channels, stride, _in):
    LHS = {(block_name, 'shortcut'): (shortcut(in_channels, out_channels, stride), [_in])}
    RHS = {(block_name, 'downsample'): (conv_bn(in_channels, out_channels, (1, 1), stride), [_in])}
    return LHS, RHS, ((block_name, 'shortcut'), (block_name, 'downsample'))


#2.expand_blocks
@bind_vars
def expand_2_blocks(block_name, in_channels, out_channels, stride, _in):
    LHS = {(block_name, 'block'): (block(in_channels, out_channels, stride), [_in])}  
    RHS = pipeline([
        ((block_name, 'block1'), conv_bn(in_channels, out_channels, (3, 3), stride=stride, padding=1, activation=F.relu), [_in]),
        ((block_name, 'block2'), conv_bn(out_channels, out_channels, (3, 3), stride=1, padding=1))
    ])
    return LHS, RHS, ((block_name, 'block'), (block_name, 'block2'))

                                          
@bind_vars
def expand_3_blocks(block_name, in_channels, out_channels, stride, h_channels, _in):
    LHS = {(block_name, 'block'): (block(in_channels,out_channels, stride, h_channels), [_in])}
    RHS = pipeline([
        ((block_name, 'block1'), conv_bn(in_channels, h_channels, (1, 1), stride=1, activation=F.relu), [_in]),
        ((block_name, 'block2'), conv_bn(h_channels, h_channels, (3,3), stride=stride, padding=1, activation=F.relu)),
        ((block_name, 'block3'), conv_bn(h_channels, out_channels, (1,1), stride=1)) 
    ])
    return LHS, RHS, ((block_name, 'block'), (block_name, 'block3'))

rules = [add_shortcuts(), replace_identity_shortcuts(), replace_shortcuts(), expand_2_blocks(), expand_3_blocks(), expand_conv_bns()]
for rule in rules: draw(rule)

In [7]:
net = apply_rules(net_initial, rules)
draw(net)

#### 4. Build torch module and load weights

In [8]:
model = TorchGraph(index_by_labels(net, '.')).train()
load_state(model, state_dict)

#### 5. Test

In [9]:
from numpy.testing import assert_almost_equal

torch_model = getattr(torchvision.models, model_name)(pretrained=True).train()
inputs = torch.rand((2, 3, 224, 224))

output = model({'input': inputs})
torch_output = torch_model(inputs)

assert_almost_equal(to_numpy(output['fc']), to_numpy(torch_output), decimal=6)
print('Success!')

Success!
