In [26]:
import mxnet as mx
import math
import sys
import os

In [70]:
def ConvModule(sym, num_filter, kernel, pad=(0, 0), stride=(1, 1), fix_gamma=True):
    conv = mx.sym.Convolution(data=sym, kernel=kernel, stride=stride, pad=pad, num_filter=num_filter)
    bn = mx.sym.BatchNorm(data=conv, fix_gamma=fix_gamma)
    act = mx.sym.LeakyReLU(data=bn, act_type="leaky") # same memory to our act, less than CuDNN one
    return act

def ResModule(sym, base_filter, stage, layer, fix_gamma=True):
    num_f = base_filter * int(math.pow(2, stage))
    s = 1
    if stage != 0 and layer == 0:
        s = 2
    conv1 = ConvModule(sym, num_f, kernel=(1, 1), pad=(0, 0), stride=(1, 1))
    conv2 = ConvModule(conv1, num_f, kernel=(3, 3), pad=(1, 1), stride=(s, s))
    conv3 = ConvModule(conv2, num_f * 4, kernel=(1, 1), pad=(0, 0), stride=(1, 1))

    if layer == 0:
        sym = ConvModule(sym, num_f * 4, kernel=(1, 1), pad=(0, 0), stride=(s, s))

    sum_sym = sym + conv3
    # Annotate the critical points that can be saved as inter-stage parameter
    sym._set_attr(mirror_stage='True')
    return sum_sym


In [148]:
def get_cost(sym, type_dict=None, **kwargs):
    texec = sym.simple_bind(ctx=mx.cpu(),
                            grad_req='write',
                            type_dict=type_dict,
                            **kwargs)
    alloc_line = [t for t in texec.debug_str().split('\n') if 'allocated' in t][0]
    mem_usage = [int(s) for s in alloc_line.split() if s.isdigit()]
    assert len(mem_usage) == 1
    return mem_usage[0]

In [149]:
base_filter = 64
data = mx.sym.Variable(name='data')
conv1 = ConvModule(data, 64, (7,7), pad=(3, 3), stride=(2, 2))
mp1 = mx.sym.Pooling(data=conv1, pool_type="max", kernel=(3, 3), stride=(2, 2))

sym = mp1
layers = [3, 24, 36, 3]
for j in xrange(len(layers)):
    print j
    for i in xrange(layers[j]):
        sym = ResModule(sym, base_filter, j, i)

avg = mx.symbol.Pooling(data=sym, kernel=(7, 7), stride=(1, 1), name="global_pool", pool_type='avg')
flatten = mx.symbol.Flatten(data=avg, name='flatten')
fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=1000, name='fc1')
net = mx.symbol.SoftmaxOutput(data=fc1, name='softmax')
sym = net

0
1
2
3


In [150]:
len(sym.list_inputs())
sym.infer_type(data='float32')
#sym.get_internals().infer_shape((32, 3, 224,224))

([numpy.float32,
  numpy.float32,
  numpy.float32,
  numpy.float32,
  numpy.float32,
  numpy.float32,
  numpy.float32,
  numpy.float32,
  numpy.float32,
  numpy.float32,
  numpy.float32,
  numpy.float32,
  numpy.float32,
  numpy.float32,
  numpy.float32,
  numpy.float32,
  numpy.float32,
  numpy.float32,
  numpy.float32,
  numpy.float32,
  numpy.float32,
  numpy.float32,
  numpy.float32,
  numpy.float32,
  numpy.float32,
  numpy.float32,
  numpy.float32,
  numpy.float32,
  numpy.float32,
  numpy.float32,
  numpy.float32,
  numpy.float32,
  numpy.float32,
  numpy.float32,
  numpy.float32,
  numpy.float32,
  numpy.float32,
  numpy.float32,
  numpy.float32,
  numpy.float32,
  numpy.float32,
  numpy.float32,
  numpy.float32,
  numpy.float32,
  numpy.float32,
  numpy.float32,
  numpy.float32,
  numpy.float32,
  numpy.float32,
  numpy.float32,
  numpy.float32,
  numpy.float32,
  numpy.float32,
  numpy.float32,
  numpy.float32,
  numpy.float32,
  numpy.float32,
  numpy.float32,
  numpy.float3

In [160]:
input_shape = (64, 3, 224, 224)

get_cost(sym, data=input_shape)

20800

In [161]:
import memonger
reload(memonger)

<module 'memonger' from 'memonger.pyc'>

In [162]:
memonger.search_plan(sym, data=input_shape)

Search threshold=0 MB, cost=5882 MB
Search threshold=1780 MB, cost=3984 MB
Search threshold=1152 MB, cost=3914 MB
Search threshold=1358 MB, cost=4210 MB
Search threshold=1383 MB, cost=4210 MB
Search threshold=1408 MB, cost=4210 MB
Search threshold=1433 MB, cost=4210 MB
Search threshold=1458 MB, cost=4382 MB
Search threshold=1483 MB, cost=4382 MB
Search threshold=1508 MB, cost=4382 MB
Search threshold=1533 MB, cost=4382 MB
Search threshold=1558 MB, cost=4382 MB
Search threshold=1583 MB, cost=4382 MB
Search threshold=1608 MB, cost=4382 MB
Search threshold=1633 MB, cost=3984 MB
Search threshold=1658 MB, cost=3984 MB
Search threshold=1683 MB, cost=3984 MB
Search threshold=1708 MB, cost=3984 MB
Search threshold=1733 MB, cost=3984 MB
Search threshold=1758 MB, cost=3984 MB
Search threshold=1783 MB, cost=3984 MB
Search threshold=1808 MB, cost=3984 MB
Search threshold=1833 MB, cost=4958 MB
Find best plan with threshold=1152, cost=3914 MB


<Symbol softmax>