In [1]:
#!pip install parse
import warnings; warnings.simplefilter(action='ignore', category=FutureWarning)

from gamma import *
from gamma.pytorch import *
import gamma

In [2]:
import numpy as np

TF_DIMS = {
    'act': 'NHWC',
    2: ('in_channels', 'out_channels'),
    4: ('kh', 'kw', 'in_channels', 'out_channels')
}

TORCH_DIMS = {
    'act': 'NCHW',
    2: ('out_channels', 'in_channels'),
    4: ('out_channels', 'in_channels', 'kh', 'kw')
}

transpose = lambda x, input_dims, output_dims: x.transpose([input_dims.index(d) for d in output_dims])

def state_dict_from_ckpt(ckpt_name):
    from tensorflow.python import pywrap_tensorflow
    reader = pywrap_tensorflow.NewCheckpointReader(ckpt)
    return {k: reader.get_tensor(k) for k in reader.get_variable_to_shape_map().keys()}

def depthwise_to_grouped_conv(weights):
    kh, kw, in_channels, depth_mult = weights.shape
    return weights.reshape((kh, kw, 1, in_channels*depth_mult))

def tf_to_torch(state_dict):
    res = ((k, (depthwise_to_grouped_conv(v) if ('depthwise_weights' in k) else v)) for (k,v) in state_dict.items())
    res = ((k, (np.array(v) if not isinstance(v, np.ndarray) else v)) for (k, v) in res)
    res = ((k, (transpose(v, TF_DIMS[v.ndim], TORCH_DIMS[v.ndim]) if (v.ndim > 1) else v)) for (k,v) in res)
    return {k: torch.from_numpy(v) for (k,v) in res}


def tf_to_torch_name_rules(use_ema=False):
    sfx = '/ExponentialMovingAverage' if use_ema else ''
    return [     
        ('{}/BatchNorm/beta'+sfx,        '{}/bn/bias'),
        ('{}/BatchNorm/gamma'+sfx,       '{}/bn/weight'),
        ('{}/BatchNorm/moving_mean',     '{}/bn/running_mean'),
        ('{}/BatchNorm/moving_variance', '{}/bn/running_var'),

        ('{}/weights'+sfx,            '{}/conv/weight'),
        ('{}/depthwise_weights'+sfx,  '{}/conv/weight'),
        ('{}/biases'+sfx,             '{}/conv/bias'),
    ]

#### 1. Load pre-trained weights 

In [3]:
model_name = 'mobilenet_v2_1.0_224'
url = f'https://storage.googleapis.com/mobilenet_v2/checkpoints/{model_name}.tgz'
fname = gamma.utils.get_file(url)
ckpt = f'{fname}/{model_name}.ckpt'
use_ema = True

In [4]:
state_dict = tf_to_torch(state_dict_from_ckpt(ckpt))

M2 = 'MobilenetV2/'
rules = [
    (M2+'Conv/{}',                      'prep/expand/{}'),    
    (M2+'expanded_conv/{}',             'prep/{}'),    
    (M2+'expanded_conv_{i}',            'block_{i}'),
    (M2+'Conv_1/{}',                    'classifier/expand/{}'),
    (M2+'Logits/Conv2d_1c_1x1/conv/{}', 'classifier/fc/{}'),
]

state_dict = rename(state_dict, tf_to_torch_name_rules(use_ema))
state_dict = rename(state_dict, rules)
state_dict['classifier/fc/weight'].squeeze_(); #this conv is applied to a 1x1 activation so equivalent to a nn.Linear layer

state = {n: (constant(v, size=v.shape), []) for n, v in state_dict.items()}
draw(state, direction='TB')

#### 2. Network description

In [5]:
num_classes = 1001
act = F.relu6
eps = 1e-3
inverted_residual = node(namedtuple('inverted_residual', ['in_channels', 'h_channels', 'out_channels', 'stride']))

layers = (
    [( 16,  16*6,  24, 2)] + [( 24,  24*6,  24, 1)]*1 +
    [( 24,  24*6,  32, 2)] + [( 32,  32*6,  32, 1)]*2 +
    [( 32,  32*6,  64, 2)] + [( 64,  64*6,  64, 1)]*3 +
    [( 64,  64*6,  96, 1)] + [( 96,  96*6,  96, 1)]*2 +
    [( 96,  96*6, 160, 2)] + [(160, 160*6, 160, 1)]*2 +
    [(160, 160*6, 320, 1)] 
)

net_initial = pipeline([
    ('prep/expand', conv_bn(3, 32,  kernel_size=3, padding=1, stride=2, activation=act, eps=eps), ['input']),
    ('prep/depthwise', conv_bn(32, 32, kernel_size=3, padding=1, groups=32, activation=act, eps=eps)),
    ('prep/project', conv_bn(32, 16, kernel_size=1, eps=eps)),
    *((f'block_{i}', inverted_residual(*params)) for (i, params) in enumerate(layers, 1)),
    ('classifier/expand', conv_bn(320, 1280, kernel_size=1, activation=act, eps=eps)),
    ('classifier/avg_pool', global_avg_pool()),
    ('classifier/dropout', dropout(p=0.5, inplace=True)),
    ('classifier/fc', linear(1280, num_classes))
])   

draw(net_initial)

#### 3. Rewrite rules

In [6]:
@bind_vars
def add_shortcuts(block_name, in_channels, h_channels, _in):
    LHS = {block_name: (inverted_residual(in_channels, h_channels, in_channels, 1), [_in])}
    
    RHS = { 
        block_name: (inverted_residual(in_channels, h_channels, in_channels, 1), [_in]),      
        path(block_name, 'add'): (add(inplace=True), [block_name, _in]),
    }    
    return LHS, RHS, (block_name, path(block_name, 'add'))


rule = add_shortcuts()
net = apply_rule(net_initial, rule)
draw(net)

In [7]:
@bind_vars
def expand_blocks(block_name, in_channels, h_channels, out_channels, stride, _in):
    LHS = {block_name: (inverted_residual(in_channels, h_channels, out_channels, stride),[_in])}
    RHS = pipeline([
        ('expand', conv_bn(in_channels, h_channels, kernel_size=1, activation=act, eps=eps), [_in]),
        ('depthwise', conv_bn(h_channels, h_channels, kernel_size=3, stride=stride, padding=1, groups=h_channels, activation=act, eps=eps)),
        ('project', conv_bn(h_channels, out_channels, kernel_size=1, eps=eps)),
    ], prefix=block_name)
    return LHS, RHS, (block_name, path(block_name, 'project'))


rules = [expand_blocks(), expand_conv_bns(), match_tf_padding()]
for rule in rules: draw(rule)
net = apply_rules(net, rules)
draw(net)

#### 4. Build torch module and load weights

In [8]:
model = TorchGraph(net).eval()
load_state(model, state_dict)

#### 5. Test

In [9]:
assert(use_ema)
import tensorflow as tf
inputs = np.random.rand(2, 224, 224, 3).astype(np.float32)

torch_inputs = {'input': torch.from_numpy(transpose(inputs, TF_DIMS['act'], TORCH_DIMS['act']))}
torch_outputs = model(torch_inputs)

gd = tf.GraphDef.FromString(open(f'{fname}/{model_name}_frozen.pb', 'rb').read())
inp, logits = tf.import_graph_def(gd,  return_elements = ['input:0', 'MobilenetV2/Logits/Squeeze:0'])
with tf.Session(graph=inp.graph):
    x = logits.eval(feed_dict={inp: inputs})
    
np.testing.assert_almost_equal(x, to_numpy(torch_outputs['classifier/fc']), decimal=4)
print('Success!')

Success!
