In [1]:
import os
import math
import mxnet as mx
from mxnet import image
from mxnet import nd, gluon, autograd, init
from mxnet.gluon.data.vision import ImageFolderDataset
from mxnet.gluon.data import DataLoader
from mxnet.gluon import nn
from tensorboardX import SummaryWriter
import numpy as np
import shutil
import _pickle as cPickle
from sklearn import preprocessing
from mxnet.gluon.parameter import Parameter, ParameterDict
from common.util import download_file
import subprocess

from IPython.core.debugger import Tracer

In [2]:
class Options:
    def __init__(self):
        self.seed_val = 0
        self.num_train_sup = 50000
        self.batch_size = 128
        self.data_dir = '/tanData/datasets/cifar100'
        self.log_dir = '/tanData/logs'
        self.model_dir ='/tanData/models'
        self.exp_name = 'cifar100_nlabels_%i_allconv13_lddrm_mm_pathnorm_maxmin_seed_%i'%(self.num_train_sup, self.seed_val)
        self.ctx = mx.gpu(0)
        self.alpha_drm = 0.5
        self.alpha_pn = 1.0
        self.alpha_kl = 0.5
        self.alpha_nn = 0.5
        self.alpha_min = 0.5
        self.alpha_max = 1.0 - self.alpha_min
        
        self.use_bias = True
        self.use_bn = True
        self.do_topdown = True
        self.do_countpath = False
        self.do_pn = True
        self.relu_td = False
        self.do_nn = True
        self.min_max = True

opt = Options()

In [3]:
def gpu_device(ctx=mx.gpu(0)):
    try:
        _ = mx.nd.array([1, 2, 3], ctx=ctx)
    except mx.MXNetError:
        return None
    return ctx

assert gpu_device(opt.ctx), 'No GPU device found!'

In [4]:
log_dir = os.path.join(opt.log_dir, opt.exp_name)
if not os.path.exists(log_dir):
    os.makedirs(log_dir)

In [5]:
data_shape = (3, 32, 32)
train_data_sup = mx.io.ImageRecordIter(
    path_imgrec = os.path.join(opt.data_dir,'cifar100_train.rec'),
    data_shape  = data_shape,
    batch_size  = opt.batch_size,
    mean_r             = 129.3,
    mean_g             = 124.1,
    mean_b             = 112.4,
    std_r              = 68.2,
    std_g              = 65.4,
    std_b              = 70.4,
    shuffle = True,
    ## Data augmentation
    rand_crop   = True,
    max_crop_size = 32,
    min_crop_size = 32,
    pad = 4,
    fill_value = 0,
    rand_mirror = True)
valid_data = mx.io.ImageRecordIter(
    path_imgrec = os.path.join(opt.data_dir,'cifar100_val.rec'),
    data_shape  = data_shape,
    batch_size  = opt.batch_size,
    mean_r             = 129.3,
    mean_g             = 124.1,
    mean_b             = 112.4,
    std_r              = 68.2,
    std_g              = 65.4,
    std_b              = 70.4,
    ## No data augmentation
    rand_crop   = False,
    rand_mirror = False)

In [6]:
criterion = gluon.loss.SoftmaxCrossEntropyLoss()
L2_loss = gluon.loss.L2Loss()
L1_loss = gluon.loss.L1Loss()

In [7]:
class Normal(mx.init.Initializer):
    """Initializes weights with random values sampled from a normal distribution
    with a mean and standard deviation of `sigma`.
    """
    def __init__(self, mean=0, sigma=0.01):
        super(Normal, self).__init__(sigma=sigma)
        self.sigma = sigma
        self.mean = mean

    def _init_weight(self, _, arr):
        mx.random.normal(self.mean, self.sigma, out=arr)

In [8]:
# from resnet import ResNet164_v2
from mxnet.gluon.model_zoo import vision as models
from vgg_ld_opt_min_max import VGG_DRM

In [9]:
import datetime
writer = SummaryWriter(os.path.join(opt.log_dir, opt.exp_name))

def get_acc(output, label):
    pred = output.argmax(1, keepdims=False)
    correct = (pred == label).sum()
    return correct.asscalar()

def extract_acts(net, x, layer_indx):
    start_layer = 0
    out = []
    for i in layer_indx:
        for block in net.features._children[start_layer:i]:
            x = block(x)
        out.append(x)
        start_layer = i
    return out

In [10]:
def train(net, train_data_sup, valid_data, num_epochs, lr, wd, ctx, lr_decay, relu_indx):
    trainer = gluon.Trainer(
        net.collect_params(), 'adam', {'learning_rate': 0.001, 'wd': wd})
    
    prev_time = datetime.datetime.now()
    best_valid_acc = 0; best_valid_acc2 = 0; best_valid_accmax = 0; best_valid_accmin = 0
    iter_indx = 0
    
    for epoch in range(num_epochs-100):
        train_data_sup.reset()
        train_loss = 0; train_loss_xentropy = 0; train_loss_drm = 0; train_loss_pn = 0; train_loss_kl = 0; train_loss_nn = 0
        correct = 0; total = 0; correct2=0; correctmax=0; correctmin=0
        num_batch_train = 0
        
        if epoch == 20:
            sgd_lr = 0.15
            decay_val = np.exp(np.log(sgd_lr / 0.0001) / (num_epochs - 2))
            sgd_lr = sgd_lr * decay_val
            trainer = gluon.Trainer(net.collect_params(), 'SGD', {'learning_rate': sgd_lr, 'wd': wd})
            
        if epoch >= 20:
            trainer.set_learning_rate(trainer.learning_rate / decay_val)
        
        for batch_sup in train_data_sup:
            bs = batch_sup.data[0].shape[0]
            data_sup = batch_sup.data[0].as_in_context(ctx)
            label_sup = batch_sup.label[0].as_in_context(ctx)
            with autograd.record():
                [output_sup, output_min_sup, xhat_sup, xhat_min_sup, _, loss_pn_sup, loss_nn_sup] = net(data_sup, label_sup)
                loss_xentropy_sup = opt.alpha_max * criterion(output_sup, label_sup) + opt.alpha_min * criterion(output_min_sup, label_sup)
                loss_drm_sup = opt.alpha_max * L2_loss(xhat_sup, data_sup) + opt.alpha_min * L2_loss(xhat_min_sup, data_sup)
                softmax_sup = opt.alpha_max * nd.softmax(output_sup) + opt.alpha_min * nd.softmax(output_min_sup)
                loss_kl_sup = -nd.sum(nd.log(100.0*softmax_sup + 1e-8) * softmax_sup, axis=1)
                loss = loss_xentropy_sup + opt.alpha_drm * loss_drm_sup + opt.alpha_kl * loss_kl_sup + opt.alpha_nn * loss_nn_sup + opt.alpha_pn * loss_pn_sup
                
            loss.backward()
            trainer.step(bs)
            
            loss_drm = loss_drm_sup
            loss_pn = loss_pn_sup
            loss_xentropy = loss_xentropy_sup
            loss_kl = loss_kl_sup
            loss_nn = loss_nn_sup
            
            train_loss_xentropy += nd.mean(loss_xentropy).asscalar()
            train_loss_drm += nd.mean(loss_drm).asscalar()
            train_loss_pn += nd.mean(loss_pn).asscalar()
            train_loss_kl += nd.mean(loss_kl).asscalar()
            train_loss_nn += nd.mean(loss_nn).asscalar()
            train_loss += nd.mean(loss).asscalar()
            correct += get_acc(opt.alpha_max * output_sup + opt.alpha_min * output_min_sup, label_sup)
            correct2 += opt.alpha_max * get_acc(output_sup, label_sup) + opt.alpha_min * get_acc(output_min_sup, label_sup)
            correctmax += get_acc(output_sup, label_sup)
            correctmin += get_acc(output_min_sup, label_sup)
            
            total += bs
            num_batch_train += 1
            iter_indx += 1
        
        writer.add_scalars('loss', {'train': train_loss / num_batch_train}, epoch)
        writer.add_scalars('loss_xentropy', {'train': train_loss_xentropy / num_batch_train}, epoch)
        writer.add_scalars('loss_drm', {'train': train_loss_drm / num_batch_train}, epoch)
        writer.add_scalars('loss_pn', {'train': train_loss_pn / num_batch_train}, epoch)
        writer.add_scalars('loss_kl', {'train': train_loss_kl / num_batch_train}, epoch)
        writer.add_scalars('loss_nn', {'train': train_loss_nn / num_batch_train}, epoch)
        writer.add_scalars('acc', {'train': correct / total}, epoch)
        writer.add_scalars('acc', {'train2': correct2 / total}, epoch)
        writer.add_scalars('acc', {'trainmax': correctmax / total}, epoch)
        writer.add_scalars('acc', {'trainmin': correctmin / total}, epoch)
        
        cur_time = datetime.datetime.now()
        h, remainder = divmod((cur_time - prev_time).seconds, 3600)
        m, s = divmod(remainder, 60)
        time_str = "Time %02d:%02d:%02d" % (h, m, s)
        if valid_data is not None:
            valid_data.reset()
            valid_loss = 0; valid_loss_xentropy = 0; valid_loss_drm = 0; valid_loss_pn = 0; valid_loss_kl = 0; valid_loss_nn = 0
            valid_correct = 0; valid_total = 0; valid_correct2 = 0; valid_correctmax = 0; valid_correctmin = 0
            num_batch_valid = 0
            for batch in valid_data:
                bs = batch.data[0].shape[0]
                data = batch.data[0].as_in_context(ctx)
                label = batch.label[0].as_in_context(ctx)
                [output, output_min, xhat, xhat_min, _, loss_pn, loss_nn] = net(data, label)
                loss_xentropy = opt.alpha_max * criterion(output, label) + opt.alpha_min * criterion(output_min, label)
                loss_drm = opt.alpha_max * L2_loss(xhat, data) + opt.alpha_min * L2_loss(xhat_min, data)
                softmax_val = opt.alpha_max * nd.softmax(output) + opt.alpha_min * nd.softmax(output_min)
                loss_kl = -nd.sum(nd.log(100.0*softmax_val + 1e-8) * softmax_val, axis=1)
                loss = loss_xentropy + opt.alpha_drm * loss_drm + opt.alpha_kl * loss_kl + opt.alpha_nn * loss_nn + opt.alpha_pn * loss_pn
                
                valid_loss_xentropy += nd.mean(loss_xentropy).asscalar()
                valid_loss_drm += nd.mean(loss_drm).asscalar()
                valid_loss_pn += nd.mean(loss_pn).asscalar()
                valid_loss_kl += nd.mean(loss_kl).asscalar()
                valid_loss_nn += nd.mean(loss_nn).asscalar()
                valid_loss += nd.mean(loss).asscalar()
                valid_correct += get_acc(opt.alpha_max * output + opt.alpha_min * output_min, label)
                valid_correct2 += opt.alpha_max * get_acc(output, label) + opt.alpha_min * get_acc(output_min, label)
                valid_correctmax += get_acc(output, label)
                valid_correctmin += get_acc(output_min, label)
                
                valid_total += bs
                num_batch_valid += 1
            valid_acc = valid_correct / valid_total
            valid_acc2 = valid_correct2 / valid_total
            valid_accmax = valid_correctmax / valid_total
            valid_accmin = valid_correctmin / valid_total
            if valid_acc > best_valid_acc:
                best_valid_acc = valid_acc
                net.collect_params().save('%s/%s_best.params'%(opt.model_dir, opt.exp_name))
            if valid_acc2 > best_valid_acc2:
                best_valid_acc2 = valid_acc2
                net.collect_params().save('%s/%s_best_2.params'%(opt.model_dir, opt.exp_name))
            if valid_accmax > best_valid_accmax:
                best_valid_accmax = valid_accmax
                net.collect_params().save('%s/%s_best_max.params'%(opt.model_dir, opt.exp_name))
            if valid_accmin > best_valid_accmin:
                best_valid_accmin = valid_accmin
                net.collect_params().save('%s/%s_best_min.params'%(opt.model_dir, opt.exp_name))
            writer.add_scalars('loss', {'valid': valid_loss / num_batch_valid}, epoch)
            writer.add_scalars('loss_xentropy', {'valid': valid_loss_xentropy / num_batch_valid}, epoch)
            writer.add_scalars('loss_drm', {'valid': valid_loss_drm / num_batch_valid}, epoch)
            writer.add_scalars('loss_pn', {'valid': valid_loss_pn / num_batch_valid}, epoch)
            writer.add_scalars('loss_kl', {'valid': valid_loss_kl / num_batch_valid}, epoch)
            writer.add_scalars('loss_nn', {'valid': valid_loss_nn / num_batch_valid}, epoch)
            writer.add_scalars('acc', {'valid': valid_acc}, epoch)
            writer.add_scalars('acc', {'valid2': valid_acc2}, epoch)
            writer.add_scalars('acc', {'validmax': valid_accmax}, epoch)
            writer.add_scalars('acc', {'validmin': valid_accmin}, epoch)
            epoch_str = ("Epoch %d. Train Loss: %f, Train Xent: %f, Train Reconst: %f, Train Pn: %f, Train acc %f, Valid Loss: %f, Valid acc %f, Valid acc2 %f, Valid accmax %f, Valid accmin %f, Best valid acc %f, Best valid acc2 %f, Best valid accmax %f, Best valid accmin %f, "
                         % (epoch, train_loss / num_batch_train, train_loss_xentropy / num_batch_train, train_loss_drm / num_batch_train, train_loss_pn / num_batch_train,
                            correct / total, valid_loss / num_batch_valid, valid_acc, valid_acc2, valid_accmax, valid_accmin, best_valid_acc, best_valid_acc2, best_valid_accmax, best_valid_accmin))
            
            # net.collect_params().save('%s/%s_latest.params'%(opt.model_dir, opt.exp_name))
            if not epoch % 20:
                net.collect_params().save('%s/%s_epoch_%i.params'%(opt.model_dir, opt.exp_name, epoch))
        else:
            epoch_str = ("Epoch %d. Loss: %f, Train acc %f, "
                         % (epoch, train_loss / num_batch_train,
                            correct / total))
        prev_time = cur_time
        print(epoch_str + time_str + ', lr ' + str(trainer.learning_rate))
        
    return best_valid_acc

In [11]:
num_epochs = 500
learning_rate = 0.15
weight_decay = 5e-4
lr_decay = 0.1

In [12]:
def run_train(num_exp, ctx):
    valid_acc = 0
    for i in range(num_exp):
        ### CIFAR VGG_DRM
        model = VGG_DRM('AllConv13', batch_size=opt.batch_size, num_class=100, use_bias=opt.use_bias, use_bn=opt.use_bn, do_topdown=opt.do_topdown, do_countpath=opt.do_countpath, do_pn=opt.do_pn, relu_td=opt.relu_td, do_nn=opt.do_nn, min_max=opt.min_max)
        for param in model.collect_params().values():
            if param.name.find('conv') != -1 or param.name.find('dense') != -1:
                if param.name.find('weight') != -1:
                    param.initialize(init=mx.initializer.Xavier(), ctx=ctx)
                else:
                    param.initialize(init=mx.init.Zero(), ctx=ctx)
            elif param.name.find('batchnorm') != -1 or param.name.find('instancenorm') != -1:
                if param.name.find('gamma') != -1:
                    param.initialize(init=Normal(mean=1, sigma=0.02), ctx=ctx)
                else:
                    param.initialize(init=mx.init.Zero(), ctx=ctx)
            elif param.name.find('biasadder') != -1:
                param.initialize(init=mx.init.Zero(), ctx=ctx)
        
        # model.collect_params().load('%s/%s_epoch_%i.params'%(opt.model_dir, opt.exp_name, 380), ctx=ctx)
        # model.hybridize()
        
        relu_indx = []
        
        for i in range(len(model.features._children)):
            if model.features._children[i].name.find('relu') != -1:
                relu_indx.append(i)
                
        acc = train(model, train_data_sup, valid_data, num_epochs, learning_rate, weight_decay, ctx, lr_decay, relu_indx)
        print('Validation Accuracy - Run %i = %f'%(i, acc))
        valid_acc += acc

    print('Validation Accuracy = %f'%(valid_acc/num_exp))

In [None]:
run_train(1, ctx=opt.ctx)

Epoch 0. Train Loss: 2.808803, Train Xent: 1.703139, Train Reconst: 0.327884, Train Pn: 0.004020, Train acc 0.431406, Valid Loss: 5.060799, Valid acc 0.490506, Valid acc2 0.452482, Valid accmax 0.407041, Valid accmin 0.497923, Best valid acc 0.490506, Best valid acc2 0.452482, Best valid accmax 0.407041, Best valid accmin 0.497923, Time 00:04:12, lr 0.001
Epoch 1. Train Loss: 0.877728, Train Xent: 1.196728, Train Reconst: 0.137875, Train Pn: 0.004591, Train acc 0.619585, Valid Loss: 7.082472, Valid acc 0.670573, Valid acc2 0.649089, Valid accmax 0.666767, Valid accmin 0.631410, Best valid acc 0.670573, Best valid acc2 0.649089, Best valid accmax 0.666767, Best valid accmin 0.631410, Time 00:04:22, lr 0.001
Epoch 2. Train Loss: 0.345412, Train Xent: 0.941546, Train Reconst: 0.089564, Train Pn: 0.004024, Train acc 0.706871, Valid Loss: 5.863862, Valid acc 0.751603, Valid acc2 0.727764, Valid accmax 0.720453, Valid accmin 0.735076, Best valid acc 0.751603, Best valid acc2 0.727764, Best v