## Resnet18 demo

Build a resnet  from scratch and compare to torchvision model.

In [1]:
import torch.utils.model_zoo as model_zoo
import torchvision

from gamma import *
from gamma.pytorch import *
from functools import reduce

#### 1. Load pre-trained weights from torch model zoo

In [2]:
state_dict = model_zoo.load_url(torchvision.models.resnet.model_urls['resnet18'])
state = {n: constant(n, [v.data, tuple(v.size())], []) for n, v in state_dict.items()}
draw(state, sep='.', direction='TB', scale=0.7)

#### 2. Network description

In [3]:
basic_block = node('basic_block', ['in_channels', 'out_channels', 'stride'])

num_classes = 1000
prep = [
    conv('conv1', [3, 64, (7, 7), 2, 3, False]),
    bn('bn1', [64]),
    relu('relu1', []),
    max_pool('maxpool1', [(3, 3), 2, 1])
]

layers = [basic_block(name, params) for name, *params in [
    [('layer1', '0'), 64,  64,  1], [('layer1', '1'), 64,  64,  1],
    [('layer2', '0'), 64,  128, 2], [('layer2', '1'), 128, 128, 1],
    [('layer3', '0'), 128, 256, 2], [('layer3', '1'), 256, 256, 1],
    [('layer4', '0'), 256, 512, 2], [('layer4', '1'), 512, 512, 1]]
]

classifier = [
    global_avg_pool('avgpool', []),
    linear('fc', [512, num_classes])
]

net_initial = pipeline(prep + layers + classifier)
draw(net_initial)

#### 3. Rewrite rules

In [4]:
_in, _out, _0, _1, _2, _3, _4, _5, _6 = var('in'), var('out'), *map(var, range(7))

shortcut  = node('shortcut',      ['in_channels', 'out_channels', 'stride'])

#1. expand_blocks
@bind_vars
def expand_blocks(block_name, in_channels, out_channels, stride):
    LHS = {_out: basic_block(block_name, [in_channels, out_channels, stride], [_in]),}
  
    RHS = add_prefix(block_name, {
        _0: conv('conv1', [in_channels, out_channels, (3, 3), stride, 1, False], [_in]),   
        _1: bn('bn1', [out_channels], [_0]),
        _2: relu('relu1', [], [_1]),
        _3: conv('conv2', [out_channels, out_channels, (3, 3), 1, 1, False], [_2]),
        _4: bn('bn2', [out_channels], [_3]),
        _5: shortcut('shortcut', [in_channels, out_channels, stride], [_in]),
        _6: add('add', [], [_4, _5]),     
        _out: relu('relu2', [], [_6]),
    })
  
    return LHS, RHS

#2. replace_shortcuts
@bind_vars
def replace_id_shortcuts(block_name, in_channels):  
    LHS = {_out: shortcut((block_name, 'shortcut'), [in_channels, in_channels, 1], [_in])}
    RHS = {_out: identity((block_name, 'identity'), [], [_in])}
    return LHS, RHS

@bind_vars
def replace_shortcuts(block_name, in_channels, out_channels, stride):
    LHS = {_out: shortcut((block_name, 'shortcut'), [in_channels, out_channels, stride], [_in])}
    RHS = {
        _0: conv(((block_name, 'downsample'), '0'), [in_channels, out_channels, (1, 1), stride, 0, False], [_in]),
        _out: bn(((block_name, 'downsample'), '1'), [out_channels], [_0]),  
    } 
    return LHS, RHS

rules = [expand_blocks(), replace_id_shortcuts(), replace_shortcuts()]

for rule in rules: draw(rule, scale=0.7)

In [5]:
net = reduce(apply_rule, rules, net_initial)
net = reindex(net, {n: path_str(a['label'], '.') for n, a in net.items()})
draw(net, scale=0.7)

#### 4. Build torch module and load weights

In [6]:
model = TorchGraph(net).eval()
load_state(model, state_dict)

#### 5. Test

In [7]:
from numpy.testing import assert_almost_equal

torch_model = torchvision.models.resnet18(pretrained=True).eval()
inputs = torch.rand((2, 3, 224, 224))

output = model({'input': inputs})
torch_output = torch_model(inputs)

assert_almost_equal(to_numpy(output['fc']), to_numpy(torch_output), decimal=6)
print('Success!')

Success!
