In [1]:
import os
import argparse
import logging
logging.basicConfig(level=logging.DEBUG)
from common import find_mxnet, data, fit
from common.util import download_file
import mxnet as mx

  chunks = self.iterencode(o, _one_shot=True)


In [2]:
def download_cifar10():
    data_dir="/efs/data/cifar-100-mxnet"
    #data_dir="/data/cifar-100-mxnet"
    fnames = (os.path.join(data_dir, "train.rec"),
              os.path.join(data_dir, "test.rec"))
    return fnames

  chunks = self.iterencode(o, _one_shot=True)


In [3]:
(train_fname, val_fname) = download_cifar10()

# parse args
parser = argparse.ArgumentParser(description="train cifar100",
                                 formatter_class=argparse.ArgumentDefaultsHelpFormatter)
fit.add_fit_args(parser)
data.add_data_args(parser)
data.add_data_aug_args(parser)
data.set_data_aug_level(parser, 2)
parser.set_defaults(
    # data
    data_train     = train_fname,
    data_val       = val_fname,
    num_classes    = 100,
    num_examples  = 50000,
    image_shape    = '3,28,28',
 #   model_prefix   = '/efs/checkpoints/inceptionStandard',
    pad_size       = 4,
    # train
    batch_size     = 128,
    num_epochs     = 45,
    lr             = .05,
    lr_step_epochs = '38',
    gpus           = '4,5,6,7'
)
args = parser.parse_args("")

  chunks = self.iterencode(o, _one_shot=True)


In [4]:
from mxnet.io import DataDesc

#todo unsupervised should not contain the supervised samples
#(trainUnsup, valUnsup) = data.get_rec_iter(args)

#train = mx.io.PrefetchingIter([trainSup, trainUnsup],
#                        rename_data=[{'data': 'dataSup'}, 
#                                     {'data': 'dataUnsup'}],
#                        rename_label=[{'softmax_label': 'labelSup'}, 
#                                     {'softmax_label': 'labelUnsup'}])
sup_batch_size = 128
unsup_multiplier = 1

args.batch_size = sup_batch_size * (unsup_multiplier+1)
(train, val) = data.get_rec_iter(args, shuffle=False)


class Split_Iterator(mx.io.DataIter):
    '''splits dataset into supervised and unsupervised part'''

    def __init__(self, data_iter):
        super(Split_Iterator, self).__init__()
        self.data_iter = data_iter
        self.batch_size = self.data_iter.batch_size

    @property
    def provide_data(self):
        data = [DataDesc('dataSup', (sup_batch_size,3,28,28), 'float32')]
        return data + [DataDesc('dataUnsup'+str(i), (sup_batch_size,3,28,28), 'float32') for i in range(unsup_multiplier)]
        
    @property
    def provide_label(self):
        return [DataDesc('labelSup', (sup_batch_size,), 'float32')]

    def hard_reset(self):
        self.data_iter.hard_reset()

    def reset(self):
        self.data_iter.reset()

    def next(self):
        batch = self.data_iter.next()
        
        if unsup_multiplier >= 1:
        
            data = batch.data[0]
            labels = batch.label[0]

            dataSplits = mx.ndarray.split(data, axis=0, num_outputs=(unsup_multiplier+1))
            labelSplits = mx.ndarray.split(labels, axis=0, num_outputs=(unsup_multiplier+1))

            # todo shuffling, does not make sense in a minibatch           
            
            return mx.io.DataBatch(data=dataSplits, label=[labelSplits[0]], \
                        pad=batch.pad, index=batch.index)
        else:
            return mx.io.DataBatch(data=batch.data, label=batch.label, \
                pad=batch.pad, index=batch.index)

            
splittedTrainData = Split_Iterator(train)

  chunks = self.iterencode(o, _one_shot=True)


In [5]:
def ConvFactory(data, num_filter, kernel, stride=(1,1), pad=(0, 0), name=None, suffix=''):
    conv = mx.symbol.Convolution(data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad, name='conv_%s%s' %(name, suffix))
    bn = mx.symbol.BatchNorm(data=conv, name='bn_%s%s' %(name, suffix))
    act = mx.symbol.LeakyReLU(data=bn, act_type='rrelu', name='rrelu_%s%s' %(name, suffix))
    return act

def InceptionFactoryA(data, num_1x1, num_3x3red, num_3x3, num_d3x3red, num_d3x3, pool, proj, name):
    # 1x1
    c1x1 = ConvFactory(data=data, num_filter=num_1x1, kernel=(1, 1), name=('%s_1x1' % name))
    # 3x3 reduce + 3x3
    c3x3r = ConvFactory(data=data, num_filter=num_3x3red, kernel=(1, 1), name=('%s_3x3' % name), suffix='_reduce')
    c3x3 = ConvFactory(data=c3x3r, num_filter=num_3x3, kernel=(3, 3), pad=(1, 1), name=('%s_3x3' % name))
    # double 3x3 reduce + double 3x3
    cd3x3r = ConvFactory(data=data, num_filter=num_d3x3red, kernel=(1, 1), name=('%s_double_3x3' % name), suffix='_reduce')
    cd3x3 = ConvFactory(data=cd3x3r, num_filter=num_d3x3, kernel=(3, 3), pad=(1, 1), name=('%s_double_3x3_0' % name))
    cd3x3 = ConvFactory(data=cd3x3, num_filter=num_d3x3, kernel=(3, 3), pad=(1, 1), name=('%s_double_3x3_1' % name))
    # pool + proj
    pooling = mx.symbol.Pooling(data=data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name)))
    cproj = ConvFactory(data=pooling, num_filter=proj, kernel=(1, 1), name=('%s_proj' %  name))
    # concat
    concat = mx.symbol.Concat(*[c1x1, c3x3, cd3x3, cproj], name='ch_concat_%s_chconcat' % name)
    return concat

def InceptionFactoryB(data, num_3x3red, num_3x3, num_d3x3red, num_d3x3, name):
    # 3x3 reduce + 3x3
    c3x3r = ConvFactory(data=data, num_filter=num_3x3red, kernel=(1, 1), name=('%s_3x3' % name), suffix='_reduce')
    c3x3 = ConvFactory(data=c3x3r, num_filter=num_3x3, kernel=(3, 3), pad=(1, 1), stride=(2, 2), name=('%s_3x3' % name))
    # double 3x3 reduce + double 3x3
    cd3x3r = ConvFactory(data=data, num_filter=num_d3x3red, kernel=(1, 1),  name=('%s_double_3x3' % name), suffix='_reduce')
    cd3x3 = ConvFactory(data=cd3x3r, num_filter=num_d3x3, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name=('%s_double_3x3_0' % name))
    cd3x3 = ConvFactory(data=cd3x3, num_filter=num_d3x3, kernel=(3, 3), pad=(1, 1), stride=(2, 2), name=('%s_double_3x3_1' % name))
    # pool + proj
    pooling = mx.symbol.Pooling(data=data, kernel=(3, 3), stride=(2, 2), pad=(1,1), pool_type="max", name=('max_pool_%s_pool' % name))
    # concat
    concat = mx.symbol.Concat(*[c3x3, cd3x3, pooling], name='ch_concat_%s_chconcat' % name)
    return concat

  chunks = self.iterencode(o, _one_shot=True)


In [6]:
def build_embeddings(data, nembeddings, grad_scale=1.0):
    # data
    #data = mx.symbol.Variable(name="data")
    # stage 2
    in3a = InceptionFactoryA(data, 64, 64, 64, 64, 96, "avg", 32, '3a')
    in3b = InceptionFactoryA(in3a, 64, 64, 96, 64, 96, "avg", 64, '3b')
    in3c = InceptionFactoryB(in3b, 128, 160, 64, 96, '3c')
    # stage 3
    in4a = InceptionFactoryA(in3c, 224, 64, 96, 96, 128, "avg", 128, '4a')
    in4b = InceptionFactoryA(in4a, 192, 96, 128, 96, 128, "avg", 128, '4b')
    in4c = InceptionFactoryA(in4b, 160, 128, 160, 128, 160, "avg", 128, '4c')
    in4d = InceptionFactoryA(in4c, 96, 128, 192, 160, 192, "avg", 128, '4d')
    in4e = InceptionFactoryB(in4d, 128, 192, 192, 256, '4e')
    # stage 4
    in5a = InceptionFactoryA(in4e, 352, 192, 320, 160, 224, "avg", 128, '5a')
    in5b = InceptionFactoryA(in5a, 352, 192, 320, 192, 224, "max", 128, '5b')
    # global avg pooling
    avg = mx.symbol.Pooling(data=in5b, kernel=(7, 7), stride=(1, 1), name="global_pool", pool_type='avg')
    # linear classifier
    flatten = mx.symbol.Flatten(data=avg, name='flatten')
    
    #flatten has 1024 outputs
    #arg_shape, output_shape, aux_shape = flatten.infer_shape(data=(1, 3,28,28))
    #print(output_shape)
    
    # nice and smooth embeddings
    fc0 = mx.symbol.FullyConnected(data=flatten, num_hidden=nembeddings, name='fc0') 
    return fc0

  chunks = self.iterencode(o, _one_shot=True)


In [7]:
from mxnet.symbol import *

def getshape(tensor):
    arg_shape, output_shape, aux_shape = tensor.infer_shape(labelSup=(128,))
    print(output_shape)
def getshapeData(tensor):
    arg_shape, output_shape, aux_shape = tensor.infer_shape(
        dataUnsup=(128,3,28,28),dataSup=(128,3,28,28))
    print(output_shape)
    
def compute_visit_loss(p, visit_weight=1):
    
    visit_probability = mean(p, axis=(0), keepdims=True, name='visit_prob')
    
    t_nb = sup_batch_size * unsup_multiplier
    
    init = mx.initializer.Constant(t_nb)
    t_nb = var('t_nb', init=init, dtype='float32', shape=(1))
    
    target = broadcast_div(ones_like(visit_probability), t_nb)
    
    arg_shape, output_shape, aux_shape = visit_probability.infer_shape(
        dataUnsup0=(128,3,28,28),dataSup=(128,3,28,28))
    print(output_shape)
    arg_shape, output_shape, aux_shape = target.infer_shape(
        dataUnsup0=(128,3,28,28),dataSup=(128,3,28,28))
    print(output_shape)
    arg_shape, output_shape, aux_shape = t_nb.infer_shape()
    print(output_shape)
    
    visit_loss = SoftmaxOutput(visit_probability, target, grad_scale=visit_weight, name='visit_loss')
    
    return visit_loss
    
def compute_semisup_loss(a,b,labels,walker_weight=1., visit_weight=1.):
    equality_matrix = broadcast_equal(reshape(labels, shape=(-1,1)), labels, name="eqmat")
    
    equality_matrix = cast(equality_matrix, dtype='float32')
    p_target = broadcast_div(equality_matrix,
                             sum(equality_matrix, axis=(1), keepdims=True))
    
    match_ab = dot(a,transpose(b),name='match_ab')
    p_ab = softmax(match_ab, name='p_ab')
    p_ba = softmax(transpose(match_ab), name='p_ba')
    p_aba = dot(p_ab, p_ba, name='p_aba')
    
    #todo: create walk statistics
    
    # softmaxOutput should be cross entropy loss: https://github.com/dmlc/mxnet/issues/1969
    # apparently this calculates the gradient of cross entropy loss for backprop, so should
    # be equivalent
    
    # probably no need to log: l = log(1e-8 + p_aba, name='log_aba')
    
    walker_loss = SoftmaxOutput(p_aba, p_target, name='loss_aba', grad_scale=walker_weight)
    
    # this would be some kind of cross entropy loss. does not work yet though
    #cross_entropy = p_target * log(out) + (1 - p_target) * log(1 - out)
    #loss = MakeLoss(cross_entropy)
    
    visit_loss = compute_visit_loss(p_ab, visit_weight)

    return (walker_loss, visit_loss)

  chunks = self.iterencode(o, _one_shot=True)


In [8]:
def logit_loss(embeddings, labels, nclasses, grad_scale=1):
    fc1 = mx.symbol.FullyConnected(data=embeddings, num_hidden=nclasses, name='fc1'+str(grad_scale))
    softmax = mx.symbol.SoftmaxOutput(fc1, labels, name='softmax'+str(grad_scale), grad_scale=grad_scale)
    return softmax

  chunks = self.iterencode(o, _one_shot=True)


In [9]:
def build():
    dataSup = mx.symbol.Variable(name="dataSup")
    labelSup = mx.symbol.Variable(name='labelSup')
    overall_loss = []
    
    if unsup_multiplier >= 1:
        dataUnsup = mx.symbol.Variable(name="dataUnsup0")

        # concat data, feed both through the network
        # then split it up again
        data = concat(dataSup, dataUnsup, dim=0)

        embeddings = build_embeddings(data, nembeddings=256)
        splitted = split(embeddings, num_outputs=(unsup_multiplier+1), axis=0)

        supEmbeddings = splitted[0]
        unsupEmbeddings = splitted[1]
        (walker_loss, visit_loss) = compute_semisup_loss(supEmbeddings, unsupEmbeddings, labelSup, 
                                                     walker_weight=0.5, visit_weight=0.5)
        overall_loss = [walker_loss, visit_loss]
        
    else:
        supEmbeddings = build_embeddings(dataSup, nembeddings=256)
        
    overall_loss = [logit_loss(supEmbeddings, labelSup, 100)] + overall_loss
    
    return Group(overall_loss)


  chunks = self.iterencode(o, _one_shot=True)


In [10]:
class Multi_Accuracy(mx.metric.EvalMetric):
    """Calculate accuracies of multi label"""

    def __init__(self, num=None):
        super(Multi_Accuracy, self).__init__('multi-accuracy', num)

    def update(self, labels, preds):

        for i in range(len(preds)):
            pred_label = mx.nd.argmax_channel(preds[i]).asnumpy().astype('int32')
            label = labels[0].asnumpy().astype('int32')

            #mx.metric.check_label_shapes(label, pred_label)

            if i is None:
                self.sum_metric += (pred_label.flat == label.flat).sum()
                self.num_inst += len(pred_label.flat)
            else:
                self.sum_metric[i] += (pred_label.flat == label.flat).sum()
                self.num_inst[i] += len(pred_label.flat)

  chunks = self.iterencode(o, _one_shot=True)


In [None]:
def fit_model(args, network, data, **kwargs):
    """
    train a model
    args : argparse returns
    network : the symbol definition of the nerual network
    data_loader : function that returns the train and val data iterators
    """
    # kvstore
    kv = mx.kvstore.create(args.kv_store)

    # logging
    head = '%(asctime)-15s Node[' + str(kv.rank) + '] %(message)s'
    logging.basicConfig(level=logging.DEBUG, format=head)
    logging.info('start with arguments %s', args)

    # data iterators
    (train, val) = data
    if args.test_io:
        tic = time.time()
        for i, batch in enumerate(train):
            for j in batch.data:
                j.wait_to_read()
            if (i+1) % args.disp_batches == 0:
                logging.info('Batch [%d]\tSpeed: %.2f samples/sec' % (
                    i, args.disp_batches*args.batch_size/(time.time()-tic)))
                tic = time.time()
        return

    # load model
    if 'arg_params' in kwargs and 'aux_params' in kwargs:
        arg_params = kwargs['arg_params']
        aux_params = kwargs['aux_params']
    else:
        sym, arg_params, aux_params = fit._load_model(args, kv.rank)
        if sym is not None:
            assert sym.tojson() == network.tojson()

    # save model
    checkpoint = fit._save_model(args, kv.rank)

    # devices for training
    devs = mx.cpu() if args.gpus is None or args.gpus is '' else [
        mx.gpu(int(i)) for i in args.gpus.split(',')]

    # learning rate
    lr, lr_scheduler = fit._get_lr_scheduler(args, kv)

    data_names = ['dataSup'] + ['dataUnsup'+str(i) for i in range(unsup_multiplier)]
    # create model
    model = mx.mod.Module(
        context       = devs,
        symbol        = network,
        label_names   = ['labelSup'],
        data_names   = data_names
    )
    
    #print(model.label_names)
    lr_scheduler  = lr_scheduler
    optimizer_params = {
            'learning_rate': lr,
            'momentum' : args.mom,
            'wd' : args.wd,
            'lr_scheduler': lr_scheduler}

    monitor = mx.mon.Monitor(interval=1000, pattern='.*aba_backward.*') 
    
    initializer = mx.init.Xavier(
        rnd_type='gaussian', factor_type="in", magnitude=2)
    # initializer   = mx.init.Xavier(factor_type="in", magnitude=2.34),

    # evaluation metrices
    eval_metrics = Multi_Accuracy(num=3)

    # callbacks that run after each batch
    batch_end_callbacks = [mx.callback.Speedometer(args.batch_size, args.disp_batches)]

    # run
    model.fit(train,
        begin_epoch        = args.load_epoch if args.load_epoch else 0,
        num_epoch          = args.num_epochs,
        eval_data          = val,
        eval_metric        = eval_metrics,
        kvstore            = kv,
        optimizer          = args.optimizer,
        optimizer_params   = optimizer_params,
        initializer        = initializer,
        arg_params         = arg_params,
        aux_params         = aux_params,
        batch_end_callback = batch_end_callbacks,
        epoch_end_callback = checkpoint,
        allow_missing      = True,
        monitor            = monitor)
    return model

  chunks = self.iterencode(o, _one_shot=True)


In [None]:
inception = build()

model = fit_model(args, inception, (splittedTrainData, val))

INFO:root:start with arguments Namespace(batch_size=256, benchmark=0, data_nthreads=4, data_train='/efs/data/cifar-100-mxnet/train.rec', data_val='/efs/data/cifar-100-mxnet/test.rec', disp_batches=20, dtype='float32', gpus='4,5,6,7', image_shape='3,28,28', kv_store='device', load_epoch=None, lr=0.05, lr_factor=0.1, lr_step_epochs='38', max_random_aspect_ratio=0, max_random_h=36, max_random_l=50, max_random_rotate_angle=0, max_random_s=50, max_random_scale=1, max_random_shear_ratio=0, min_random_scale=1, model_prefix=None, mom=0.9, monitor=0, network=None, num_classes=100, num_epochs=45, num_examples=50000, num_layers=None, optimizer='sgd', pad_size=4, random_crop=1, random_mirror=1, rgb_mean='123.68,116.779,103.939', test_io=0, top_k=0, wd=0.0001)


[(1, 128)]
[(1, 128)]
[(1,)]


INFO:root:Batch:       1 loss_aba_backward_data         0.0794742	
INFO:root:Batch:       1 loss_aba_backward_label        0.0754529	
INFO:root:Batch:       1 p_aba_backward_0               0.0575235	
INFO:root:Batch:       1 p_aba_backward_1               0.0646825	
INFO:root:Batch:       1 loss_aba_backward_data         0.0838479	
INFO:root:Batch:       1 loss_aba_backward_label        0.0682083	
INFO:root:Batch:       1 p_aba_backward_0               0.0802431	
INFO:root:Batch:       1 p_aba_backward_1               0.0714157	
INFO:root:Batch:       1 loss_aba_backward_data         0.0794634	
INFO:root:Batch:       1 loss_aba_backward_label        0.0743439	
INFO:root:Batch:       1 p_aba_backward_0               0.0650386	
INFO:root:Batch:       1 p_aba_backward_1               0.0649255	
INFO:root:Batch:       1 loss_aba_backward_data         0.0826275	
INFO:root:Batch:       1 loss_aba_backward_label        0.0665453	
INFO:root:Batch:       1 p_aba_backward_0               0.0751

In [None]:
# batch 10 train 0.389, valid: 0.40