In [1]:
import os
import math
import mxnet as mx
from mxnet import image
from mxnet import nd, gluon, autograd, init
from mxnet.gluon.data.vision import ImageFolderDataset
from mxnet.gluon.data import DataLoader
from mxnet.gluon import nn
from tensorboardX import SummaryWriter
import numpy as np
import shutil
import _pickle as cPickle
from sklearn import preprocessing
from mxnet.gluon.parameter import Parameter, ParameterDict
from common.util import download_file
import subprocess

from IPython.core.debugger import Tracer

In [2]:
class Options:
    def __init__(self):
        self.seed_val = 1
        self.num_train_sup = 50000
        self.batch_size = 128
        self.data_dir = '/tanData/datasets/cifar100'
        self.log_dir = '/tanData/logs'
        self.model_dir ='/tanData/models'
        self.exp_name = 'cifar100_nlabels_%i_allconv13_lddrm_mm_pathnorm_meanteacher_rampup_seed_%i'%(self.num_train_sup, self.seed_val)
        self.ctx = mx.gpu(5)
        self.alpha_drm = 0.5
        self.alpha_pn = 1.0
        self.alpha_kl = 0.5
        self.alpha_nn = 0.5
        self.alpha_consistent = 33.0
        
        self.use_bias = True
        self.use_bn = True
        self.do_topdown = True
        self.do_countpath = False
        self.do_pn = True
        self.relu_td = False
        self.do_nn = True

opt = Options()

In [3]:
def gpu_device(ctx=mx.gpu(0)):
    try:
        _ = mx.nd.array([1, 2, 3], ctx=ctx)
    except mx.MXNetError:
        return None
    return ctx

assert gpu_device(opt.ctx), 'No GPU device found!'

In [4]:
log_dir = os.path.join(opt.log_dir, opt.exp_name)
if not os.path.exists(log_dir):
    os.makedirs(log_dir)

In [5]:
data_shape = (3, 32, 32)
train_data_sup = mx.io.ImageRecordIter(
    path_imgrec = os.path.join(opt.data_dir,'cifar100_train.rec'),
    data_shape  = data_shape,
    batch_size  = opt.batch_size,
    mean_r             = 129.3,
    mean_g             = 124.1,
    mean_b             = 112.4,
    std_r              = 68.2,
    std_g              = 65.4,
    std_b              = 70.4,
    shuffle = True,
    ## Data augmentation
    rand_crop   = True,
    max_crop_size = 32,
    min_crop_size = 32,
    pad = 4,
    fill_value = 0,
    rand_mirror = True)
valid_data = mx.io.ImageRecordIter(
    path_imgrec = os.path.join(opt.data_dir,'cifar100_val.rec'),
    data_shape  = data_shape,
    batch_size  = opt.batch_size,
    mean_r             = 129.3,
    mean_g             = 124.1,
    mean_b             = 112.4,
    std_r              = 68.2,
    std_g              = 65.4,
    std_b              = 70.4,
    ## No data augmentation
    rand_crop   = False,
    rand_mirror = False)

In [6]:
criterion = gluon.loss.SoftmaxCrossEntropyLoss()
L2_loss = gluon.loss.L2Loss()
L1_loss = gluon.loss.L1Loss()

In [7]:
class Normal(mx.init.Initializer):
    """Initializes weights with random values sampled from a normal distribution
    with a mean and standard deviation of `sigma`.
    """
    def __init__(self, mean=0, sigma=0.01):
        super(Normal, self).__init__(sigma=sigma)
        self.sigma = sigma
        self.mean = mean

    def _init_weight(self, _, arr):
        mx.random.normal(self.mean, self.sigma, out=arr)

In [8]:
# from resnet import ResNet164_v2
from mxnet.gluon.model_zoo import vision as models
from vgg_ld_opt import VGG_DRM

In [9]:
import datetime
writer = SummaryWriter(os.path.join(opt.log_dir, opt.exp_name))

def get_acc(output, label):
    pred = output.argmax(1, keepdims=False)
    correct = (pred == label).sum()
    return correct.asscalar()

def extract_acts(net, x, layer_indx):
    start_layer = 0
    out = []
    for i in layer_indx:
        for block in net.features._children[start_layer:i]:
            x = block(x)
        out.append(x)
        start_layer = i
    return out

def update_ema_variables(model, ema_model, alpha, global_step):
    # Use the true average until the exponential average is more correct
    alpha = min(1 - 1 / (global_step + 1), alpha)
    for ema_param, param in zip(ema_model.collect_params().values(), model.collect_params().values()):
        ema_param.set_data(alpha * ema_param.data() + (1 - alpha) * param.data())
        
def test(net, valid_data, ctx):
    valid_data.reset()
    valid_loss = 0; valid_loss_xentropy = 0; valid_loss_drm = 0; valid_loss_pn = 0; valid_loss_kl = 0; valid_loss_nn = 0
    valid_correct = 0; valid_total = 0
    num_batch_valid = 0
    for batch in valid_data:
        bs = batch.data[0].shape[0]
        data = batch.data[0].as_in_context(ctx)
        label = batch.label[0].as_in_context(ctx)
        [output, xhat, _, loss_pn, loss_nn] = net(data, label)
        loss_xentropy = criterion(output, label)
        loss_drm = L2_loss(xhat, data)
        softmax_val = nd.softmax(output)
        loss_kl = -nd.sum(nd.log(100.0*softmax_val + 1e-8) * softmax_val, axis=1)
        loss = loss_xentropy + opt.alpha_drm * loss_drm + opt.alpha_kl * loss_kl + opt.alpha_nn * loss_nn + opt.alpha_pn * loss_pn

        valid_loss_xentropy += nd.mean(loss_xentropy).asscalar()
        valid_loss_drm += nd.mean(loss_drm).asscalar()
        valid_loss_pn += nd.mean(loss_pn).asscalar()
        valid_loss_kl += nd.mean(loss_kl).asscalar()
        valid_loss_nn += nd.mean(loss_nn).asscalar()
        valid_loss += nd.mean(loss).asscalar()
        valid_correct += get_acc(output, label)

        valid_total += bs
        num_batch_valid += 1
    valid_acc = valid_correct / valid_total
    return valid_acc, valid_loss_xentropy, valid_loss_drm, valid_loss_pn, valid_loss_kl, valid_loss_nn, valid_loss, num_batch_valid

def write_results(writer, name, valid_acc, valid_loss_xentropy, valid_loss_drm, valid_loss_pn, valid_loss_kl, valid_loss_nn, valid_loss, num_batch_valid, epoch):
    writer.add_scalars('loss', {'%s'%name: valid_loss / num_batch_valid}, epoch)
    writer.add_scalars('loss_xentropy', {'%s'%name: valid_loss_xentropy / num_batch_valid}, epoch)
    writer.add_scalars('loss_drm', {'%s'%name: valid_loss_drm / num_batch_valid}, epoch)
    writer.add_scalars('loss_pn', {'%s'%name: valid_loss_pn / num_batch_valid}, epoch)
    writer.add_scalars('loss_kl', {'%s'%name: valid_loss_kl / num_batch_valid}, epoch)
    writer.add_scalars('loss_nn', {'%s'%name: valid_loss_nn / num_batch_valid}, epoch)
    writer.add_scalars('acc', {'%s'%name: valid_acc}, epoch)

In [10]:
def train(net, ema_net, train_data_sup, valid_data, num_epochs, lr, wd, ctx, lr_decay, relu_indx, ema_decay):
    trainer = gluon.Trainer(
        net.collect_params(), 'adam', {'learning_rate': 0.0001, 'wd': wd})
    
    prev_time = datetime.datetime.now()
    best_valid_acc = 0
    best_valid_acc_ema = 0
    iter_indx = 0
    global_step = 0
    
    for epoch in range(num_epochs-100):
        train_data_sup.reset()
        train_loss = 0; train_loss_xentropy = 0; train_loss_drm = 0; train_loss_pn = 0; train_loss_kl = 0; train_loss_nn = 0
        correct = 0; total = 0
        num_batch_train = 0
        
        if epoch == 20:
            sgd_lr = 0.15
            decay_val = np.exp(np.log(sgd_lr / 0.0001) / (num_epochs - 2))
            sgd_lr = sgd_lr * decay_val
            trainer = gluon.Trainer(net.collect_params(), 'SGD', {'learning_rate': sgd_lr, 'wd': wd})
            
        if epoch >= 20:
            trainer.set_learning_rate(trainer.learning_rate / decay_val)
        
        for batch_sup in train_data_sup:
            bs = batch_sup.data[0].shape[0]
            data_sup = batch_sup.data[0].as_in_context(ctx)
            label_sup = batch_sup.label[0].as_in_context(ctx)
            with autograd.record():
                [output_sup, xhat_sup, _, loss_pn_sup, loss_nn_sup] = net(data_sup, label_sup)
                if global_step == 0:
                    for ema_param, param in zip(ema_net.collect_params().values(),net.collect_params().values()):
                        ema_param.initialize(init=mx.initializer.Constant(param.data()), ctx=ctx)
                [output_sup_ema, _, _, _, _] = ema_net(data_sup, label_sup)
                loss_xentropy_sup = criterion(output_sup, label_sup)
                loss_drm_sup = L2_loss(xhat_sup, data_sup)
                softmax_sup = nd.softmax(output_sup)
                softmax_sup_ema = nd.softmax(output_sup_ema)
                loss_kl_sup = -nd.sum(nd.log(100.0*softmax_sup + 1e-8) * softmax_sup, axis=1)
                loss_consistent_sup = L2_loss(softmax_sup, softmax_sup_ema)
                loss = loss_xentropy_sup + opt.alpha_drm * loss_drm_sup + opt.alpha_kl * loss_kl_sup + opt.alpha_nn * loss_nn_sup + opt.alpha_pn * loss_pn_sup + opt.alpha_consistent * loss_consistent_sup
                
            loss.backward()
            trainer.step(bs)
            global_step += 1
            if global_step < 40000:
                update_ema_variables(net, ema_net, ema_decay[0], global_step)
            else:
                update_ema_variables(net, ema_net, ema_decay[1], global_step)
            
            loss_drm = loss_drm_sup
            loss_pn = loss_pn_sup
            loss_xentropy = loss_xentropy_sup
            loss_kl = loss_kl_sup
            loss_nn = loss_nn_sup
            
            train_loss_xentropy += nd.mean(loss_xentropy).asscalar()
            train_loss_drm += nd.mean(loss_drm).asscalar()
            train_loss_pn += nd.mean(loss_pn).asscalar()
            train_loss_kl += nd.mean(loss_kl).asscalar()
            train_loss_nn += nd.mean(loss_nn).asscalar()
            train_loss += nd.mean(loss).asscalar()
            correct += get_acc(output_sup, label_sup)
            
            total += bs
            num_batch_train += 1
            iter_indx += 1
        
        writer.add_scalars('loss', {'train': train_loss / num_batch_train}, epoch)
        writer.add_scalars('loss_xentropy', {'train': train_loss_xentropy / num_batch_train}, epoch)
        writer.add_scalars('loss_drm', {'train': train_loss_drm / num_batch_train}, epoch)
        writer.add_scalars('loss_pn', {'train': train_loss_pn / num_batch_train}, epoch)
        writer.add_scalars('loss_kl', {'train': train_loss_kl / num_batch_train}, epoch)
        writer.add_scalars('loss_nn', {'train': train_loss_nn / num_batch_train}, epoch)
        writer.add_scalars('acc', {'train': correct / total}, epoch)
        
        cur_time = datetime.datetime.now()
        h, remainder = divmod((cur_time - prev_time).seconds, 3600)
        m, s = divmod(remainder, 60)
        time_str = "Time %02d:%02d:%02d" % (h, m, s)
        if valid_data is not None:
            valid_acc, valid_loss_xentropy, valid_loss_drm, valid_loss_pn, valid_loss_kl, valid_loss_nn, valid_loss, num_batch_valid = test(net, valid_data, ctx)
            if valid_acc > best_valid_acc:
                best_valid_acc = valid_acc
                net.collect_params().save('%s/%s_best.params'%(opt.model_dir, opt.exp_name))
            write_results(writer, 'valid', valid_acc, valid_loss_xentropy, valid_loss_drm, valid_loss_pn, valid_loss_kl, valid_loss_nn, valid_loss, num_batch_valid, epoch)
            epoch_str = ("Epoch %d. Train Loss: %f, Train Xent: %f, Train Reconst: %f, Train Pn: %f, Train acc %f, Valid Loss: %f, Valid acc %f, Best valid acc %f, "
                         % (epoch, train_loss / num_batch_train, train_loss_xentropy / num_batch_train, train_loss_drm / num_batch_train, train_loss_pn / num_batch_train,
                            correct / total, valid_loss / num_batch_valid, valid_acc, best_valid_acc))
            
            valid_acc, valid_loss_xentropy, valid_loss_drm, valid_loss_pn, valid_loss_kl, valid_loss_nn, valid_loss, num_batch_valid = test(ema_net, valid_data, ctx)
            if valid_acc > best_valid_acc_ema:
                best_valid_acc_ema = valid_acc
                net.collect_params().save('%s/%s_ema_best.params'%(opt.model_dir, opt.exp_name))
            write_results(writer, 'valid_ema', valid_acc, valid_loss_xentropy, valid_loss_drm, valid_loss_pn, valid_loss_kl, valid_loss_nn, valid_loss, num_batch_valid, epoch)    
            epoch_str_ema = ("Epoch %d. Valid Loss EMA: %f, Valid acc EMA %f, Best valid acc EMA %f, "
                         % (epoch, valid_loss / num_batch_valid, valid_acc, best_valid_acc))
            if not epoch % 20:
                net.collect_params().save('%s/%s_epoch_%i.params'%(opt.model_dir, opt.exp_name, epoch))
                net.collect_params().save('%s/%s_ema_epoch_%i.params'%(opt.model_dir, opt.exp_name, epoch))
        else:
            epoch_str = ("Epoch %d. Loss: %f, Train acc %f, "
                         % (epoch, train_loss / num_batch_train,
                            correct / total))
        prev_time = cur_time
        print(epoch_str + time_str + ', lr ' + str(trainer.learning_rate))
        print(epoch_str_ema + time_str)
        
    return best_valid_acc

In [11]:
num_epochs = 500
learning_rate = 0.15
weight_decay = 5e-4
lr_decay = 0.1
ema_decay = [0.99, 0.999]

In [12]:
def run_train(num_exp, ctx):
    valid_acc = 0
    for i in range(num_exp):
        ### CIFAR VGG_DRM
        model = VGG_DRM('AllConv13', batch_size=opt.batch_size, num_class=100, use_bias=opt.use_bias, use_bn=opt.use_bn, do_topdown=opt.do_topdown, do_countpath=opt.do_countpath, do_pn=opt.do_pn, relu_td=opt.relu_td, do_nn=opt.do_nn)
        for param in model.collect_params().values():
            if param.name.find('conv') != -1 or param.name.find('dense') != -1:
                if param.name.find('weight') != -1:
                    param.initialize(init=mx.initializer.Xavier(), ctx=ctx)
                else:
                    param.initialize(init=mx.init.Zero(), ctx=ctx)
            elif param.name.find('batchnorm') != -1 or param.name.find('instancenorm') != -1:
                if param.name.find('gamma') != -1:
                    param.initialize(init=Normal(mean=1, sigma=0.02), ctx=ctx)
                else:
                    param.initialize(init=mx.init.Zero(), ctx=ctx)
            elif param.name.find('biasadder') != -1:
                param.initialize(init=mx.init.Zero(), ctx=ctx)
                
        ema_model = VGG_DRM('AllConv13', batch_size=opt.batch_size, num_class=100, use_bias=opt.use_bias, use_bn=opt.use_bn, do_topdown=opt.do_topdown, do_countpath=opt.do_countpath, do_pn=opt.do_pn, relu_td=opt.relu_td, do_nn=opt.do_nn)
        # model.hybridize()
        
        relu_indx = []
        
        for i in range(len(model.features._children)):
            if model.features._children[i].name.find('relu') != -1:
                relu_indx.append(i)
                
        acc = train(model, ema_model, train_data_sup, valid_data, num_epochs, learning_rate, weight_decay, ctx, lr_decay, relu_indx, ema_decay)
        print('Validation Accuracy - Run %i = %f'%(i, acc))
        valid_acc += acc

    print('Validation Accuracy = %f'%(valid_acc/num_exp))

In [None]:
run_train(1, ctx=opt.ctx)

Epoch 0. Train Loss: 7.182033, Train Xent: 4.166586, Train Reconst: 0.699669, Train Pn: 0.000295, Train acc 0.076666, Valid Loss: 7.368681, Valid acc 0.092761, Best valid acc 0.092761, Time 00:03:29, lr 0.0001
Epoch 0. Valid Loss EMA: 6.656285, Valid acc EMA 0.101062, Best valid acc EMA 0.092761, Time 00:03:29
Epoch 1. Train Loss: 5.424252, Train Xent: 3.786090, Train Reconst: 0.289311, Train Pn: 0.000947, Train acc 0.133032, Valid Loss: 6.095891, Valid acc 0.132212, Best valid acc 0.132212, Time 00:03:44, lr 0.0001
Epoch 1. Valid Loss EMA: 5.613025, Valid acc EMA 0.148638, Best valid acc EMA 0.132212, Time 00:03:44
Epoch 2. Train Loss: 4.500122, Train Xent: 3.571621, Train Reconst: 0.214945, Train Pn: 0.002538, Train acc 0.165465, Valid Loss: 8.571475, Valid acc 0.097656, Best valid acc 0.132212, Time 00:03:43, lr 0.0001
Epoch 2. Valid Loss EMA: 5.910604, Valid acc EMA 0.171975, Best valid acc EMA 0.132212, Time 00:03:43
Epoch 3. Train Loss: 3.813007, Train Xent: 3.395284, Train Recon