# Deferred Initialization

In [1]:
from mxnet import init, nd
from mxnet.gluon import nn

def getnet():
    net = nn.Sequential()
    net.add(nn.Dense(256, activation='relu'))
    net.add(nn.Dense(10))
    return net
net = getnet()
net.collect_params()

sequential0_ (
  Parameter dense0_weight (shape=(256, 0), dtype=float32)
  Parameter dense0_bias (shape=(256,), dtype=float32)
  Parameter dense1_weight (shape=(10, 0), dtype=float32)
  Parameter dense1_bias (shape=(10,), dtype=float32)
)

### Still Hasn't Initialized yet

In [2]:
net.initialize()
net.collect_params()

sequential0_ (
  Parameter dense0_weight (shape=(256, 0), dtype=float32)
  Parameter dense0_bias (shape=(256,), dtype=float32)
  Parameter dense1_weight (shape=(10, 0), dtype=float32)
  Parameter dense1_bias (shape=(10,), dtype=float32)
)

### When Really Initialized 

In [3]:
x = nd.random.uniform(shape=(2, 20))
net(x)            
net.collect_params()

sequential0_ (
  Parameter dense0_weight (shape=(256, 20), dtype=float32)
  Parameter dense0_bias (shape=(256,), dtype=float32)
  Parameter dense1_weight (shape=(10, 256), dtype=float32)
  Parameter dense1_bias (shape=(10,), dtype=float32)
)

## Deferred Initialization in Practice

In [4]:
class MyInit(init.Initializer):
    def _init_weight(self, name, data):
        print('Init', name, data.shape)
        # The actual initialization logic is omitted here.

net = getnet()
net.initialize(init=MyInit())

### Calling Initialization 

In [5]:
x = nd.random.uniform(shape=(2, 20))
y = net(x)
print('2nd forward')
y = net(x)

Init dense2_weight (256, 20)
Init dense3_weight (10, 256)
2nd forward



## Forced Initialization



In [6]:
print('already known the shape')
net.initialize(init=MyInit(), force_reinit=True)
print('specified the input shape')
net = nn.Sequential()
net.add(nn.Dense(256, in_units=20, activation='relu'))
net.add(nn.Dense(10, in_units=256))
net.initialize(init=MyInit())

already known the shape
Init dense2_weight (256, 20)
Init dense3_weight (10, 256)
specified the input shape
Init dense4_weight (256, 20)
Init dense5_weight (10, 256)
