In [1]:
import os
import math
import mxnet as mx
from mxnet import image
from mxnet import nd, gluon, autograd, init
from mxnet.gluon.data.vision import ImageFolderDataset
from mxnet.gluon.data import DataLoader
from mxnet.gluon import nn
from tensorboardX import SummaryWriter
import numpy as np
import shutil
import _pickle as cPickle
from sklearn import preprocessing
from mxnet.gluon.parameter import Parameter, ParameterDict
from common.util import download_file
import subprocess

from IPython.core.debugger import Tracer

In [2]:
class Options:
    def __init__(self):
        self.seed_val = 0
        self.num_train_sup = 1000
        self.batch_size = 128
        self.data_dir = '/tanData/datasets/svhn'
        self.log_dir = '/tanData/logs'
        self.model_dir ='/tanData/models'
        self.exp_name = 'svhn_nlabels_%i_allconv13_lddrm_mm_pathnorm_seed_%i'%(self.num_train_sup, self.seed_val)
        self.ctx = mx.gpu(5)
        self.alpha_drm = 0.5
        self.alpha_pn = 1.0
        self.alpha_kl = 0.5
        self.alpha_nn = 0.5
        
        self.use_bias = True
        self.use_bn = True
        self.do_topdown = True
        self.do_countpath = False
        self.do_pn = True
        self.relu_td = False
        self.do_nn = True

opt = Options()

In [3]:
with open(os.path.join(opt.data_dir, 'svhn_train.lst'), 'r') as f:
    lines = f.readlines()
    tokens = [i.rstrip().split('\t') for i in lines]
    idx_label = dict((int(idx), label) for idx, label, _ in tokens)
    
labels = set(idx_label.values())

num_train = len(os.listdir(os.path.join(opt.data_dir, 'train')))

num_train_sup = opt.num_train_sup

num_train_sup_per_label = num_train_sup // len(labels)

ratio_unsup_sup = (num_train - num_train_sup) // num_train_sup

# select labeled data
data_rng = np.random.RandomState(opt.seed_val)
indx = data_rng.permutation(range(0, num_train))
indx_sup = []

for c in labels:
    c_count = 0
    for i in indx:
        if idx_label[i] == c and c_count < num_train_sup_per_label:
            indx_sup.append(i)
            c_count += 1
        if c_count >= num_train_sup_per_label:
            break
            
label_count = dict()
# print(indx_sup)
# for i in indx_sup:
#     print(idx_label[i])

def mkdir_if_not_exist(path):
    if not os.path.exists(os.path.join(*path)):
        os.makedirs(os.path.join(*path))
        
for train_file in os.listdir(os.path.join(opt.data_dir, 'train')):
    idx = int(train_file.split('.')[0])
    label = idx_label[idx]
    if idx in indx_sup:
        mkdir_if_not_exist([opt.data_dir, 'train_valid_sup_nsup_%i_seed_%i'%(opt.num_train_sup, opt.seed_val), label])
        for i in range(ratio_unsup_sup):
            shutil.copy(os.path.join(opt.data_dir,'train', train_file),
                        os.path.join(opt.data_dir, 'train_valid_sup_nsup_%i_seed_%i'%(opt.num_train_sup, opt.seed_val), label, '%i_%i.png'%(idx, i)))
    else:
        mkdir_if_not_exist([opt.data_dir, 'train_valid_unsup_nsup_%i_seed_%i'%(opt.num_train_sup, opt.seed_val), label])
        shutil.copy(os.path.join(opt.data_dir,'train', train_file),
                   os.path.join(opt.data_dir, 'train_valid_unsup_nsup_%i_seed_%i'%(opt.num_train_sup, opt.seed_val), label))

In [4]:
%%bash -s "$opt.num_train_sup" "$opt.seed_val"
DATA_DIR=/tanData/datasets/svhn
data_name=( "train_valid_sup_nsup_$1_seed_$2" "train_valid_unsup_nsup_$1_seed_$2" )
list_name=( "svhn_train_valid_sup_nsup_$1_seed_$2" "svhn_train_valid_unsup_nsup_$1_seed_$2" )
MX_DIR=/mxnet

for ((i=0;i<${#data_name[@]};++i)); do
    # clean stuffs
    rm -rf ${DATA_DIR}/${list_name[i]}.*
    # make list for all classes
    python ${MX_DIR}/tools/im2rec.py --list --exts '.png' --recursive ${DATA_DIR}/${list_name[i]} ${DATA_DIR}/${data_name[i]}
    # make .rec file for all classes
    python ${MX_DIR}/tools/im2rec.py --exts '.png' --quality 95 --num-thread 16 --color 1 ${DATA_DIR}/${list_name[i]} ${DATA_DIR}/${data_name[i]}
    # remove folders
    rm -rf ${DATA_DIR}/${data_name[i]}
done

0 0
1 1
2 2
3 3
4 4
5 5
6 6
7 7
8 8
9 9
Creating .rec file from /tanData/datasets/svhn/svhn_train_valid_sup_nsup_1000_seed_0.lst in /tanData/datasets/svhn
time: 0.00288820266724  count: 0
time: 0.0664539337158  count: 1000
time: 0.0585939884186  count: 2000
time: 0.0557899475098  count: 3000
time: 0.0568461418152  count: 4000
time: 0.05894780159  count: 5000
time: 0.0648090839386  count: 6000
time: 0.0610620975494  count: 7000
time: 0.0570049285889  count: 8000
time: 0.0585730075836  count: 9000
time: 0.0569579601288  count: 10000
time: 0.0546579360962  count: 11000
time: 0.0615811347961  count: 12000
time: 0.0521509647369  count: 13000
time: 0.0509130954742  count: 14000
time: 0.0539779663086  count: 15000
time: 0.0516760349274  count: 16000
time: 0.0499148368835  count: 17000
time: 0.0468111038208  count: 18000
time: 0.0478849411011  count: 19000
time: 0.0483860969543  count: 20000
time: 0.0504598617554  count: 21000
time: 0.0529749393463  count: 22000
time: 0.0521230697632  count: 2

In [5]:
def gpu_device(ctx=mx.gpu(0)):
    try:
        _ = mx.nd.array([1, 2, 3], ctx=ctx)
    except mx.MXNetError:
        return None
    return ctx

assert gpu_device(opt.ctx), 'No GPU device found!'

In [6]:
log_dir = os.path.join(opt.log_dir, opt.exp_name)
if not os.path.exists(log_dir):
    os.makedirs(log_dir)

In [7]:
data_shape = (3, 32, 32)
train_data_unsup = mx.io.ImageRecordIter(
    path_imgrec = os.path.join(opt.data_dir,'svhn_train_valid_unsup_nsup_%i_seed_%i.rec'%(opt.num_train_sup, opt.seed_val)),
    data_shape  = data_shape,
    batch_size  = opt.batch_size // 2,
    mean_r             = 115.4,
    mean_g             = 115.4,
    mean_b             = 119.6,
    std_r              = 56.0,
    std_g              = 57.8,
    std_b              = 58.3,
    shuffle = True,
    ## Data augmentation
    rand_crop   = True,
    max_crop_size = 32,
    min_crop_size = 32,
    pad = 4,
    fill_value = 0,
    rand_mirror = False)
train_data_sup = mx.io.ImageRecordIter(
    path_imgrec = os.path.join(opt.data_dir,'svhn_train_valid_sup_nsup_%i_seed_%i.rec'%(opt.num_train_sup, opt.seed_val)),
    data_shape  = data_shape,
    batch_size  = opt.batch_size // 2,
    mean_r             = 115.4,
    mean_g             = 115.4,
    mean_b             = 119.6,
    std_r              = 56.0,
    std_g              = 57.8,
    std_b              = 58.3,
    shuffle = True,
    ## Data augmentation
    rand_crop   = True,
    max_crop_size = 32,
    min_crop_size = 32,
    pad = 4,
    fill_value = 0,
    rand_mirror = False)
valid_data = mx.io.ImageRecordIter(
    path_imgrec = os.path.join(opt.data_dir,'svhn_val.rec'),
    data_shape  = data_shape,
    batch_size  = opt.batch_size // 2,
    mean_r             = 115.4,
    mean_g             = 115.4,
    mean_b             = 119.6,
    std_r              = 56.0,
    std_g              = 57.8,
    std_b              = 58.3,
    ## No data augmentation
    rand_crop   = False,
    rand_mirror = False)

In [8]:
criterion = gluon.loss.SoftmaxCrossEntropyLoss()
L2_loss = gluon.loss.L2Loss()
L1_loss = gluon.loss.L1Loss()

In [9]:
class Normal(mx.init.Initializer):
    """Initializes weights with random values sampled from a normal distribution
    with a mean and standard deviation of `sigma`.
    """
    def __init__(self, mean=0, sigma=0.01):
        super(Normal, self).__init__(sigma=sigma)
        self.sigma = sigma
        self.mean = mean

    def _init_weight(self, _, arr):
        mx.random.normal(self.mean, self.sigma, out=arr)

In [10]:
# from resnet import ResNet164_v2
from mxnet.gluon.model_zoo import vision as models
from vgg_ld_opt import VGG_DRM

In [11]:
import datetime
writer = SummaryWriter(os.path.join(opt.log_dir, opt.exp_name))

def get_acc(output, label):
    pred = output.argmax(1, keepdims=False)
    correct = (pred == label).sum()
    return correct.asscalar()

def extract_acts(net, x, layer_indx):
    start_layer = 0
    out = []
    for i in layer_indx:
        for block in net.features._children[start_layer:i]:
            x = block(x)
        out.append(x)
        start_layer = i
    return out

In [12]:
def train(net, train_data_unsup, train_data_sup, valid_data, num_epochs, lr, wd, ctx, lr_decay, relu_indx):
    trainer = gluon.Trainer(
        net.collect_params(), 'adam', {'learning_rate': 0.001, 'wd': wd})
    
    prev_time = datetime.datetime.now()
    best_valid_acc = 0
    iter_indx = 0
    
    for epoch in range(num_epochs-100):
        train_data_unsup.reset(); train_data_sup.reset()
        train_loss = 0; train_loss_xentropy = 0; train_loss_drm = 0; train_loss_pn = 0; train_loss_kl = 0; train_loss_nn = 0
        correct = 0; total = 0
        num_batch_train = 0
        
        if epoch == 20:
            sgd_lr = 0.15
            decay_val = np.exp(np.log(sgd_lr / 0.0001) / (num_epochs - 2))
            sgd_lr = sgd_lr * decay_val
            trainer = gluon.Trainer(net.collect_params(), 'SGD', {'learning_rate': sgd_lr, 'wd': wd})
            
        if epoch >= 20:
            trainer.set_learning_rate(trainer.learning_rate / decay_val)
        
        for batch_unsup, batch_sup in zip(train_data_unsup, train_data_sup):
            assert batch_unsup.data[0].shape[0] == batch_sup.data[0].shape[0], "batch_unsup and batch_sup must have the same size"
            bs = batch_unsup.data[0].shape[0]
            data_unsup = batch_unsup.data[0].as_in_context(ctx)
            label_unsup = batch_unsup.label[0].as_in_context(ctx)
            data_sup = batch_sup.data[0].as_in_context(ctx)
            label_sup = batch_sup.label[0].as_in_context(ctx)
            with autograd.record():
                [output_unsup, xhat_unsup, _, loss_pn_unsup, loss_nn_unsup] = net(data_unsup)
                loss_drm_unsup = L2_loss(xhat_unsup, data_unsup)
                softmax_unsup = nd.softmax(output_unsup)
                loss_kl_unsup = -nd.sum(nd.log(10.0*softmax_unsup + 1e-8) * softmax_unsup, axis=1)
                loss_unsup = opt.alpha_drm * loss_drm_unsup + opt.alpha_kl * loss_kl_unsup + opt.alpha_nn * loss_nn_unsup + opt.alpha_pn * loss_pn_unsup
                
                [output_sup, xhat_sup, _, loss_pn_sup, loss_nn_sup] = net(data_sup, label_sup)
                loss_xentropy_sup = criterion(output_sup, label_sup)
                loss_drm_sup = L2_loss(xhat_sup, data_sup)
                softmax_sup = nd.softmax(output_sup)
                loss_kl_sup = -nd.sum(nd.log(10.0*softmax_sup + 1e-8) * softmax_sup, axis=1)
                loss_sup = loss_xentropy_sup + opt.alpha_drm * loss_drm_sup + opt.alpha_kl * loss_kl_sup + opt.alpha_nn * loss_nn_sup + opt.alpha_pn * loss_pn_sup
                
                loss = loss_unsup + loss_sup
                
            loss.backward()
            trainer.step(bs)
            
            loss_drm = loss_drm_unsup + loss_drm_sup
            loss_pn = loss_pn_unsup + loss_pn_sup
            loss_xentropy = loss_xentropy_sup
            loss_kl = loss_kl_unsup + loss_kl_sup
            loss_nn = loss_nn_unsup + loss_nn_sup
            
            train_loss_xentropy += nd.mean(loss_xentropy).asscalar()
            train_loss_drm += nd.mean(loss_drm).asscalar()
            train_loss_pn += nd.mean(loss_pn).asscalar()
            train_loss_kl += nd.mean(loss_kl).asscalar()
            train_loss_nn += nd.mean(loss_nn).asscalar()
            train_loss += nd.mean(loss).asscalar()
            correct += (get_acc(output_sup, label_sup) + get_acc(output_unsup, label_unsup))/2
            
            total += bs
            num_batch_train += 1
            iter_indx += 1
        
        writer.add_scalars('loss', {'train': train_loss / num_batch_train}, epoch)
        writer.add_scalars('loss_xentropy', {'train': train_loss_xentropy / num_batch_train}, epoch)
        writer.add_scalars('loss_drm', {'train': train_loss_drm / num_batch_train}, epoch)
        writer.add_scalars('loss_pn', {'train': train_loss_pn / num_batch_train}, epoch)
        writer.add_scalars('loss_kl', {'train': train_loss_kl / num_batch_train}, epoch)
        writer.add_scalars('loss_nn', {'train': train_loss_nn / num_batch_train}, epoch)
        writer.add_scalars('acc', {'train': correct / total}, epoch)
        
        cur_time = datetime.datetime.now()
        h, remainder = divmod((cur_time - prev_time).seconds, 3600)
        m, s = divmod(remainder, 60)
        time_str = "Time %02d:%02d:%02d" % (h, m, s)
        if valid_data is not None:
            valid_data.reset()
            valid_loss = 0; valid_loss_xentropy = 0; valid_loss_drm = 0; valid_loss_pn = 0; valid_loss_kl = 0; valid_loss_nn = 0
            valid_correct = 0; valid_total = 0
            num_batch_valid = 0
            for batch in valid_data:
                bs = batch.data[0].shape[0]
                data = batch.data[0].as_in_context(ctx)
                label = batch.label[0].as_in_context(ctx)
                [output, xhat, _, loss_pn, loss_nn] = net(data, label)
                loss_xentropy = criterion(output, label)
                loss_drm = L2_loss(xhat, data)
                softmax_val = nd.softmax(output)
                loss_kl = -nd.sum(nd.log(10.0*softmax_val + 1e-8) * softmax_val, axis=1)
                loss = loss_xentropy + opt.alpha_drm * loss_drm + opt.alpha_kl * loss_kl + opt.alpha_nn * loss_nn + opt.alpha_pn * loss_pn
                
                valid_loss_xentropy += nd.mean(loss_xentropy).asscalar()
                valid_loss_drm += nd.mean(loss_drm).asscalar()
                valid_loss_pn += nd.mean(loss_pn).asscalar()
                valid_loss_kl += nd.mean(loss_kl).asscalar()
                valid_loss_nn += nd.mean(loss_nn).asscalar()
                valid_loss += nd.mean(loss).asscalar()
                valid_correct += get_acc(output, label)
                
                valid_total += bs
                num_batch_valid += 1
            valid_acc = valid_correct / valid_total
            if valid_acc > best_valid_acc:
                best_valid_acc = valid_acc
                net.collect_params().save('%s/%s_best.params'%(opt.model_dir, opt.exp_name))
            writer.add_scalars('loss', {'valid': valid_loss / num_batch_valid}, epoch)
            writer.add_scalars('loss_xentropy', {'valid': valid_loss_xentropy / num_batch_valid}, epoch)
            writer.add_scalars('loss_drm', {'valid': valid_loss_drm / num_batch_valid}, epoch)
            writer.add_scalars('loss_pn', {'valid': valid_loss_pn / num_batch_valid}, epoch)
            writer.add_scalars('loss_kl', {'valid': valid_loss_kl / num_batch_valid}, epoch)
            writer.add_scalars('loss_nn', {'valid': valid_loss_nn / num_batch_valid}, epoch)
            writer.add_scalars('acc', {'valid': valid_acc}, epoch)
            epoch_str = ("Epoch %d. Train Loss: %f, Train Xent: %f, Train Reconst: %f, Train Pn: %f, Train acc %f, Valid Loss: %f, Valid acc %f, Best valid acc %f, "
                         % (epoch, train_loss / num_batch_train, train_loss_xentropy / num_batch_train, train_loss_drm / num_batch_train, train_loss_pn / num_batch_train,
                            correct / total, valid_loss / num_batch_valid, valid_acc, best_valid_acc))
            if not epoch % 20:
                net.collect_params().save('%s/%s_epoch_%i.params'%(opt.model_dir, opt.exp_name, epoch))
        else:
            epoch_str = ("Epoch %d. Loss: %f, Train acc %f, "
                         % (epoch, train_loss / num_batch_train,
                            correct / total))
        prev_time = cur_time
        print(epoch_str + time_str + ', lr ' + str(trainer.learning_rate))
        
    return best_valid_acc

In [13]:
num_epochs = 500
learning_rate = 0.15
weight_decay = 5e-4
lr_decay = 0.1

In [14]:
def run_train(num_exp, ctx):
    valid_acc = 0
    for i in range(num_exp):
        ### CIFAR VGG_DRM
        model = VGG_DRM('AllConv13', batch_size=opt.batch_size // 2, num_class=10, use_bias=opt.use_bias, use_bn=opt.use_bn, do_topdown=opt.do_topdown, do_countpath=opt.do_countpath, do_pn=opt.do_pn, relu_td=opt.relu_td, do_nn=opt.do_nn)
        for param in model.collect_params().values():
            if param.name.find('conv') != -1 or param.name.find('dense') != -1:
                if param.name.find('weight') != -1:
                    param.initialize(init=mx.initializer.Xavier(), ctx=ctx)
                else:
                    param.initialize(init=mx.init.Zero(), ctx=ctx)
            elif param.name.find('batchnorm') != -1 or param.name.find('instancenorm') != -1:
                if param.name.find('gamma') != -1:
                    param.initialize(init=Normal(mean=1, sigma=0.02), ctx=ctx)
                else:
                    param.initialize(init=mx.init.Zero(), ctx=ctx)
            elif param.name.find('biasadder') != -1:
                param.initialize(init=mx.init.Zero(), ctx=ctx)
                
        # model.hybridize()
        
        relu_indx = []
        
        for i in range(len(model.features._children)):
            if model.features._children[i].name.find('relu') != -1:
                relu_indx.append(i)
                
        acc = train(model, train_data_unsup, train_data_sup, valid_data, num_epochs, learning_rate, weight_decay, ctx, lr_decay, relu_indx)
        print('Validation Accuracy - Run %i = %f'%(i, acc))
        valid_acc += acc

    print('Validation Accuracy = %f'%(valid_acc/num_exp))

In [None]:
run_train(1, ctx=opt.ctx)