In [1]:
import os
import argparse
import logging
logging.basicConfig(level=logging.ERROR)
from common import find_mxnet, data, fit
from common.util import download_file
import mxnet as mx
import numpy as np
import gzip, struct
import time
from mxnet.symbol import *

In [8]:
parser = argparse.ArgumentParser(description="train mnist",
                                 formatter_class=argparse.ArgumentDefaultsHelpFormatter)
fit.add_fit_args(parser)
parser.set_defaults(
    # train
    gpus           = '0,1,2,3,4,5,6,7',
    loss_gpu       = 0,
    batch_size     = 100,
    disp_batches   = 100,
    num_epochs     = 20000,
    num_examples   = 100, 
    num_classes    = 10,
    wd             = 1e-4,
    lr             = .001,
    lr_factor      = .33,
    optimizer      = 'adam',
    lr_step_epochs = '5000,10000,15000',
    nembeddings    = 128,
)
args = parser.parse_args("")
unsup_multiplier = 1
labeled_per_class = 10
sample_seed = 47
val_interval = 100
num_unsup_examples = 60000
num_sup_examples = labeled_per_class * args.num_classes

In [9]:
def read_data(label, image):
    """
    download and read data into numpy
    """
    base_url = 'http://yann.lecun.com/exdb/mnist/'
    with gzip.open(download_file(base_url+label, os.path.join('data',label))) as flbl:
        magic, num = struct.unpack(">II", flbl.read(8))
        label = np.fromstring(flbl.read(), dtype=np.int8)
    with gzip.open(download_file(base_url+image, os.path.join('data',image)), 'rb') as fimg:
        magic, num, rows, cols = struct.unpack(">IIII", fimg.read(16))
        image = np.fromstring(fimg.read(), dtype=np.uint8).reshape(len(label), rows, cols)
    return (label, image)


def to4d(img):
    """
    reshape to 4D arrays
    """
    return img.reshape(img.shape[0], 1, 28, 28).astype(np.float32)/255

def sample_by_label(images, labels, n_per_label, num_labels, seed=None):
    """Extract equal number of sampels per class."""
    res_img = []
    res_lbl = []
    rng = np.random.RandomState(seed=seed)
    for i in range(num_labels):
        a = images[labels == i]
        
        if n_per_label == -1:  # use all available labeled data
            res_img.append(a)
        else:  # use randomly chosen subset
            choice = rng.choice(len(a), n_per_label, False)
            print(choice)
            r = a[choice]
            
            res_img.append(r)
            
            lbls = np.ones(n_per_label) * i
            res_lbl.append(lbls)
    return (res_img, res_lbl)


"""
create data iterator with NDArrayIter
"""
(train_lbl, train_img) = read_data(
        'train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz')
(val_lbl, val_img) = read_data(
        't10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz')

(imgs, lbls) = sample_by_label(to4d(train_img), train_lbl, labeled_per_class, 10, seed = sample_seed)
imgs = np.vstack(imgs)
lbls = np.hstack(lbls)

train_sup = mx.io.NDArrayIter(
    imgs, lbls, args.batch_size, shuffle=True, data_name='dataSup', label_name='labelSup')
train_unsup = mx.io.NDArrayIter(
    to4d(train_img), label=None, batch_size=args.batch_size, 
    shuffle=True, data_name='dataUnsup')
val = mx.io.NDArrayIter(
    to4d(val_img), val_lbl, args.batch_size*1) # use larger test batch size

[5840 5656  576 3662 3626 4448 4352 5659 4106 4176]
[1103 5635 3920 6070 6075 3763 6509 1527 1191 2650]
[5128 5073 2352 2550  164 5421 5858 2597 1559 3745]
[1750 1338 2470 1011 2191  774 4962 5773 6060 1810]
[2295 5622 5409  555 5314 2464  909 3597 1125 3279]
[5183 2728 2781 3977  345 4871 4113 1403 2388 1468]
[5142 1629  953 4917  950  672 2646 5433 3551 5325]
[5726 2479 1952 6120 4552 2228 2253 4124 4963 3627]
[4854  222  260 4818 1262 3203 5607  990 5167 1794]
[3534   41 4560 5218 4152  972 5041 2605  640 3111]


In [10]:
from common.multi_iterator import Multi_iterator
    
train = Multi_iterator(train_sup, train_unsup, unsup_multiplier, num_unsup_examples, num_sup_examples)

600  times more unsup data than sup data


In [11]:
from symbols import mnist
from common.lba import compute_semisup_loss, logit_loss
from common.wrapper_module import get_embedding_shapes

# kvstore
kv = mx.kvstore.create(args.kv_store)

t_nb = args.batch_size * unsup_multiplier
initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="avg", magnitude=2.34)

lr, lr_scheduler = fit._get_lr_scheduler(args, kv)

def buildEmbeddingModule():
    dataSup = mx.symbol.Variable(name="dataSup")
    dataUnsup = []
    for i in range(unsup_multiplier):
        dataUnsup.append(Variable(name="dataUnsup"+str(i)))

    # concat data, feed both through the network
    # then split it up again
    data = concat(dataSup, *dataUnsup, dim=0)

    embeddings = mnist.build_embeddings(data)
    splitted = split(embeddings, num_outputs=(unsup_multiplier+1), axis=0, name='split')
    
    # devices for training
    devs = mx.cpu() if args.gpus is None or args.gpus is '' else [
        mx.gpu(int(i)) for i in args.gpus.split(',')]

    data_names = ['dataSup'] + ['dataUnsup'+str(i) for i in range(unsup_multiplier)]
        
    # create model
    model = mx.mod.Module(
        context       = devs,
        symbol        = splitted,
        data_names    = data_names,
        label_names   = None)
    
    model.bind(data_shapes=train.provide_data)         
    model.init_params(initializer)    
    model.init_optimizer(optimizer='adam', optimizer_params=(
        ('learning_rate', lr), 
        ('wd', 1e-4),
       # ('rescale_grad', 0.1),
        ('lr_scheduler', lr_scheduler)))

    return model


def buildLossModule():
    supEmbeddings = Variable(name="embeddings_sup")
    labelSup = mx.symbol.Variable(name='labelSup')
    overall_loss = []
    
    if unsup_multiplier >= 1:
        unsupEmbeddings = []
        for i in range(unsup_multiplier):
            unsupEmbeddings.append(Variable(name="embeddings_unsup"+str(i)))
    
        unsupEmbeddings = concat(*unsupEmbeddings, dim=0)
            
        (walker_loss, visit_loss) = compute_semisup_loss(supEmbeddings, unsupEmbeddings, labelSup, t_nb, 
                                                     walker_weight=1.0, visit_weight=1.0)
        overall_loss = [walker_loss, visit_loss]
        
    overall_loss = [logit_loss(supEmbeddings, labelSup, 10)] + overall_loss
    
    # todo maybe use gpu
    devs = mx.cpu() if args.loss_gpu is None or args.loss_gpu is '' else mx.gpu(args.loss_gpu)

    # create module
    model = mx.mod.Module(
        context = devs,
        symbol  = Group(overall_loss),
        data_names = ['embeddings_sup']+['embeddings_unsup'+str(i) for i in range(unsup_multiplier)],
        label_names = ['labelSup'])
    
    # allocate memory by given the input data and label shapes
    model.bind(data_shapes=get_embedding_shapes(args.batch_size, args.nembeddings, unsup_multiplier), 
               label_shapes=train.provide_label,
               inputs_need_grad=True)
    
    model.init_params(initializer)  
    model.init_optimizer(optimizer='adam', optimizer_params=(
        ('learning_rate', lr), 
       # ('rescale_grad', 0.1),
        ('wd', 1e-4),
        ('lr_scheduler', lr_scheduler)))
    
    return model

[5000.0, 10000.0, 15000.0]


In [12]:
from common.wrapper_module import WrapperModule
#eval_metrics = Multi_Accuracy(num= 3 if unsup_multiplier >= 1 else 1)
                    
def fit_model(args, data, **kwargs):
    """
    train a model
    args : argparse returns
    data_loader : function that returns the train and val data iterators
    """
    # logging
    head = '%(asctime)-15s Node[' + str(kv.rank) + '] %(message)s'
    logging.basicConfig(level=logging.DEBUG, format=head)
    logging.info('start with arguments %s', args)
    
    # data iterators
    (train, val) = data

    # load model
    if 'arg_params' in kwargs and 'aux_params' in kwargs:
        arg_params = kwargs['arg_params']
        aux_params = kwargs['aux_params']
    else: 
        sym, arg_params, aux_params = fit._load_model(args, kv.rank)
        
    # save model
    checkpoint = fit._save_model(args, kv.rank)

    model = WrapperModule(buildEmbeddingModule(), buildLossModule(), unsup_multiplier)
    
    def validate_model(epoch, *args):
        if epoch % val_interval != 0: 
            return
        res = model.score(val)
        #TODO: pull this into default
        print('Epoch[%d] Validation-accuracy=%f' % (epoch,  res))

    #monitor = mx.mon.Monitor(interval=1000, pattern='.*aba_backward.*') 
    monitor = mx.mon.Monitor(interval=1000, pattern='.*') 
    
    # callbacks that run after each batch
    batch_end_callbacks = [mx.callback.Speedometer(args.batch_size, args.disp_batches)]
    
    epoch_end_callbacks = validate_model

    # run
    model.fit(train,
        begin_epoch        = args.load_epoch if args.load_epoch else 0,
        num_epoch          = args.num_epochs,
        #eval_data          = val,
        #eval_metric        = eval_metrics,
        kvstore            = kv,
        arg_params         = arg_params,
        aux_params         = aux_params,
        batch_end_callback = batch_end_callbacks,
        epoch_end_callback = epoch_end_callbacks,
        allow_missing      = True
        #monitor            = monitor
             )

In [13]:
# training is here
train.reset()

start = time.time()
fit_model(args, (train, val))
print(time.time() - start)

Epoch[0] Validation-accuracy=0.093400
Epoch[100] Validation-accuracy=0.928100
Epoch[200] Validation-accuracy=0.949300
Epoch[300] Validation-accuracy=0.960000
Epoch[400] Validation-accuracy=0.964100
Epoch[500] Validation-accuracy=0.971000
Epoch[600] Validation-accuracy=0.973300
Epoch[700] Validation-accuracy=0.976300
Epoch[800] Validation-accuracy=0.977200
Epoch[900] Validation-accuracy=0.977800
Epoch[1000] Validation-accuracy=0.979400
Epoch[1100] Validation-accuracy=0.978700
Epoch[1200] Validation-accuracy=0.980500
Epoch[1300] Validation-accuracy=0.980200
Epoch[1400] Validation-accuracy=0.981300
Epoch[1500] Validation-accuracy=0.980200
Epoch[1600] Validation-accuracy=0.982300
Epoch[1700] Validation-accuracy=0.983000
Epoch[1800] Validation-accuracy=0.983600
Epoch[1900] Validation-accuracy=0.982700
Epoch[2000] Validation-accuracy=0.984200
Epoch[2100] Validation-accuracy=0.984200
Epoch[2200] Validation-accuracy=0.984600
Epoch[2300] Validation-accuracy=0.983100
Epoch[2400] Validation-accur

KeyboardInterrupt: 

In [None]:
a = []
isinstance(3, list)

In [None]:
class Multi_Accuracy(mx.metric.EvalMetric):
    """Calculate accuracies of multi label"""

    def __init__(self, num=None):
        super(Multi_Accuracy, self).__init__('multi-accuracy', num)

    def update(self, labels, preds):

        #for i in range(len(preds)):
        for i in range(1):
            pred_label = mx.nd.argmax_channel(preds[i]).asnumpy().astype('int32')
            label = labels[0].asnumpy().astype('int32')

            #mx.metric.check_label_shapes(label, pred_label)
            
            #print((pred_label.flat == label.flat).sum())
            #print(len(pred_label.flat))

            
            self.sum_metric[i] += (pred_label.flat == label.flat).sum()
            self.num_inst[i] += len(pred_label.flat)

In [None]:
print(labels.asnumpy())
model.get_outputs()[1].asnumpy()[0:10,0:10]