In [1]:
import argparse,logging, os, math
import mxnet as mx
from mxnet import image
from mxnet import nd, gluon, autograd, init
from mxnet.gluon.data.vision import ImageFolderDataset
from mxnet.gluon.data import DataLoader
from mxnet.gluon import nn
from tensorboardX import SummaryWriter
import numpy as np
import shutil
import _pickle as cPickle
from sklearn import preprocessing
from mxnet.gluon.parameter import Parameter, ParameterDict
import subprocess
import time

from IPython.core.debugger import Tracer

from gluon_se_resnext_w_d_maxmin import se_resnext

In [2]:
logger = logging.getLogger()
logger.setLevel(logging.INFO)

formatter = logging.Formatter('%(asctime)s - %(message)s')
console = logging.StreamHandler()
console.setFormatter(formatter)
logger.addHandler(console)

In [3]:
class Options:
    def __init__(self):
        self.gpus = '0,1,2,3,4,5,6,7' #the gpus will be used, e.g "0,1,2,3"
        self.data_dir = '/tanData/datasets/imagenet/data/imagenet_senet' #the input data directory
        self.log_dir = '/tanData/logs'
        self.model_dir ='/tanData/models'
        self.exp_name = 'exp1'
        self.data_type = 'imagenet' #the dataset type
        self.depth = 50 #the depth of resnet
        self.batch_size = 32 #the batch size
        self.num_group = 64 #the number of convolution groups
        self.drop_out = 0.0 #the probability of an element to be zeroed
        self.alpha_max = 0.5
        self.alpha_min = 0.5
        
        self.list_dir = './' #the directory which contain the training list file
        self.lr = 0.1 #initialization learning rate
        self.mom = 0.9 #momentum for sgd
        self.bn_mom = 0.9 #momentum for batch normlization
        self.wd = 0.0001 #weight decay for sgd
        self.workspace = 512 #memory space size(MB) used in convolution, 
                            #if xpu memory is oom, then you can try smaller vale, such as --workspace 256 
        self.num_classes = 1000 #the class number of your task
        self.aug_level = 2 # level 1: use only random crop and random mirror, 
                           #level 2: add scale/aspect/hsv augmentation based on level 1, 
                           #level 3: add rotation/shear augmentation based on level 2 
        self.num_examples = 1281167 # the number of training examples
        self.kv_store = 'device' # the kvstore type'
        self.model_load_epoch = 8 # load the model on an epoch using the model-load-prefix
        self.frequent = 50 # frequency of logging
        self.memonger = False # true means using memonger to save momory, https://github.com/dmlc/mxnet-memonger
        self.retrain = False # true means continue training
        
args = Options()

In [4]:
hdlr = logging.FileHandler('./log/log-se-resnext-{}-{}.log'.format(args.data_type, args.depth))
hdlr.setFormatter(formatter)
logger.addHandler(hdlr)
logging.info(args)

2018-09-04 03:55:29,673 - <__main__.Options object at 0x7fc08f3bb828>


In [5]:
kv = mx.kvstore.create(args.kv_store)
ctx = mx.cpu() if args.gpus is None else [mx.gpu(int(i)) for i in args.gpus.split(',')]
batch_size = args.batch_size
batch_size *= max(1, len(ctx))
begin_epoch = args.model_load_epoch if args.model_load_epoch else 0
if not os.path.exists("./model"):
    os.mkdir("./model")
model_prefix = "seresnext_{}_{}_{}_{}".format(args.data_type, args.depth, kv.rank, args.exp_name)
# model_prefix = "model/se-resnext-{}-{}-{}".format(args.data_type, args.depth, kv.rank)
arg_params = None
aux_params = None
if args.retrain:
    _, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, args.model_load_epoch)

In [6]:
train_data = mx.io.ImageRecordIter(
    path_imgrec         = os.path.join(args.data_dir, "train.rec") if args.data_type == 'cifar10' else
                          os.path.join(args.data_dir, "train_256_q90.rec") if args.aug_level == 1
                          else os.path.join(args.data_dir, "train_480_q90.rec") ,
    label_width         = 1,
    data_name           = 'data',
    label_name          = 'softmax_label',
    data_shape          = (3, 32, 32) if args.data_type=="cifar10" else (3, 224, 224),
    batch_size          = batch_size,
    pad                 = 4 if args.data_type == "cifar10" else 0,
    fill_value          = 127,  # only used when pad is valid
    rand_crop           = True,
    max_random_scale    = 1.0,  # 480 with imagnet, 32 with cifar10
    min_random_scale    = 1.0 if args.data_type == "cifar10" else 1.0 if args.aug_level == 1 else 0.533,  # 256.0/480.0=0.533, 256.0/384.0=0.667 256.0/256=1.0
    max_aspect_ratio    = 0 if args.data_type == "cifar10" else 0 if args.aug_level == 1 else 0.25, # 0.25
    random_h            = 0 if args.data_type == "cifar10" else 0 if args.aug_level == 1 else 36,  # 0.4*90
    random_s            = 0 if args.data_type == "cifar10" else 0 if args.aug_level == 1 else 50,  # 0.4*127
    random_l            = 0 if args.data_type == "cifar10" else 0 if args.aug_level == 1 else 50,  # 0.4*127
    max_rotate_angle    = 0 if args.aug_level <= 2 else 10,
    max_shear_ratio     = 0 if args.aug_level <= 2 else 0.0, #0.1 args.aug_level = 3
    rand_mirror         = True,
    shuffle             = True,
    num_parts           = kv.num_workers,
    part_index          = kv.rank)
val_data = mx.io.ImageRecordIter(
    path_imgrec         = os.path.join(args.data_dir, "val.rec") if args.data_type == 'cifar10' else
                          os.path.join(args.data_dir, "val_256_q90.rec"),
    label_width         = 1,
    data_name           = 'data',
    label_name          = 'softmax_label',
    batch_size          = batch_size,
    data_shape          = (3, 32, 32) if args.data_type=="cifar10" else (3, 224, 224),
    rand_crop           = False,
    rand_mirror         = False,
    num_parts           = kv.num_workers,
    part_index          = kv.rank)

In [7]:
class Normal(mx.init.Initializer):
    """Initializes weights with random values sampled from a normal distribution
    with a mean and standard deviation of `sigma`.
    """
    def __init__(self, mean=0, sigma=0.01):
        super(Normal, self).__init__(sigma=sigma)
        self.sigma = sigma
        self.mean = mean

    def _init_weight(self, _, arr):
        mx.random.normal(self.mean, self.sigma, out=arr)

In [8]:
def multi_factor_scheduler(begin_epoch, epoch_size, step=[30, 60, 90, 95, 110, 120], factor=0.1):
    step_ = [epoch_size * (x-begin_epoch) for x in step if x-begin_epoch > 0]
    return mx.lr_scheduler.MultiFactorScheduler(step=step_, factor=factor) if len(step_) else None

In [9]:
criterion = gluon.loss.SoftmaxCrossEntropyLoss()
acc_top1 = mx.metric.Accuracy()
acc_top5 = mx.metric.TopKAccuracy(5)
import datetime
writer = SummaryWriter(os.path.join(args.log_dir, args.exp_name))

In [10]:
def test(net, val_data, ctx):
    val_data.reset()
    
    acc_top1_val = mx.metric.Accuracy()
    acc_top5_val = mx.metric.TopKAccuracy(5)
    acc_top1_val_max = mx.metric.Accuracy()
    acc_top5_val_max = mx.metric.TopKAccuracy(5)
    acc_top1_val_min = mx.metric.Accuracy()
    acc_top5_val_min = mx.metric.TopKAccuracy(5)
    acc_top1_val.reset()
    acc_top5_val.reset()
    acc_top1_val_max.reset()
    acc_top5_val_max.reset()
    acc_top1_val_min.reset()
    acc_top5_val_min.reset()
    for i, batch in enumerate(val_data):
        data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0)
        label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0)
        
        outputs = []
        outputsmax = []
        outputsmin = []
        for x in data:
            zmax, zmin = net(x)
            z = args.alpha_max * zmax + args.alpha_min * zmin
            outputs.append(z)
            outputsmax.append(zmax)
            outputsmin.append(zmin)
            
            
        acc_top1_val.update(label, outputs)
        acc_top5_val.update(label, outputs)
        acc_top1_val_max.update(label, outputsmax)
        acc_top5_val_max.update(label, outputsmax)
        acc_top1_val_min.update(label, outputsmin)
        acc_top5_val_min.update(label, outputsmin)

    _, top1 = acc_top1_val.get()
    _, top5 = acc_top5_val.get()
    _, top1max = acc_top1_val_max.get()
    _, top5max = acc_top5_val_max.get()
    _, top1min = acc_top1_val_min.get()
    _, top5min = acc_top5_val_min.get()
    return (top1, top5, top1max, top5max, top1min, top5min)

In [11]:
def train(net, train_data, val_data, num_epochs, ctx):
    epoch_size = max(int(args.num_examples / batch_size / kv.num_workers), 1)
    lr_sch = multi_factor_scheduler(begin_epoch, epoch_size, step=[30, 60, 90, 95, 110, 120], factor=0.1)
    trainer = gluon.Trainer(net.collect_params(), 'nag', {'learning_rate':args.lr, 'momentum':args.mom, 'wd':args.wd, 'lr_scheduler': lr_sch})
    
    prev_time = datetime.datetime.now()
    best_top1_val = 0.; best_top1_valmax = 0.; best_top1_valmin = 0.
    best_top5_val = 0.; best_top5_valmax = 0.; best_top5_valmin = 0.
    log_interval = 500
    
    for epoch in range(begin_epoch, num_epochs):
        train_data.reset()
        
        tic = time.time()
        btic = time.time()
        acc_top1.reset()
        acc_top5.reset()
        train_loss = 0
        num_batch = 0
        
        for i, batch in enumerate(train_data):
            bs = batch.data[0].shape[0]
            
            data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0)
            label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0)
            
            loss = []
            outputs = []
            
            with autograd.record():
                for x, y in zip(data, label):
                    zmax, zmin = net(x)
                    loss_xent = args.alpha_max * criterion(zmax, y) + args.alpha_min * criterion(zmin, y)
                    z = args.alpha_max * zmax + args.alpha_min * zmin

                    loss.append(loss_xent)
                    outputs.append(z)
                    
            for l in loss:
                l.backward()
                
            trainer.step(bs)
            
            acc_top1.update(label, outputs)
            acc_top5.update(label, outputs)
            train_loss += sum([l.sum().asscalar() for l in loss])
            num_batch += 1
            if log_interval and not i % log_interval:
                _, top1 = acc_top1.get()
                _, top5 = acc_top5.get()
                logging.info('Epoch[%d] Batch [%d] Lr: %f     Speed: %f samples/sec   top1-acc=%f     top5-acc=%f'%(
                          epoch, i, trainer.learning_rate, batch_size*log_interval/(time.time()-btic), top1, top5))
                btic = time.time()
        
        _, top1 = acc_top1.get()
        _, top5 = acc_top5.get()
        train_loss /= num_batch * batch_size
        writer.add_scalars('acc', {'train_top1': top1}, epoch)
        writer.add_scalars('acc', {'train_top5': top5}, epoch)
        
        top1_val, top5_val, top1_valmax, top5_valmax, top1_valmin, top5_valmin = test(net=net, val_data=val_data, ctx=ctx)
        
        if top1_val > best_top1_val:
            best_top1_val = top1_val
            net.collect_params().save('%s/%s_best_top1.params'%(args.model_dir, model_prefix))
        
        if top1_valmax > best_top1_valmax:
            best_top1_valmax = top1_valmax
            net.collect_params().save('%s/%s_best_top1_max.params'%(args.model_dir, model_prefix))
            
        if top1_valmin > best_top1_valmin:
            best_top1_valmin = top1_valmin
            net.collect_params().save('%s/%s_best_top1_min.params'%(args.model_dir, model_prefix))
        
        if top5_val > best_top5_val:
            best_top5_val = top5_val
            net.collect_params().save('%s/%s_best_top5.params'%(args.model_dir, model_prefix))
        
        if top5_valmax > best_top5_valmax:
            best_top5_valmax = top5_valmax
            net.collect_params().save('%s/%s_best_top5_max.params'%(args.model_dir, model_prefix))
        
        if top5_valmin > best_top5_valmin:
            best_top5_valmin = top5_valmin
            net.collect_params().save('%s/%s_best_top5_min.params'%(args.model_dir, model_prefix))
        
        logging.info('[Epoch %d] training: acc-top1=%f acc-top5=%f loss=%f lr=%f'%(epoch, top1, top5, train_loss, trainer.learning_rate))
        logging.info('[Epoch %d] time cost: %f'%(epoch, time.time()-tic))
        logging.info('[Epoch %d] validation: acc-top1=%f acc-top5=%f best-acc-top1=%f best-acc-top5=%f'%(epoch, top1_val, top5_val, best_top1_val, best_top5_val))
        logging.info('[Epoch %d] validation: acc-top1-max=%f acc-top5-max=%f best-acc-top1-max=%f best-acc-top5-max=%f'%(epoch, top1_valmax, top5_valmax, best_top1_valmax, best_top5_valmax))
        logging.info('[Epoch %d] validation: acc-top1-min=%f acc-top5-min=%f best-acc-top1-min=%f best-acc-top5-min=%f'%(epoch, top1_valmin, top5_valmin, best_top1_valmin, best_top5_valmin))
        
        writer.add_scalars('acc', {'valid_top1': top1_val}, epoch)
        writer.add_scalars('acc', {'valid_top5': top5_val}, epoch)
        writer.add_scalars('acc', {'valid_top1_max': top1_valmax}, epoch)
        writer.add_scalars('acc', {'valid_top5_max': top5_valmax}, epoch)
        writer.add_scalars('acc', {'valid_top1_min': top1_valmin}, epoch)
        writer.add_scalars('acc', {'valid_top5_min': top5_valmin}, epoch)
        
        net.collect_params().save('%s/%s_current.params'%(args.model_dir, model_prefix))
        if not epoch % 10:
            net.collect_params().save('%s/%s_epoch_%i.params'%(args.model_dir, model_prefix, epoch))
    
    return best_top1_val, best_top5_val

In [12]:
ratio_list = [0.25, 0.125, 0.0625, 0.03125]   # 1/4, 1/8, 1/16, 1/32
if args.depth == 18:
    units = [2, 2, 2, 2]
elif args.depth == 34:
    units = [3, 4, 6, 3]
elif args.depth == 50:
    units = [3, 4, 6, 3]
elif args.depth == 101:
    units = [3, 4, 23, 3]
elif args.depth == 152:
    units = [3, 8, 36, 3]
elif args.depth == 200:
    units = [3, 24, 36, 3]
elif args.depth == 269:
    units = [3, 30, 48, 8]
else:
    raise ValueError("no experiments done on detph {}, you can do it youself".format(args.depth))

num_epochs = 200 if args.data_type == "cifar10" else 125

In [13]:
def run_train(ctx):        
    model = se_resnext(units=units, num_stage=4, filter_list=[64, 256, 512, 1024, 2048] if args.depth >=50 else [64, 64, 128, 256, 512], ratio_list=ratio_list, num_class=args.num_classes, num_group=args.num_group, data_type="imagenet", drop_out=args.drop_out, bn_mom=args.bn_mom)
    model.collect_params().load('/tanData/models/seresnext_imagenet_50_0_exp1_current.params', ctx=ctx)
#     for param in model.collect_params().values():
#         if param.name.find('conv') != -1 or param.name.find('dense') != -1:
#             if param.name.find('weight') != -1:
#                 param.initialize(init=mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2), ctx=ctx)
#             else:
#                 param.initialize(init=mx.init.Zero(), ctx=ctx)
#         elif param.name.find('batchnorm') != -1:
#             if param.name.find('gamma') != -1:
#                 param.initialize(init=Normal(mean=1, sigma=0.02), ctx=ctx)
#             else:
#                 param.initialize(init=mx.init.Zero(), ctx=ctx)
#         elif param.name.find('biasadder') != -1:
#             param.initialize(init=mx.init.Zero(), ctx=ctx)
#         else:
#             param.initialize(init=mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2), ctx=ctx)
                  
    model.hybridize()
        
    best_top1_val, best_top5_val = train(net=model, train_data=train_data, val_data=val_data, num_epochs=num_epochs, ctx=ctx)

In [None]:
run_train(ctx=ctx)

2018-09-04 03:56:23,964 - Epoch[8] Batch [0] Lr: 0.100000     Speed: 3939.521365 samples/sec   top1-acc=0.457031     top5-acc=0.695312
2018-09-04 04:03:53,888 - Epoch[8] Batch [500] Lr: 0.100000     Speed: 284.497977 samples/sec   top1-acc=0.455651     top5-acc=0.701456
2018-09-04 04:11:19,034 - Epoch[8] Batch [1000] Lr: 0.100000     Speed: 287.554524 samples/sec   top1-acc=0.454421     top5-acc=0.701026
2018-09-04 04:18:45,976 - Epoch[8] Batch [1500] Lr: 0.100000     Speed: 286.398349 samples/sec   top1-acc=0.455514     top5-acc=0.701938
2018-09-04 04:26:09,098 - Epoch[8] Batch [2000] Lr: 0.100000     Speed: 288.863464 samples/sec   top1-acc=0.454821     top5-acc=0.701977
2018-09-04 04:33:30,467 - Epoch[8] Batch [2500] Lr: 0.100000     Speed: 290.014732 samples/sec   top1-acc=0.454801     top5-acc=0.701641
2018-09-04 04:40:55,000 - Epoch[8] Batch [3000] Lr: 0.100000     Speed: 287.947042 samples/sec   top1-acc=0.455490     top5-acc=0.701825
2018-09-04 04:48:23,477 - Epoch[8] Batch [35