In [1]:
import os
import math
import mxnet as mx
from mxnet import image
from mxnet import nd, gluon, autograd, init
from mxnet.gluon.data.vision import ImageFolderDataset
from mxnet.gluon.data import DataLoader
from mxnet.gluon import nn
from tensorboardX import SummaryWriter
import numpy as np
import shutil
import _pickle as cPickle
from sklearn import preprocessing
from mxnet.gluon.parameter import Parameter, ParameterDict
from common.util import download_file
import subprocess
import time

from IPython.core.debugger import Tracer

In [2]:
def unpickle(file):
    fo = open(file, 'rb')
    dict = cPickle.load(fo, encoding='bytes')
    fo.close()
    return dict

In [3]:
class Options:
    def __init__(self):
        self.seed_val = 0
        self.num_train_sup = 128000
        self.batch_size = 8
        self.data_dir = '/tanData/datasets/imagenet/data'
        self.log_dir = '/tanData/logs'
        self.model_dir ='/tanData/models'
        self.exp_name = 'imagenet_nlabels_%i_allconv13_lddrm_mm_pathnorm_seed_%i'%(self.num_train_sup, self.seed_val)
        self.gpus = 4
        self.first_gpu = 0
        self.alpha_drm = 0.005
        self.alpha_pn = 0.01
        self.alpha_kl = 0.005
        self.alpha_nn = 0.005
        
        self.use_bias = True
        self.use_bn = True
        self.do_topdown = True
        self.do_countpath = False
        self.do_pn = True
        self.relu_td = False
        self.do_nn = True

opt = Options()

In [4]:
batch_size = opt.batch_size
batch_size *= max(1, opt.gpus)
ctx = [mx.gpu(i+opt.first_gpu) for i in range(opt.gpus)] if opt.gpus > 0 else [mx.cpu()]
kv = mx.kvstore.create('device')

In [5]:
def gpu_device(ctx=mx.gpu(0)):
    try:
        _ = mx.nd.array([1, 2, 3], ctx=ctx)
    except mx.MXNetError:
        return None
    return ctx

for i, gpu in enumerate(ctx):
    assert gpu_device(gpu), 'GPU device %i is not available!'%(i + self.first_gpu)

In [6]:
log_dir = os.path.join(opt.log_dir, opt.exp_name)
if not os.path.exists(log_dir):
    os.makedirs(log_dir)

In [7]:
train_data_unsup = mx.io.ImageRecordIter(
    path_imgrec=os.path.join(opt.data_dir, 'train_unsup_480.rec'),
    label_width=1,
    data_name='data',
    label_name='softmax_label',
    data_shape=(3, 224, 224),
    batch_size=batch_size // 2,
    mean_r=123.68,
    mean_g=116.779,
    mean_b=103.939,
    pad=0,
    fill_value=0,
    shuffle=True,
    rand_crop=True,
    rand_mirror=True,
    max_random_scale=1.0,
    min_random_scale=0.533,
    max_aspect_ratio=0.25,
    random_h=36,
    random_s=50,
    random_l=50,
    num_parts=kv.num_workers,
    part_index=kv.rank,
    preprocess_threads=32)

train_data_sup = mx.io.ImageRecordIter(
    path_imgrec=os.path.join(opt.data_dir, 'train_sup_480.rec'),
    label_width=1,
    data_name='data',
    label_name='softmax_label',
    data_shape=(3, 224, 224),
    batch_size=batch_size // 2,
    mean_r=123.68,
    mean_g=116.779,
    mean_b=103.939,
    pad=0,
    fill_value=0,
    shuffle=True,
    rand_crop=True,
    rand_mirror=True,
    max_random_scale=1.0,
    min_random_scale=0.533,
    max_aspect_ratio=0.25,
    random_h=36,
    random_s=50,
    random_l=50,
    num_parts=kv.num_workers,
    part_index=kv.rank,
    preprocess_threads=32)

valid_data = mx.io.ImageRecordIter(
    path_imgrec=os.path.join(opt.data_dir,'val_256.rec'),
    label_width=1,
    data_name='data',
    label_name='softmax_label',
    batch_size=batch_size // 2,
    data_shape=(3, 224, 224),
    rand_crop=False,
    rand_mirror=False,
    fill_value=0,
    mean_r=123.68,
    mean_g=116.779,
    mean_b=103.939,
    num_parts=kv.num_workers,
    part_index=kv.rank,
    preprocess_threads=32)

In [8]:
criterion = gluon.loss.SoftmaxCrossEntropyLoss()
L2_loss = gluon.loss.L2Loss()
L1_loss = gluon.loss.L1Loss()

In [9]:
class Normal(mx.init.Initializer):
    """Initializes weights with random values sampled from a normal distribution
    with a mean and standard deviation of `sigma`.
    """
    def __init__(self, mean=0, sigma=0.01):
        super(Normal, self).__init__(sigma=sigma)
        self.sigma = sigma
        self.mean = mean

    def _init_weight(self, _, arr):
        mx.random.normal(self.mean, self.sigma, out=arr)

In [10]:
# from resnet import ResNet164_v2
from mxnet.gluon.model_zoo import vision as models
from vgg_ld_opt_imagenet import VGG_DRM

In [11]:
acc_top1 = mx.metric.Accuracy()
acc_top5 = mx.metric.TopKAccuracy(5)

In [12]:
import datetime
writer = SummaryWriter(os.path.join(opt.log_dir, opt.exp_name))

def test(ctx, val_data):
    acc_top1_val = mx.metric.Accuracy()
    acc_top5_val = mx.metric.TopKAccuracy(5)
    for i, batch in enumerate(val_data):
        data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0)
        label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0)
        outputs = [net(X) for X in data]
        acc_top1_val.update(label, outputs)
        acc_top5_val.update(label, outputs)

    _, top1 = acc_top1_val.get()
    _, top5 = acc_top5_val.get()
    return (1 - top1, 1 - top5)

In [13]:
def train(net, train_data_unsup, train_data_sup, valid_data, num_epochs, lr, wd, ctx, lr_decay, relu_indx):
    trainer = gluon.Trainer(
        net.collect_params(), 'adam', {'learning_rate': 0.0001, 'wd': wd})
    
    prev_time = datetime.datetime.now()
    best_top1_err = np.inf
    best_top5_err = np.inf
    log_interval = 1
    
#     # Learning rate decay factor
#     lr_decay = 0.1
#     # Epochs where learning rate decays
#     lr_decay_epoch = [30, 60, 90, np.inf]

#     # Nesterov accelerated gradient descent
#     optimizer = 'nag'
#     # Set parameters
#     optimizer_params = {'learning_rate': 0.1, 'wd': 0.0001, 'momentum': 0.9}

#     # Define our trainer for net
#     trainer = gluon.Trainer(net.collect_params(), optimizer, optimizer_params)
    
    for epoch in range(num_epochs-100):
        train_data_unsup.reset(); train_data_sup.reset()
        
        tic = time.time()
        btic = time.time()
        acc_top1.reset()
        acc_top5.reset()
        train_loss = 0
        num_batch = 0
        
        if epoch == 100:
            train_data_unsup = mx.io.ImageRecordIter(
                path_imgrec=os.path.join(opt.data_dir, 'train_unsup_256.rec'%(opt.num_train_sup, opt.seed_val)),
                label_width=1,
                data_name='data',
                label_name='softmax_label',
                data_shape=(3, 224, 224),
                batch_size=batch_size // 2,
                mean_r=123.68,
                mean_g=116.779,
                mean_b=103.939,
                pad=0,
                fill_value=0,
                shuffle=True,
                rand_crop=True,
                rand_mirror=True,
                num_parts=kv.num_workers,
                part_index=kv.rank,
                preprocess_threads=32)

            train_data_sup = mx.io.ImageRecordIter(
                path_imgrec=os.path.join(opt.data_dir, 'train_sup_256.rec'%(opt.num_train_sup, opt.seed_val)),
                label_width=1,
                data_name='data',
                label_name='softmax_label',
                data_shape=(3, 224, 224),
                batch_size=batch_size // 2,
                mean_r=123.68,
                mean_g=116.779,
                mean_b=103.939,
                pad=0,
                fill_value=0,
                shuffle=True,
                rand_crop=True,
                rand_mirror=True,
                num_parts=kv.num_workers,
                part_index=kv.rank,
                preprocess_threads=32)
            
        
#         if epoch == 20:
#             sgd_lr = 0.15
#             decay_val = np.exp(np.log(sgd_lr / 0.0001) / (num_epochs - 2))
#             sgd_lr = sgd_lr * decay_val
#             trainer = gluon.Trainer(net.collect_params(), 'SGD', {'learning_rate': sgd_lr, 'wd': wd})
            
#         if epoch >= 20:
#             trainer.set_learning_rate(trainer.learning_rate / decay_val)
        
        for i, (batch_unsup, batch_sup) in enumerate(zip(train_data_unsup, train_data_sup)):
            assert batch_unsup.data[0].shape[0] == batch_sup.data[0].shape[0], "batch_unsup and batch_sup must have the same size"
            
            bs = batch_unsup.data[0].shape[0]
            
            data_unsup = gluon.utils.split_and_load(batch_unsup.data[0], ctx_list=ctx, batch_axis=0)
            label_unsup = gluon.utils.split_and_load(batch_unsup.label[0], ctx_list=ctx, batch_axis=0)
            data_sup = gluon.utils.split_and_load(batch_sup.data[0], ctx_list=ctx, batch_axis=0)
            label_sup = gluon.utils.split_and_load(batch_sup.label[0], ctx_list=ctx, batch_axis=0)
            
            loss = []
            outputs_unsup = []
            outputs_sup = []
            
            with autograd.record():
                for xuns, yuns, xsup, ysup in zip(data_unsup, label_unsup, data_sup, label_sup):
                    [output_unsup, xhat_unsup, _, loss_pn_unsup, loss_nn_unsup] = net(xuns)
                    loss_drm_unsup = L2_loss(xhat_unsup, xuns)
                    softmax_unsup = nd.softmax(output_unsup)
                    loss_kl_unsup = -nd.sum(nd.log(1000.0*softmax_unsup + 1e-8) * softmax_unsup, axis=1)
                    loss_unsup = opt.alpha_drm * loss_drm_unsup + opt.alpha_kl * loss_kl_unsup + opt.alpha_nn * loss_nn_unsup + opt.alpha_pn * loss_pn_unsup

                    [output_sup, xhat_sup, _, loss_pn_sup, loss_nn_sup] = net(xsup, ysup)
                    loss_xentropy_sup = criterion(output_sup, ysup)
                    loss_drm_sup = L2_loss(xhat_sup, xsup)
                    softmax_sup = nd.softmax(output_sup)
                    loss_kl_sup = -nd.sum(nd.log(1000.0*softmax_sup + 1e-8) * softmax_sup, axis=1)
                    loss_sup = loss_xentropy_sup + opt.alpha_drm * loss_drm_sup + opt.alpha_kl * loss_kl_sup + opt.alpha_nn * loss_nn_sup + opt.alpha_pn * loss_pn_sup

                    loss.append(loss_unsup + loss_sup)
                    outputs_sup.append(output_sup)
                    outputs_unsup.append(output_unsup)
                    
            for l in loss:
                l.backward()
                
            trainer.step(bs)
            
            acc_top1.update(label_unsup, outputs_unsup)
            acc_top5.update(label_sup, outputs_sup)
            train_loss += sum([l.sum().asscalar() for l in loss])
            num_batch += 1
            if log_interval and not (i + 1) % log_interval:
                _, top1 = acc_top1.get()
                _, top5 = acc_top5.get()
                err_top1, err_top5 = (1-top1, 1-top5)
                print('Epoch[%d] Batch [%d]     Speed: %f samples/sec   top1-err=%f     top5-err=%f'%(
                          epoch, i, batch_size*log_interval/(time.time()-btic), err_top1, err_top5))
                btic = time.time()
        
        _, top1 = acc_top1.get()
        _, top5 = acc_top5.get()
        err_top1, err_top5 = (1-top1, 1-top5)
        train_loss /= num_batch * batch_size
        writer.add_scalars('acc', {'train_top1': 1.0 - err_top1}, epoch)
        writer.add_scalars('acc', {'train_top5': 1.0 - err_top5}, epoch)
        
        err_top1_val, err_top5_val = test(ctx, valid_data)
        
        if err_top1_val < best_top1_err:
            best_top1_err = err_top1_val
            net.collect_params().save('%s/%s_best_top1.params'%(opt.model_dir, opt.exp_name))
        
        if err_top5_val < best_top5_err:
            best_top5_err = err_top5_val
            net.collect_params().save('%s/%s_best_top5.params'%(opt.model_dir, opt.exp_name))
        
        print('[Epoch %d] training: err-top1=%f err-top5=%f loss=%f'%(epoch, err_top1, err_top5, train_loss))
        print('[Epoch %d] time cost: %f'%(epoch, time.time()-tic))
        print('[Epoch %d] validation: err-top1=%f err-top5=%f'%(epoch, err_top1_val, err_top5_val))
        
        writer.add_scalars('acc', {'valid_top1': 1.0 - err_top1_val}, epoch)
        writer.add_scalars('acc', {'valid_top5': 1.0 - err_top5_val}, epoch)
        
        if not epoch % 50:
            net.collect_params().save('%s/%s_epoch_%i.params'%(opt.model_dir, opt.exp_name, epoch))
    
    return best_top1_err, best_top5_err

In [14]:
num_epochs = 500
learning_rate = 0.15
weight_decay = 5e-4
lr_decay = 0.1

In [15]:
def run_train(num_exp, ctx):
    valid_acc = 0
    for i in range(num_exp):
        ### CIFAR VGG_DRM
        model = VGG_DRM('VGG16', batch_size=opt.batch_size // 2, num_class=1000, use_bias=opt.use_bias, use_bn=opt.use_bn, do_topdown=opt.do_topdown, do_countpath=opt.do_countpath, do_pn=opt.do_pn, relu_td=opt.relu_td, do_nn=opt.do_nn)
        for param in model.collect_params().values():
            if param.name.find('conv') != -1 or param.name.find('dense') != -1:
                if param.name.find('weight') != -1:
                    param.initialize(init=mx.initializer.Xavier(), ctx=ctx)
                else:
                    param.initialize(init=mx.init.Zero(), ctx=ctx)
            elif param.name.find('batchnorm') != -1 or param.name.find('instancenorm') != -1:
                if param.name.find('gamma') != -1:
                    param.initialize(init=Normal(mean=1, sigma=0.02), ctx=ctx)
                else:
                    param.initialize(init=mx.init.Zero(), ctx=ctx)
            elif param.name.find('biasadder') != -1:
                param.initialize(init=mx.init.Zero(), ctx=ctx)
                
        # model.hybridize()
        
        relu_indx = []
        
        for i in range(len(model.features._children)):
            if model.features._children[i].name.find('relu') != -1:
                relu_indx.append(i)
                
        best_top1_err, best_top5_err = train(model, train_data_unsup, train_data_sup, valid_data, num_epochs, learning_rate, weight_decay, ctx, lr_decay, relu_indx)

In [16]:
run_train(1, ctx=ctx)

Epoch[0] Batch [0]     Speed: 2.009015 samples/sec   top1-err=1.000000     top5-err=1.000000
Epoch[0] Batch [1]     Speed: 25.790286 samples/sec   top1-err=1.000000     top5-err=1.000000
Epoch[0] Batch [2]     Speed: 27.574156 samples/sec   top1-err=1.000000     top5-err=1.000000
Epoch[0] Batch [3]     Speed: 25.973174 samples/sec   top1-err=1.000000     top5-err=1.000000
Epoch[0] Batch [4]     Speed: 24.721301 samples/sec   top1-err=1.000000     top5-err=1.000000
Epoch[0] Batch [5]     Speed: 26.846751 samples/sec   top1-err=1.000000     top5-err=1.000000
Epoch[0] Batch [6]     Speed: 28.411358 samples/sec   top1-err=1.000000     top5-err=1.000000
Epoch[0] Batch [7]     Speed: 27.875774 samples/sec   top1-err=1.000000     top5-err=1.000000
Epoch[0] Batch [8]     Speed: 28.821168 samples/sec   top1-err=1.000000     top5-err=1.000000
Epoch[0] Batch [9]     Speed: 28.266540 samples/sec   top1-err=1.000000     top5-err=1.000000
Epoch[0] Batch [10]     Speed: 26.632874 samples/sec   top1-e

KeyboardInterrupt: 