In [1]:
import os
import argparse
import logging
logging.basicConfig(level=logging.DEBUG)
from common import find_mxnet, fit
from common.util import download_file
import mxnet as mx
import numpy as np
import gzip, struct
from mxnet import profiler
from time import time

def read_data(label, image):
    """
    download and read data into numpy
    """
    base_url = 'http://yann.lecun.com/exdb/mnist/'
    with gzip.open(download_file(base_url+label, os.path.join('data',label))) as flbl:
        magic, num = struct.unpack(">II", flbl.read(8))
        label = np.fromstring(flbl.read(), dtype=np.int8)
    with gzip.open(download_file(base_url+image, os.path.join('data',image)), 'rb') as fimg:
        magic, num, rows, cols = struct.unpack(">IIII", fimg.read(16))
        image = np.fromstring(fimg.read(), dtype=np.uint8).reshape(len(label), rows, cols)
    return (label, image)


def to4d(img):
    """
    reshape to 4D arrays
    """
    return img.reshape(img.shape[0], 1, 28, 28).astype(np.float32)/255

def get_mnist_iter(batch_size):
    """
    create data iterator with NDArrayIter
    """
    (train_lbl, train_img) = read_data(
            'train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz')
    (val_lbl, val_img) = read_data(
            't10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz')
    train = mx.io.NDArrayIter(
        to4d(train_img), train_lbl, batch_size, shuffle=True)
    val = mx.io.NDArrayIter(
        to4d(val_img), val_lbl, batch_size)
    return (train, val)

class random_mnist_iterator(mx.io.DataIter):
    '''random task ilab iterator'''
    #requires bucketing module, only constraint should be that symgen in the bucketing module must give a single output with name softmax[bucketing_key+1]
    def __init__(self, data_iter):
        super(random_mnist_iterator, self).__init__()
        self.data_iter = data_iter
        self.batch_size = self.data_iter.batch_size

    @property
    def provide_data(self):
        return self.data_iter.provide_data

    @property
    def provide_label(self):
        provide_label = self.data_iter.provide_label
        label_names=[]
        batch_size=[]
        for i in range(2):
            label_names.append('softmax%d_label'%(i+1))
            batch_size.append((self.batch_size,))
            label_names=['softmax_label']
        return zip(label_names,batch_size)  
           
         #provide_label must have an output like this       
        #return [('softmax1_label', (self.batch_size,)), \
         #       ('softmax2_label', (self.batch_size,)), \
                #('softmax4_label', (self.batch_size,)), \
                #('softmax3_label', (self.batch_size,))]

    def hard_reset(self):
        self.data_iter.hard_reset()

    def reset(self):
        self.data_iter.reset()

    def next(self):
        batch = self.data_iter.next()
        bucket_key = np.random.randint(0,2)
        bucket_key = 0
        #print bucket_key
        labelnp=[]
        # prepare all the labels
        #print len(batch.label)
        ##print batch.label[0].asnumpy().shape
        #print batch.label[0].asnumpy()
        #for lab in batch.label[0].asnumpy():
        #    labelnp.append(mx.nd.array(lab))
        # take the subset, in this case the single label_id 
        coarse_lab = batch.label[0].asnumpy()>4
        c_lab = [mx.nd.array(1*coarse_lab)]
        #print c_lab
        #print 1*coarse_lab
        f_lab = [batch.label[0]]
        if bucket_key == 0:
            all_label = f_lab
        elif bucket_key == 1:
            all_label = c_lab
            
        
        #print all_label[0].asnumpy()
        # generates the provide label adequate to the current label
        label_names=[]
        batch_size=[]
        for bucket in [bucket_key]:
            label_names.append('softmax%d_label'%(bucket+1))
            batch_size.append((self.batch_size,))
        label_names=['softmax_label']
        #print zip(label_names,batch_size)
        return mx.io.DataBatch(data=batch.data, label=all_label, \
                   pad=batch.pad, index=batch.index, bucket_key=bucket_key, provide_data=self.data_iter.provide_data,  provide_label=zip(label_names,batch_size)) 
    

In [2]:
def get_symbol(num_classes=10,):
    data = mx.symbol.Variable('data')
   
    # first conv
    conv1 = mx.symbol.Convolution(data=data, kernel=(5,5), num_filter=20)
    tanh1 = mx.symbol.Activation(data=conv1, act_type="relu")
    pool1 = mx.symbol.Pooling(data=tanh1, pool_type="max",
                              kernel=(2,2), stride=(2,2))
    # second conv
    conv2 = mx.symbol.Convolution(data=pool1, kernel=(5,5), num_filter=50)
    tanh2 = mx.symbol.Activation(data=conv2, act_type="relu")
    pool2 = mx.symbol.Pooling(data=tanh2, pool_type="max",
                              kernel=(2,2), stride=(2,2))
    # first fullc
    flatten = mx.symbol.Flatten(data=pool2)
    fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=500)
    tanh3 = mx.symbol.Activation(data=fc1, act_type="relu")
    # second fullc
    fc2 = mx.symbol.FullyConnected(data=tanh3, num_hidden=num_classes)
    # loss
    lenet = mx.symbol.SoftmaxOutput(data=fc2, name='softmax')
    return lenet

def get_simple_symbol(num_classes=10,prefix=1):
    data = mx.symbol.Variable('data')
    data_name = ['data']
    # first conv
    conv1 = mx.symbol.Convolution(data=data, kernel=(5,5), num_filter=20,name='conv1')
    tanh1 = mx.symbol.Activation(data=conv1, act_type="relu")
    pool1 = mx.symbol.Pooling(data=tanh1, pool_type="max",
                              kernel=(2,2), stride=(2,2))
    # second conv
    conv2 = mx.symbol.Convolution(data=pool1, kernel=(5,5), num_filter=50,name='conv2%d' %prefix)
    tanh2 = mx.symbol.Activation(data=conv2, act_type="relu",name='relu%d' %prefix)
    pool2 = mx.symbol.Pooling(data=tanh2, pool_type="max",
                              kernel=(2,2), stride=(2,2))
    # first fullc
    flatten = mx.symbol.Flatten(data=pool2)
    fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=500,name='fc1%d' %prefix)
    tanh3 = mx.symbol.Activation(data=fc1, act_type="relu")
    if prefix == 1:
    # second fullc
        fc2 = mx.symbol.FullyConnected(data=tanh3, num_hidden=num_classes,name='fc2%d' %prefix)
    elif prefix == 2:
        fc2 = mx.symbol.FullyConnected(data=tanh3, num_hidden=2,name='fc2%d' %prefix)
    # loss
    #lenet = mx.symbol.SoftmaxOutput(data=fc2, name='softmax%d'%prefix)
    lenet = mx.symbol.SoftmaxOutput(data=fc2, name='softmax')

    label_name = ['softmax_label']

    #label_name = ['softmax%d_label'%prefix]
            
    return lenet,data_name, label_name

def get_gated_simple_symbol(num_classes=10,prefix=1):
    data = mx.symbol.Variable('data')
    data_name = ['data']
    # first conv
    conv1 = mx.symbol.Convolution(data=data, kernel=(5,5), num_filter=20,name='conv1')
    tanh1 = mx.symbol.Activation(data=conv1, act_type="relu")
    pool1 = mx.symbol.Pooling(data=tanh1, pool_type="max",
                              kernel=(2,2), stride=(2,2))
    # second conv
    conv2 = mx.symbol.Convolution(data=pool1, kernel=(5,5), num_filter=50,name='conv2%d' %prefix)
    tanh2 = mx.symbol.Activation(data=conv2, act_type="relu",name='relu%d' %prefix)
    pool2 = mx.symbol.Pooling(data=tanh2, pool_type="max",
                              kernel=(2,2), stride=(2,2))
    # first fullc
    flatten = mx.symbol.Flatten(data=pool2)
    fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=500,name='fc1%d' %prefix)
    gate = mx.sym.Variable('gate',init=mx.initializer.One(),shape=(1,),dtype='float32')
    fc1_gated = mx.sym.broadcast_mul(gate,fc1)
    tanh3 = mx.symbol.Activation(data=fc1_gated, act_type="relu")
    if prefix == 1:
    # second fullc
        fc2 = mx.symbol.FullyConnected(data=tanh3, num_hidden=num_classes,name='fc2%d' %prefix)
    elif prefix == 2:
        fc2 = mx.symbol.FullyConnected(data=tanh3, num_hidden=2,name='fc2%d' %prefix)
    # loss
    lenet = mx.symbol.SoftmaxOutput(data=fc2, name='softmax')
    label_name = ['softmax_label']
            
    return lenet,data_name, label_name

def get_big_symbol(num_classes=10):
    data = mx.symbol.Variable('data')
    data_name = ['data']  
    # first conv
    conv1 = mx.symbol.Convolution(data=data, kernel=(5,5), num_filter=20,name='conv1')
    tanh1 = mx.symbol.Activation(data=conv1, act_type="relu")
    pool1 = mx.symbol.Pooling(data=tanh1, pool_type="max",
                              kernel=(2,2), stride=(2,2))
    # second conv
    conv2_1 = mx.symbol.Convolution(data=pool1, kernel=(5,5), num_filter=50,name='conv21')
    tanh2_1 = mx.symbol.Activation(data=conv2_1, act_type="relu",name='relu1' )
    pool2_1 = mx.symbol.Pooling(data=tanh2_1, pool_type="max",
                              kernel=(2,2), stride=(2,2))
    
    conv2_2 = mx.symbol.Convolution(data=pool1, kernel=(5,5), num_filter=50,name='conv22' )
    tanh2_2 = mx.symbol.Activation(data=conv2_2, act_type="relu",name='relu2')
    pool2_2 = mx.symbol.Pooling(data=tanh2_2, pool_type="max",
                              kernel=(2,2), stride=(2,2))
    # first fullc
    flatten_1 = mx.symbol.Flatten(data=pool2_1)
    fc1_1 = mx.symbol.FullyConnected(data=flatten_1, num_hidden=500,name='fc11')
    
    tanh3_1 = mx.symbol.Activation(data=fc1_1, act_type="relu")
    
    flatten_2 = mx.symbol.Flatten(data=pool2_2)
    fc1_2 = mx.symbol.FullyConnected(data=flatten_2, num_hidden=500,name='fc12' )
    tanh3_2 = mx.symbol.Activation(data=fc1_2, act_type="relu")
    # second fullc
    fc2_1 = mx.symbol.FullyConnected(data=tanh3_1, num_hidden=num_classes,name='fc21' )
    fc2_2 = mx.symbol.FullyConnected(data=tanh3_2, num_hidden=2,name='fc22' )
    # loss
    sym_1 = mx.symbol.SoftmaxOutput(data=fc2_1, name='softmax1')
    sym_2 = mx.symbol.SoftmaxOutput(data=fc2_2, name='softmax2')
    label_name=['softmax1_label','softmax2_label']
    
    return mx.sym.Group([sym_1,sym_2]),data_name, label_name

def sym_gen(key):
    if key == 0:
        sym, data_names,label_names = get_simple_symbol(num_classes=10,prefix=1)
    elif key == 1:    
        sym ,data_names,label_names= get_simple_symbol(num_classes=10,prefix=2)
    elif key == 2:
        sym,data_names,label_names = get_big_symbol(num_classes=10)
        
    return sym,data_names,label_names  

def gated_sym_gen(key):
    if key == 0:
        sym, data_names,label_names = get_simple_symbol(num_classes=10,prefix=1)
    elif key == 1:    
        sym ,data_names,label_names= get_gated_simple_symbol(num_classes=10,prefix=1)
    elif key == 2:
        sym,data_names,label_names = get_big_symbol(num_classes=10)
        
    return sym,data_names,label_names    
batch_size = 10000
train, val = get_mnist_iter(batch_size)
train = random_mnist_iterator(train)
val = random_mnist_iterator(val)

schedule = [20000,30000,40000]



In [3]:
sym,data,labels=gated_sym_gen(1)
#mx.viz.plot_network(sym)

In [4]:
# sym, data_names,label_names = get_simple_symbol(num_classes=10,prefix=1)
# sym,data_names,label_names = get_big_symbol(num_classes=10)
# sym.list_arguments()

In [5]:
#lel=train.provide_label[0][0]
mod = mx.mod.Module(sym, label_names=labels, context=[mx.gpu(0)])
#mod = mx.mod.BucketingModule(sym_gen,default_bucket_key=2, context=[mx.gpu(0)])
print labels

['softmax_label']


In [6]:
iterations = 5
#profiler.profiler_set_config(mode='symbolic', filename='profile_output.json')
# profiler.profiler_set_state('run')
optimizer='sgd'
kvstore='local'
optimizer_params={'learning_rate':0.1, 'momentum': 0.9,'wd':0.0000, 'lr_scheduler': mx.lr_scheduler.MultiFactorScheduler(step=schedule,factor=0.1) }
# mod.bind()
mod.bind(data_shapes=train.provide_data,
              label_shapes=train.provide_label)
mod.init_params()
mod.init_optimizer(kvstore=kvstore, optimizer=optimizer,
                            optimizer_params=optimizer_params)

#profiler.profiler_set_state('run')

# real run    
for i in range(iterations):
    batch=train.next()
    start_for = time()
    mod.forward(batch, is_train=True)
    print 'forward time is'+' '+str(time()-start_for)
    start_back= time()
    mod.backward()
    print 'backward time is'+' '+str(time()-start_back)
    start_upd = time()
    mod.update()
    print 'update time is'+' '+str(time()-start_upd)
    for output in mod.get_outputs(merge_multi_context=False)[0]:
        output.wait_to_read()
#profiler.profiler_set_state('stop')

forward time is 0.000754833221436
backward time is 0.000375986099243
update time is 0.00330781936646
forward time is 0.000627040863037
backward time is 4.2200088501e-05
update time is 0.00121402740479
forward time is 0.000645160675049
backward time is 0.000134944915771
update time is 0.00147390365601
forward time is 0.000670194625854
backward time is 0.000105142593384
update time is 0.0014660358429
forward time is 0.00062894821167
backward time is 0.000285863876343
update time is 0.00125694274902


In [7]:
#no scalar
forward time is 0.000964879989624
backward time is 0.000248193740845
update time is 0.00177907943726
forward time is 0.000608921051025
backward time is 0.0024790763855
update time is 0.000823020935059
forward time is 0.000785827636719
backward time is 0.000428915023804
update time is 0.00115990638733
forward time is 0.000653982162476
backward time is 0.000313997268677
update time is 0.00117301940918
forward time is 0.000581026077271
backward time is 0.000270128250122
update time is 0.00115299224854

SyntaxError: invalid syntax (<ipython-input-7-e1129b9874fd>, line 2)

In [None]:
mod.fit(train,
        eval_data=val,
        eval_metric=[mx.metric.Accuracy()],
        batch_end_callback = [mx.callback.Speedometer(1,batch_size)],
        allow_missing=False,
        optimizer_params={'learning_rate':0.1, 'momentum': 0.9,'wd':0.0000, 'lr_scheduler': mx.lr_scheduler.MultiFactorScheduler(step=schedule,factor=0.1) },
        num_epoch=20)

  allow_missing=allow_missing, force_init=force_init)
INFO:root:Epoch[0] Train-accuracy=0.114600
INFO:root:Epoch[0] Time cost=0.139
INFO:root:Epoch[0] Validation-accuracy=0.113500
INFO:root:Epoch[1] Batch [0]	Speed: 11450.18 samples/sec	Train-accuracy=0.110600


saving logfiles0  at None-metric-log-0.txt


In [None]:
mod._buckets[2].get_params()


In [None]:
mod._buckets[1].eval_metric

In [None]:
a = train.next()
a.provide_data[0][1][0]

In [None]:
train.provide_data()