In [1]:
import os
import argparse
import logging
logging.basicConfig(level=logging.INFO)
from common import data, fit
from common.util import download_file
import mxnet as mx
import numpy as np
from mxnet.symbol import *

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
def download_cifar100():
    data_dir="/usr/stud/plapp/data/cifar100_mxnet"
    #data_dir="/data/cifar-100-mxnet"
    fnames = (os.path.join(data_dir, "train.rec"),
              os.path.join(data_dir, "test.rec"))
    return fnames

In [4]:
(train_fname, val_fname) = download_cifar100()

# parse args
parser = argparse.ArgumentParser(description="train cifar100",
                                 formatter_class=argparse.ArgumentDefaultsHelpFormatter)
fit.add_fit_args(parser)
data.add_data_args(parser)
data.add_data_aug_args(parser)
data.set_data_aug_level(parser, 2)

num_unsup_examples = 50000
subset_factor = 5

parser.set_defaults(
    # data
    data_train     = train_fname,
    data_val       = val_fname,
    num_classes    = 100,
    image_shape    = '3,28,28',
    log_prefix     = './logs/cifar_b250_tree_semisup',
    prefix         = './checkpoints/cifar_b250_tree_semisup',
    pad_size       = 2,
    # train
    batch_size     = 50,  # todo currently has to devide 'num_validation_samples'
    num_epochs     = 200,
    lr_step_epochs = '80,120',  # this setting should converge to a good result
    gpus           = '0',
    loss_gpu       = 0,
    disp_batches   = 40,
    num_examples   = num_unsup_examples / subset_factor, 
    wd             = 1e-4,
    lr             = .03,
    lr_factor      = .33,
    nembeddings    = 256,
    optimizer      = 'sgd',
    max_tree_depth = 2
)

args = parser.parse_args("")

unsup_multiplier = 1
labeled_per_class = 500 / subset_factor

sample_seed = 47
val_interval = 1
save_interval = 10

In [5]:
from data.tree import TreeNode, TreeStructure

nodes = []

# tree from coarse labels
for i in range(20):
    n = TreeNode("superclass " + str(i), leafs=range(i * 5, i * 5 + 5))
    nodes = nodes + [n]

root = TreeNode("root", children=nodes)

#for testing: tree with only leafs
#coarseTree = TreeStructure(TreeNode("root", leafs=range(100)))

tree = TreeStructure(root)

In [6]:
from common.multi_iterator import Multi_iterator
from common.tree_iterator import Tree_iterator
from common.data import get_partial_rec_iter

(train_sup, val) = get_partial_rec_iter(args, get_val=True, devide_by=subset_factor, shuffle=True)
(train_unsup, _) = get_partial_rec_iter(args, get_val=False, devide_by=1, shuffle=True)

train_sup_tree = Tree_iterator(train_sup, tree, args.max_tree_depth)
val_tree = Tree_iterator(val, tree, args.max_tree_depth)
    
num_sup_examples = labeled_per_class * args.num_classes
train = Multi_iterator(train_sup_tree, train_unsup, unsup_multiplier, num_unsup_examples, num_sup_examples)

5  times more unsup data than sup data


In [7]:
from symbols import inception_cifar as base_net
from common.lba import compute_semisup_loss_tree, logit_loss_tree
from common.wrapper_module import get_embedding_shapes

# kvstore
kv = mx.kvstore.create(args.kv_store)

t_nb = args.batch_size * unsup_multiplier
#initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="avg", magnitude=2.34)
initializer = mx.init.Uniform(0.01)

lr, lr_scheduler = fit._get_lr_scheduler(args, kv)

def buildEmbeddingModule(arg_p=None, aux_p=None):
    dataSup = mx.symbol.Variable(name="dataSup")
    dataUnsup = []
    for i in range(unsup_multiplier):
        dataUnsup.append(Variable(name="dataUnsup"+str(i)))

    # concat data, feed both through the network
    # then split it up again
    data = concat(dataSup, *dataUnsup, dim=0)

    embeddings = base_net.build_embeddings(data, nembeddings=args.nembeddings)
    splitted = split(embeddings, num_outputs=(unsup_multiplier+1), axis=0, name='split')
    
    # devices for training
    devs = mx.cpu() if args.gpus is None or args.gpus is '' else [
        mx.gpu(int(i)) for i in args.gpus.split(',')]

    data_names = ['dataSup'] + ['dataUnsup'+str(i) for i in range(unsup_multiplier)]
        
    # create model
    model = mx.mod.Module(
        context       = devs,
        symbol        = splitted,
        data_names    = data_names,
        label_names   = None)
    
    model.bind(data_shapes=train.provide_data)         
    model.init_params(initializer, arg_p, aux_p)    
    model.init_optimizer(optimizer=args.optimizer, optimizer_params=(
       ('learning_rate', lr), 
       ('wd', 1e-4),
       #('momentum', args.mom),
       ('rescale_grad', 0.005),
       ('lr_scheduler', lr_scheduler)))

    return model


def buildLossModule(arg_p=None, aux_p=None):
    supEmbeddings = Variable(name="embeddings_sup")
    labels = [Variable(name='labelSup'+str(d)) for d in range(np.min([args.max_tree_depth, tree.depth]))]
    overall_loss = []
    
    if unsup_multiplier >= 1:
        unsupEmbeddings = []
        for i in range(unsup_multiplier):
            unsupEmbeddings.append(Variable(name="embeddings_unsup"+str(i)))
    
        unsupEmbeddings = concat(*unsupEmbeddings, dim=0)
            
        (walker_losses, visit_loss) = compute_semisup_loss_tree(supEmbeddings, unsupEmbeddings, labels, t_nb, 
                                                     tree, walker_weights=[1.,1.], visit_weight=0.5, 
                                                     maxDepth=args.max_tree_depth)
        overall_loss = walker_losses + [visit_loss]
        
    overall_loss = logit_loss_tree(supEmbeddings, labels, tree) + overall_loss
    
    # todo maybe use gpu
    devs = mx.cpu() if args.loss_gpu is None or args.loss_gpu is '' else mx.gpu(args.loss_gpu)

    # create module
    model = mx.mod.Module(
        context = devs,
        symbol  = Group(overall_loss),
        data_names = ['embeddings_sup']+['embeddings_unsup'+str(i) for i in range(unsup_multiplier)],
        label_names = ['labelSup'+str(d) for d in range(np.min([args.max_tree_depth, tree.depth]))])
    
    # allocate memory by given the input data and label shapes
    model.bind(data_shapes=get_embedding_shapes(args.batch_size, args.nembeddings, unsup_multiplier), 
               label_shapes=train.provide_label,
               inputs_need_grad=True)
            
    model.init_params(initializer, arg_p, aux_p)    
    model.init_optimizer(optimizer=args.optimizer, optimizer_params=(
        ('learning_rate', lr), 
        ('rescale_grad', 0.005),
        #('momentum', args.mom),
        ('wd', 1e-4),
        ('lr_scheduler', lr_scheduler)))
    
    return model

[16000.0, 24000.0]


In [8]:
from common.wrapper_module import WrapperModule
from common.lba import MultiAccuracy
#eval_metrics = Multi_Accuracy(num= 3 if unsup_multiplier >= 1 else 1)
                    
def fit_model(args, embeddingModule, lossModule, data, **kwargs):
    """
    train a model
    args : argparse returns
    data_loader : function that returns the train and val data iterators
    """
    # logging
    head = '%(asctime)-15s Node[' + str(kv.rank) + '] %(message)s'
    logging.basicConfig(level=logging.DEBUG, format=head)
    logging.info('start with arguments %s', args)
    
    # data iterators
    (train, val) = data

    # load model
    if 'arg_params' in kwargs and 'aux_params' in kwargs:
        arg_params = kwargs['arg_params']
        aux_params = kwargs['aux_params']
    else: 
        sym, arg_params, aux_params = fit._load_model(args, kv.rank)
        
    # save model
    checkpoint = fit._save_model(args, kv.rank)

    model = WrapperModule(embeddingModule, lossModule, unsup_multiplier)
    
    logf = open(args.log_prefix+'logs', 'w')
    def validate_model(epoch, *args_):
        if epoch % val_interval != 0: 
            return
        res = model.score(val)
        #TODO: pull this into default
        print('Epoch[%d] Validation-accuracy=%f' % (epoch,  res))
        logf.write('Epoch[%d] Validation-accuracy=%f \n' % (epoch,  res))  # python will convert \n to os.linesep
        logf.flush()

        if epoch % save_interval == 0:
            model.save_checkpoint(args.prefix, epoch)

    #monitor = mx.mon.Monitor(interval=1000, pattern='.*aba_backward.*') 
    monitor = mx.mon.Monitor(interval=1000, pattern='.*') 
    
    # callbacks that run after each batch
    batch_end_callbacks = [mx.callback.Speedometer(args.batch_size*(unsup_multiplier+1), args.disp_batches)]
    
    epoch_end_callbacks = validate_model
    ma = MultiAccuracy(num=2)

    # run
    model.fit(train,
        begin_epoch        = args.load_epoch if args.load_epoch else 0,
        num_epoch          = args.num_epochs,
        #eval_data          = val,
        eval_metric        = ma,
        kvstore            = kv,
        arg_params         = arg_params,
        aux_params         = aux_params,
        batch_end_callback = batch_end_callbacks,
        epoch_end_callback = epoch_end_callbacks,
        allow_missing      = True
        #monitor            = monitor
             )
    logf.close()
    return model

In [None]:
train.reset()

# train using a checkpoint with 20% validation accuracy, trained only supervised
#(sym, arg_p, aux_p) = mx.model.load_checkpoint('embedding_val30',25)
#embeddingModule = buildEmbeddingModule(arg_p, aux_p)      

# train from previous iteration
(arg_p_emb, aux_p_emb, arg_p_loss, aux_p_loss) = WrapperModule.load_checkpoint('cifar_b250_tree_semisup', 10)
embeddingModule = buildEmbeddingModule(arg_p_emb, aux_p_emb)      
lossModule = buildLossModule(arg_p_loss, aux_p_loss)
# train from scratch
#embeddingModule = buildEmbeddingModule()

#(semi)supervised loss module
#lossModule = buildLossModule()

m = fit_model(args, embeddingModule, lossModule, (train, val_tree))

INFO:root:start with arguments Namespace(batch_size=50, benchmark=0, data_nthreads=4, data_train='/usr/stud/plapp/data/cifar100_mxnet/train.rec', data_val='/usr/stud/plapp/data/cifar100_mxnet/test.rec', disp_batches=40, dtype='float32', gpus='0', image_shape='3,28,28', kv_store='device', load_epoch=None, log_prefix='./logs/cifar_b250_tree_semisup', loss_gpu=0, lr=0.03, lr_factor=0.33, lr_step_epochs='80,120', max_random_aspect_ratio=0, max_random_h=36, max_random_l=50, max_random_rotate_angle=0, max_random_s=50, max_random_scale=1, max_random_shear_ratio=0, max_tree_depth=2, min_random_scale=1, model_prefix=None, mom=0.9, monitor=0, nembeddings=256, network=None, num_classes=100, num_epochs=200, num_examples=10000.0, num_layers=None, optimizer='sgd', pad_size=2, prefix='./checkpoints/cifar_b250_tree_semisup', random_crop=1, random_mirror=1, rgb_mean='123.68,116.779,103.939', test_io=0, top_k=0, wd=0.0001)
INFO:root:Epoch[0] Batch [40]	Speed: 318.92 samples/sec	multi-accuracy_0=0.066829

acc level  0 0.0647
acc level  1 0.0157
Epoch[0] Validation-accuracy=0.015700


INFO:root:Saved checkpoint to "./checkpoints/cifar_b250_tree_semisup_emb-0000.params"
INFO:root:Saved checkpoint to "./checkpoints/cifar_b250_tree_semisup_loss-0000.params"
INFO:root:Epoch[1] Batch [40]	Speed: 279.59 samples/sec	multi-accuracy_0=0.081951	multi-accuracy_1=0.025854
INFO:root:Epoch[1] Batch [80]	Speed: 263.32 samples/sec	multi-accuracy_0=0.075500	multi-accuracy_1=0.025500
INFO:root:Epoch[1] Batch [120]	Speed: 257.94 samples/sec	multi-accuracy_0=0.079500	multi-accuracy_1=0.024500
INFO:root:Epoch[1] Batch [160]	Speed: 256.07 samples/sec	multi-accuracy_0=0.070000	multi-accuracy_1=0.024000
INFO:root:Epoch[1] Train-multi-accuracy_0=0.073333
INFO:root:Epoch[1] Train-multi-accuracy_1=0.025128
INFO:root:Epoch[1] Time cost=75.870


acc level  0 0.0559
acc level  1 0.0239
Epoch[1] Validation-accuracy=0.023900


INFO:root:Epoch[2] Batch [40]	Speed: 257.14 samples/sec	multi-accuracy_0=0.060000	multi-accuracy_1=0.026341
INFO:root:Epoch[2] Batch [80]	Speed: 261.41 samples/sec	multi-accuracy_0=0.057500	multi-accuracy_1=0.018500
INFO:root:Epoch[2] Batch [120]	Speed: 257.87 samples/sec	multi-accuracy_0=0.076500	multi-accuracy_1=0.025500
INFO:root:Epoch[2] Batch [160]	Speed: 256.44 samples/sec	multi-accuracy_0=0.071000	multi-accuracy_1=0.033000
INFO:root:Epoch[2] Train-multi-accuracy_0=0.075897
INFO:root:Epoch[2] Train-multi-accuracy_1=0.017949
INFO:root:Epoch[2] Time cost=77.617


acc level  0 0.0876
acc level  1 0.0369
Epoch[2] Validation-accuracy=0.036900


INFO:root:Epoch[3] Batch [40]	Speed: 260.05 samples/sec	multi-accuracy_0=0.082927	multi-accuracy_1=0.036585
INFO:root:Epoch[3] Batch [80]	Speed: 257.29 samples/sec	multi-accuracy_0=0.080500	multi-accuracy_1=0.027500
INFO:root:Epoch[3] Batch [120]	Speed: 262.91 samples/sec	multi-accuracy_0=0.075000	multi-accuracy_1=0.034000
INFO:root:Epoch[3] Batch [160]	Speed: 264.22 samples/sec	multi-accuracy_0=0.078000	multi-accuracy_1=0.029500
INFO:root:Epoch[3] Train-multi-accuracy_0=0.079487
INFO:root:Epoch[3] Train-multi-accuracy_1=0.029744
INFO:root:Epoch[3] Time cost=76.761


acc level  0 0.0872
acc level  1 0.0297
Epoch[3] Validation-accuracy=0.029700


INFO:root:Epoch[4] Batch [40]	Speed: 257.25 samples/sec	multi-accuracy_0=0.089268	multi-accuracy_1=0.040000


In [None]:
np.min([1,2])

In [None]:
a = train.next().label
len(a)
#len(a[0])

In [None]:
aux_p

In [None]:
embeddingModule.save_checkpoint('cifarsemisup200', 200)

In [None]:
lossModule.save_checkpoint('cifarsemisup200_loss', 200)

In [None]:
val.reset()
val.next().label[0].shape[0]