In [1]:
import mxnet as mx
from mxnet import gluon
from mxnet.gluon import nn

mx.random.seed(42)

class Net(gluon.HybridBlock):
    def __init__(self, **kwargs):
        super(Net, self).__init__(**kwargs)
        with self.name_scope():
            # layers created in name_scope will inherit name space
            # from parent layer.
            self.conv1 = nn.Conv2D(6, kernel_size=5)
            self.pool1 = nn.MaxPool2D(pool_size=2)
            self.conv2 = nn.Conv2D(16, kernel_size=5)
            self.pool2 = nn.MaxPool2D(pool_size=2)
            self.fc1 = nn.Dense(120)
            self.fc2 = nn.Dense(84)
            # You can use a Dense layer for fc3 but we do dot product manually
            # here for illustration purposes.
            self.fc3_weight = self.params.get('fc3_weight', shape=(10, 84))

    def hybrid_forward(self, F, x, fc3_weight):
        # Here `F` can be either mx.nd or mx.sym, x is the input data,
        # and fc3_weight is either self.fc3_weight.data() or
        # self.fc3_weight.var() depending on whether x is Symbol or NDArray
        print(x)
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        # 0 means copy over size from corresponding dimension.
        # -1 means infer size from the rest of dimensions.
        x = x.reshape((0, -1))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.dot(x, fc3_weight, transpose_b=True)
        return x

In [2]:
net = Net()

In [3]:
net._children

OrderedDict([('conv1', Conv2D(None -> 6, kernel_size=(5, 5), stride=(1, 1))),
             ('pool1',
              MaxPool2D(size=(2, 2), stride=(2, 2), padding=(0, 0), ceil_mode=False, global_pool=False, pool_type=max, layout=NCHW)),
             ('conv2', Conv2D(None -> 16, kernel_size=(5, 5), stride=(1, 1))),
             ('pool2',
              MaxPool2D(size=(2, 2), stride=(2, 2), padding=(0, 0), ceil_mode=False, global_pool=False, pool_type=max, layout=NCHW)),
             ('fc1', Dense(None -> 120, linear)),
             ('fc2', Dense(None -> 84, linear))])

In [4]:
net._reg_params

{'fc3_weight': Parameter net0_fc3_weight (shape=(10, 84), dtype=<class 'numpy.float32'>)}

In [5]:
net._flags

[]

In [6]:
net._cached_graph

()

In [7]:
net._cached_op

In [8]:
net.initialize()

In [9]:
x = mx.nd.random_normal(shape=(16, 1, 28, 28))

In [10]:
net(x)


[[[[-1.32751155e+00 -3.21425800e-03  2.21258700e-01 ... -9.93653476e-01
     7.21561074e-01  2.13800624e-01]
   [-1.15573585e-01 -6.79732859e-01 -3.19358349e-01 ...  9.27734792e-01
     7.16477394e-01 -1.62013519e+00]
   [-1.20223366e-01  1.80030048e+00 -8.33329916e-01 ...  7.20802903e-01
    -8.19966644e-02  1.31949830e+00]
   ...
   [-7.68818617e-01 -2.06548309e+00 -1.54131448e+00 ... -5.70446610e-01
    -2.37787798e-01 -1.11657321e+00]
   [ 1.02199569e-01 -7.87420928e-01  4.67959344e-01 ...  1.08413196e+00
     8.50915074e-01  7.07204640e-01]
   [ 1.55082130e+00  8.45636368e-01  3.37201893e-01 ...  1.57013047e+00
    -7.40349710e-01  1.78594506e+00]]]


 [[[-5.39527774e-01 -1.09334636e+00 -6.54727042e-01 ... -1.39392352e+00
     1.61527574e+00 -9.38860178e-01]
   [-1.80605817e+00 -2.46331286e+00  1.28754568e+00 ...  7.73943901e-01
    -2.16658860e-01 -1.13087308e+00]
   [-1.19116390e+00 -7.51096904e-01 -4.72258389e-01 ... -1.43151629e+00
     4.19944555e-01  4.37604368e-01]
   ...



[[-2.3739152e-03  1.2288181e-03 -2.6196400e-03  1.7307401e-02
   1.3160441e-03 -1.1935915e-03  9.0257125e-03 -3.4432653e-03
   6.9268309e-03 -2.5144888e-03]
 [-1.4261892e-03  2.8535747e-03  7.0181466e-04  1.1935305e-02
   2.3925959e-03 -1.1124604e-03  5.5938973e-03 -1.5007076e-03
   5.9103300e-03 -4.1204025e-03]
 [-2.0689422e-03  2.8114044e-03 -2.5111223e-03  1.9862700e-02
   1.7765041e-03 -7.0621585e-04  8.7961880e-03 -1.3219281e-03
   7.7348771e-03 -4.3676980e-03]
 [ 8.9091738e-04  3.9138813e-03 -8.1995809e-03  1.5244451e-02
  -1.2768649e-03 -6.1824790e-04  9.6626244e-03 -3.7912414e-03
   7.6066726e-03 -4.3306183e-03]
 [-4.9248966e-04  1.2496498e-03 -7.8753149e-04  1.5128225e-02
  -1.2085200e-03 -1.0806179e-03  8.5483445e-03 -1.0144414e-03
   2.2771757e-03 -1.7787039e-04]
 [ 2.1468885e-03  2.0143101e-03 -1.8229801e-03  1.7467944e-02
  -2.2699921e-03 -4.8990070e-05  1.0768503e-02  4.2958197e-04
   3.4390155e-03 -1.4848784e-03]
 [-4.3142017e-04  2.1216008e-03 -1.1781133e-03  1.2875386

In [11]:
net._flags

[]

In [12]:
net._cached_graph

()

In [13]:
net._cached_op

In [14]:
net.hybridize()

In [21]:
from mxnet.gluon.block import _flatten

In [40]:
def t(*args):
    return args

In [41]:
_, fmt = _flatten(t(x), "input")
fmt

[0]

In [28]:
net._in_format

[0]

forward -> _call_cached_op -> _build_cache -> _flatten -> ...

In [15]:
net._active

True

In [16]:
net._cached_op is None

True

In [17]:
from mxnet.symbol import Symbol
from mxnet.ndarray import NDArray


def _flatten(args, inout_str):
    if isinstance(args, NDArray):
        return [args], int(0)
    if isinstance(args, Symbol):
        length = len(args.list_outputs())
        length = length if length > 1 else 0
        return [args], int(length)

    assert isinstance(args, (list, tuple)), \
        "HybridBlock %s must be (nested) list of Symbol or NDArray, " \
        "but got %s of type %s"%(inout_str, str(args), str(type(args)))
    flat = []
    fmts = []
    for i in args:
        arg, fmt = _flatten(i, inout_str)
        flat.extend(arg)
        fmts.append(fmt)
    return flat, fmts

In [18]:
from mxnet import symbol

In [19]:
def _regroup(args, fmt):
    if isinstance(fmt, int):
        if fmt == 0:
            return args[0], args[1:]
        return args[:fmt], args[fmt:]

    assert isinstance(args, (list, tuple)), \
        "HybridBlock output must be (nested) list of Symbol or NDArray, " \
        "but got %s of type %s"%(str(args), str(type(args)))
    ret = []
    for i in fmt:
        res, args = _regroup(args, i)
        ret.append(res)
    return ret, args

In [20]:
# data, out = net._get_graph(x)
# data, out

print(not net._cached_graph)
print(net._in_format is None)

args, net._in_format= _flatten(x, "input")
print(len(args), args[0].shape)
print(net._in_format)

inputs = [symbol.var("data")]
print(inputs)

grouped_inputs = _regroup(inputs, net._in_format)[0]
print(grouped_inputs)

params = {i: j.var() for i, j in net._reg_params.items()}
print(params)

with net.name_scope():
    out = net.hybrid_forward(symbol, *grouped_inputs, **params)
    
out, net._out_format = _flatten(out, "output")
print(out, net._out_format)

net._cached_graph = inputs, symbol.Group(out)
print(net._cached_graph)

True
True
1 (16, 1, 28, 28)
0
[<Symbol data>]
<Symbol data>
{'fc3_weight': <Symbol net0_fc3_weight>}
<Symbol data>
[<Symbol net0_dot0>] 0
([<Symbol data>], <Symbol net0_dot0>)


In [21]:
inputs

[<Symbol data>]

In [22]:
net._in_format

0

In [23]:
grouped_inputs

<Symbol data>

In [24]:
data, out = net._cached_graph

In [25]:
data, out

([<Symbol data>], <Symbol net0_dot0>)

In [26]:
data_names = {data.name: i for i, data in enumerate(data)}
data_names

{'data': 0}

In [27]:
params = net.collect_params()
params

net0_ (
  Parameter net0_fc3_weight (shape=(10, 84), dtype=<class 'numpy.float32'>)
  Parameter net0_conv0_weight (shape=(6, 1, 5, 5), dtype=<class 'numpy.float32'>)
  Parameter net0_conv0_bias (shape=(6,), dtype=<class 'numpy.float32'>)
  Parameter net0_conv1_weight (shape=(16, 6, 5, 5), dtype=<class 'numpy.float32'>)
  Parameter net0_conv1_bias (shape=(16,), dtype=<class 'numpy.float32'>)
  Parameter net0_dense0_weight (shape=(120, 256), dtype=float32)
  Parameter net0_dense0_bias (shape=(120,), dtype=float32)
  Parameter net0_dense1_weight (shape=(84, 120), dtype=float32)
  Parameter net0_dense1_bias (shape=(84,), dtype=float32)
)

In [28]:
input_names = out.list_inputs()
input_names

['data',
 'net0_conv0_weight',
 'net0_conv0_bias',
 'net0_conv1_weight',
 'net0_conv1_bias',
 'net0_dense0_weight',
 'net0_dense0_bias',
 'net0_dense1_weight',
 'net0_dense1_bias',
 'net0_fc3_weight']

In [31]:
param_names = set(params.keys())
param_names

{'net0_conv0_bias',
 'net0_conv0_weight',
 'net0_conv1_bias',
 'net0_conv1_weight',
 'net0_dense0_bias',
 'net0_dense0_weight',
 'net0_dense1_bias',
 'net0_dense1_weight',
 'net0_fc3_weight'}

In [32]:
expected_names = set(input_names)
expected_names

{'data',
 'net0_conv0_bias',
 'net0_conv0_weight',
 'net0_conv1_bias',
 'net0_conv1_weight',
 'net0_dense0_bias',
 'net0_dense0_weight',
 'net0_dense1_bias',
 'net0_dense1_weight',
 'net0_fc3_weight'}

In [33]:
used_data_names = [i for i in data_names if i in expected_names]
used_data_names

['data']

In [34]:
len(used_data_names) != len(data_names)

False

In [35]:
used_param_names = [i for i in param_names if i in expected_names]
used_param_names

['net0_conv0_weight',
 'net0_conv0_bias',
 'net0_conv1_weight',
 'net0_conv1_bias',
 'net0_dense1_weight',
 'net0_dense0_weight',
 'net0_dense1_bias',
 'net0_fc3_weight',
 'net0_dense0_bias']

In [36]:
len(used_param_names) != len(param_names)

False

In [37]:
data_indices = []
param_indices = []
net._cached_op_args = []

In [38]:
for i, name in enumerate(input_names):
    if name in data_names:
        data_indices.append(i)
        net._cached_op_args.append((True, data_names[name]))
    else:
        param_indices.append(i)
        net._cached_op_args.append((False, params[name]))
net._cached_op_args

[(True, 0),
 (False,
  Parameter net0_conv0_weight (shape=(6, 1, 5, 5), dtype=<class 'numpy.float32'>)),
 (False,
  Parameter net0_conv0_bias (shape=(6,), dtype=<class 'numpy.float32'>)),
 (False,
  Parameter net0_conv1_weight (shape=(16, 6, 5, 5), dtype=<class 'numpy.float32'>)),
 (False,
  Parameter net0_conv1_bias (shape=(16,), dtype=<class 'numpy.float32'>)),
 (False, Parameter net0_dense0_weight (shape=(120, 256), dtype=float32)),
 (False, Parameter net0_dense0_bias (shape=(120,), dtype=float32)),
 (False, Parameter net0_dense1_weight (shape=(84, 120), dtype=float32)),
 (False, Parameter net0_dense1_bias (shape=(84,), dtype=float32)),
 (False,
  Parameter net0_fc3_weight (shape=(10, 84), dtype=<class 'numpy.float32'>))]

In [39]:
net._flags

[]

In [40]:
flags = [
    ('data_indices', data_indices), 
    ('param_indices', param_indices)
] + net._flags
flags

[('data_indices', [0]), ('param_indices', [1, 2, 3, 4, 5, 6, 7, 8, 9])]

In [41]:
from mxnet import ndarray

In [42]:
net._cached_op

In [43]:
net._cached_op = ndarray.CachedOp(out, flags)
net._cached_op

<mxnet._ctypes.ndarray.CachedOp at 0x2b9ac38d0d8>

---

---

In [44]:
net(x)

AssertionError: Invalid input format

In [14]:
net._flags

[]

In [15]:
net._cached_graph

([<Symbol data>], <Symbol net0_dot0>)

In [16]:
net._cached_op

<mxnet._ctypes.ndarray.CachedOp at 0x2eeaf832d68>

---