## Mxnet demo

Convert pytorch resnet to mxnet and compare.

In [1]:
import warnings; warnings.simplefilter(action='ignore', category=FutureWarning)

from gamma import *
from gamma.pytorch import *
from gamma import mxn
from gamma.models import *

In [2]:
depth=34
state_dict = load_resnet_weights(depth)

name_rules = [
    ('{}/bn/weight', '{}/bn/gamma'),
    ('{}/bn/bias', '{}/bn/beta'),
    ('{}', '{}')
]

mxn_state_dict = rename(state_dict, name_rules)

In [3]:
net_initial, rules, net = resnet(depth, num_classes=1000)
mnet = apply_rules(net, mxn.rules)
draw(net)
draw(mnet)

In [4]:
mxn_model = mxn.MxnetGraph(mnet)
mxn.load_state(mxn_model, mxn_state_dict)

torch_model = TorchGraph(net).eval()
load_state(torch_model, state_dict, sep='/')

In [5]:
from numpy.testing import assert_almost_equal

inputs = {'input': torch.rand((2, 3, 224, 224)), 'target': torch.randint(0, 1000, (2,), dtype=torch.int64)}

torch_output = torch_model(inputs)
mxn_output = mxn_model(mxn.to_nd(inputs))

assert_almost_equal(mxn_output['classifier'].asnumpy(), to_numpy(torch_output['classifier']), decimal=5)
print('Success!')

Success!
