In [1]:
import os
import math
import mxnet as mx
from mxnet import image
from mxnet import nd, gluon, autograd, init
from mxnet.gluon.data.vision import ImageFolderDataset
from mxnet.gluon.data import DataLoader
from mxnet.gluon import nn
from tensorboardX import SummaryWriter
import numpy as np
import shutil
import _pickle as cPickle
import gzip
from sklearn import preprocessing
from mxnet.gluon.parameter import Parameter, ParameterDict
from common.util import download_file
import subprocess

from IPython.core.debugger import Tracer

In [2]:
class Options:
    def __init__(self):
        self.seed_val = 0
        self.num_train_sup = 1000
        self.batch_size = 512
        self.data_dir = '/tanData/datasets/mnist'
        self.log_dir = '/tanData/logs_mnist'
        self.model_dir ='/tanData/models'
        self.exp_name = 'mnist_nlabels_%i_lenet_lddrm_pathnorm_seed_%i'%(self.num_train_sup, self.seed_val)
        self.ctx = mx.gpu(0)
        self.alpha_drm = 1.0
        self.alpha_pn = 1.0
        self.alpha_kl = 0.2
        self.alpha_nn = 1.0
        
        self.use_bias = True
        self.use_bn = True
        self.do_topdown = True
        self.do_countpath = False
        self.do_pn = True
        self.relu_td = False
        self.do_nn = True

opt = Options()

In [3]:
# Preparing the data step 1
f = gzip.open(os.path.join(opt.data_dir,'mnist.pkl.gz'), 'rb')
train_set, valid_set, test_set = cPickle.load(f, encoding='bytes')
f.close()

trainx = train_set[0]
trainy = train_set[1]

testx = test_set[0]
testy = test_set[1]
Ntest = np.shape(testx)[0]

validx = valid_set[0]
validy = valid_set[1]
Nvalid = np.shape(validx)[0]

trainx = np.concatenate([trainx, validx], axis=0)
trainy = np.concatenate([trainy, validy], axis=0)
labels = np.unique(trainy)

num_train = np.shape(trainx)[0]

num_train_sup = opt.num_train_sup

num_train_sup_per_label = num_train_sup // len(labels)

ratio_unsup_sup = (num_train - num_train_sup) // num_train_sup

# select labeled data
data_rng = np.random.RandomState(opt.seed_val)
indx = data_rng.permutation(range(0, num_train))
indx_sup = []

for c in labels:
    c_count = 0
    for i in indx:
        if trainy[i] == c and c_count < num_train_sup_per_label:
            for i_repeat in range(ratio_unsup_sup):
                indx_sup.append(i)
            c_count += 1
        if c_count >= num_train_sup_per_label:
            break

indx_sup = np.random.permutation(indx_sup)
# print(indx_sup)
# for i in indx_sup:
#     print(trainy[i])
    

trainx_sup = trainx[indx_sup]
trainy_sup = trainy[indx_sup]
trainx_unsup = trainx[[i for i in range(num_train) if i not in indx_sup]]
trainy_unsup = trainy[[i for i in range(num_train) if i not in indx_sup]]

trainx_sup = np.reshape(trainx_sup, newshape=(trainx_sup.shape[0], 1 , 28, 28))
trainx_unsup = np.reshape(trainx_unsup, newshape=(trainx_unsup.shape[0], 1 , 28, 28))
testx = np.reshape(testx, newshape=(testx.shape[0], 1 , 28, 28))

In [4]:
def gpu_device(ctx=mx.gpu(0)):
    try:
        _ = mx.nd.array([1, 2, 3], ctx=ctx)
    except mx.MXNetError:
        return None
    return ctx

assert gpu_device(opt.ctx), 'No GPU device found!'

In [5]:
log_dir = os.path.join(opt.log_dir, opt.exp_name)
if not os.path.exists(log_dir):
    os.makedirs(log_dir)

In [6]:
data_shape = (1, 28, 28)
train_data_unsup = mx.io.NDArrayIter(data=trainx_unsup, label=trainy_unsup, batch_size=opt.batch_size // 2, shuffle=True, last_batch_handle='roll_over')
train_data_sup = mx.io.NDArrayIter(data=trainx_sup, label=trainy_sup, batch_size=opt.batch_size // 2, shuffle=True, last_batch_handle='roll_over')
valid_data = mx.io.NDArrayIter(data=testx, label=testy, batch_size=opt.batch_size // 2, shuffle=False, last_batch_handle='pad')

In [7]:
criterion = gluon.loss.SoftmaxCrossEntropyLoss()
L2_loss = gluon.loss.L2Loss()
L1_loss = gluon.loss.L1Loss()

In [8]:
class Normal(mx.init.Initializer):
    """Initializes weights with random values sampled from a normal distribution
    with a mean and standard deviation of `sigma`.
    """
    def __init__(self, mean=0, sigma=0.01):
        super(Normal, self).__init__(sigma=sigma)
        self.sigma = sigma
        self.mean = mean

    def _init_weight(self, _, arr):
        mx.random.normal(self.mean, self.sigma, out=arr)

In [9]:
# from resnet import ResNet164_v2
from mxnet.gluon.model_zoo import vision as models
from lenet_ld_opt import VGG_DRM

In [10]:
import datetime
writer = SummaryWriter(os.path.join(opt.log_dir, opt.exp_name))

def get_acc(output, label):
    pred = output.argmax(1, keepdims=False)
    correct = (pred == label).sum()
    return correct.asscalar()

def extract_acts(net, x, layer_indx):
    start_layer = 0
    out = []
    for i in layer_indx:
        for block in net.features._children[start_layer:i]:
            x = block(x)
        out.append(x)
        start_layer = i
    return out

In [11]:
def train(net, train_data_unsup, train_data_sup, valid_data, num_epochs, lr, wd, ctx, lr_decay, relu_indx):
    trainer = gluon.Trainer(
        net.collect_params(), 'adam', {'learning_rate': 0.001, 'wd': wd})
    
    prev_time = datetime.datetime.now()
    best_valid_acc = 0
    iter_indx = 0
    
    for epoch in range(num_epochs-100):
        train_data_unsup.reset(); train_data_sup.reset()
        train_loss = 0; train_loss_xentropy = 0; train_loss_drm = 0; train_loss_pn = 0; train_loss_kl = 0; train_loss_nn = 0
        correct = 0; total = 0
        num_batch_train = 0
        
        if epoch == 20:
            sgd_lr = 0.2
            decay_val = np.exp(np.log(sgd_lr / 0.0001) / (num_epochs - 2))
            sgd_lr = sgd_lr * decay_val
            trainer = gluon.Trainer(net.collect_params(), 'SGD', {'learning_rate': sgd_lr, 'wd': wd})
            
        if epoch >= 20:
            trainer.set_learning_rate(trainer.learning_rate / decay_val)
        
        for batch_unsup, batch_sup in zip(train_data_unsup, train_data_sup):
            assert batch_unsup.data[0].shape[0] == batch_sup.data[0].shape[0], "batch_unsup and batch_sup must have the same size"
            bs = batch_unsup.data[0].shape[0]
            data_unsup = batch_unsup.data[0].as_in_context(ctx)
            label_unsup = batch_unsup.label[0].as_in_context(ctx)
            data_sup = batch_sup.data[0].as_in_context(ctx)
            label_sup = batch_sup.label[0].as_in_context(ctx)
            with autograd.record():
                [output_unsup, xhat_unsup, _, loss_pn_unsup, loss_nn_unsup] = net(data_unsup)
                loss_drm_unsup = L2_loss(xhat_unsup, data_unsup)
                softmax_unsup = nd.softmax(output_unsup)
                loss_kl_unsup = -nd.sum(nd.log(10.0*softmax_unsup + 1e-8) * softmax_unsup, axis=1)
                loss_unsup = opt.alpha_drm * loss_drm_unsup + opt.alpha_kl * loss_kl_unsup + opt.alpha_nn * loss_nn_unsup + opt.alpha_pn * loss_pn_unsup
                
                [output_sup, xhat_sup, _, loss_pn_sup, loss_nn_sup] = net(data_sup, label_sup)
                loss_xentropy_sup = criterion(output_sup, label_sup)
                loss_drm_sup = L2_loss(xhat_sup, data_sup)
                softmax_sup = nd.softmax(output_sup)
                loss_kl_sup = -nd.sum(nd.log(10.0*softmax_sup + 1e-8) * softmax_sup, axis=1)
                loss_sup = loss_xentropy_sup + opt.alpha_drm * loss_drm_sup + opt.alpha_kl * loss_kl_sup + opt.alpha_nn * loss_nn_sup + opt.alpha_pn * loss_pn_sup
                
                loss = loss_unsup + loss_sup
                
            loss.backward()
            trainer.step(bs)
            
            loss_drm = loss_drm_unsup + loss_drm_sup
            loss_pn = loss_pn_unsup + loss_pn_sup
            loss_xentropy = loss_xentropy_sup
            loss_kl = loss_kl_unsup + loss_kl_sup
            loss_nn = loss_nn_unsup + loss_nn_sup
            
            train_loss_xentropy += nd.mean(loss_xentropy).asscalar()
            train_loss_drm += nd.mean(loss_drm).asscalar()
            train_loss_pn += nd.mean(loss_pn).asscalar()
            train_loss_kl += nd.mean(loss_kl).asscalar()
            train_loss_nn += nd.mean(loss_nn).asscalar()
            train_loss += nd.mean(loss).asscalar()
            correct += (get_acc(output_sup, label_sup) + get_acc(output_unsup, label_unsup))/2
            
            total += bs
            num_batch_train += 1
            iter_indx += 1
        
        writer.add_scalars('loss', {'train': train_loss / num_batch_train}, epoch)
        writer.add_scalars('loss_xentropy', {'train': train_loss_xentropy / num_batch_train}, epoch)
        writer.add_scalars('loss_drm', {'train': train_loss_drm / num_batch_train}, epoch)
        writer.add_scalars('loss_pn', {'train': train_loss_pn / num_batch_train}, epoch)
        writer.add_scalars('loss_kl', {'train': train_loss_kl / num_batch_train}, epoch)
        writer.add_scalars('loss_nn', {'train': train_loss_nn / num_batch_train}, epoch)
        writer.add_scalars('acc', {'train': correct / total}, epoch)
        
        cur_time = datetime.datetime.now()
        h, remainder = divmod((cur_time - prev_time).seconds, 3600)
        m, s = divmod(remainder, 60)
        time_str = "Time %02d:%02d:%02d" % (h, m, s)
        if valid_data is not None:
            valid_data.reset()
            valid_loss = 0; valid_loss_xentropy = 0; valid_loss_drm = 0; valid_loss_pn = 0; valid_loss_kl = 0; valid_loss_nn = 0
            valid_correct = 0; valid_total = 0
            num_batch_valid = 0
            for batch in valid_data:
                bs = batch.data[0].shape[0]
                data = batch.data[0].as_in_context(ctx)
                label = batch.label[0].as_in_context(ctx)
                [output, xhat, _, loss_pn, loss_nn] = net(data, label)
                loss_xentropy = criterion(output, label)
                loss_drm = L2_loss(xhat, data)
                softmax_val = nd.softmax(output)
                loss_kl = -nd.sum(nd.log(10.0*softmax_val + 1e-8) * softmax_val, axis=1)
                loss = loss_xentropy + opt.alpha_drm * loss_drm + opt.alpha_kl * loss_kl + opt.alpha_nn * loss_nn + opt.alpha_pn * loss_pn
                
                valid_loss_xentropy += nd.mean(loss_xentropy).asscalar()
                valid_loss_drm += nd.mean(loss_drm).asscalar()
                valid_loss_pn += nd.mean(loss_pn).asscalar()
                valid_loss_kl += nd.mean(loss_kl).asscalar()
                valid_loss_nn += nd.mean(loss_nn).asscalar()
                valid_loss += nd.mean(loss).asscalar()
                valid_correct += get_acc(output, label)
                
                valid_total += bs
                num_batch_valid += 1
            valid_acc = valid_correct / valid_total
            if valid_acc > best_valid_acc:
                best_valid_acc = valid_acc
                net.collect_params().save('%s/%s_best.params'%(opt.model_dir, opt.exp_name))
            writer.add_scalars('loss', {'valid': valid_loss / num_batch_valid}, epoch)
            writer.add_scalars('loss_xentropy', {'valid': valid_loss_xentropy / num_batch_valid}, epoch)
            writer.add_scalars('loss_drm', {'valid': valid_loss_drm / num_batch_valid}, epoch)
            writer.add_scalars('loss_pn', {'valid': valid_loss_pn / num_batch_valid}, epoch)
            writer.add_scalars('loss_kl', {'valid': valid_loss_kl / num_batch_valid}, epoch)
            writer.add_scalars('loss_nn', {'valid': valid_loss_nn / num_batch_valid}, epoch)
            writer.add_scalars('acc', {'valid': valid_acc}, epoch)
            epoch_str = ("Epoch %d. Train Loss: %f, Train Xent: %f, Train Reconst: %f, Train Pn: %f, Train acc %f, Valid Loss: %f, Valid acc %f, Best valid acc %f, "
                         % (epoch, train_loss / num_batch_train, train_loss_xentropy / num_batch_train, train_loss_drm / num_batch_train, train_loss_pn / num_batch_train,
                            correct / total, valid_loss / num_batch_valid, valid_acc, best_valid_acc))
            if not epoch % 20:
                net.collect_params().save('%s/%s_epoch_%i.params'%(opt.model_dir, opt.exp_name, epoch))
        else:
            epoch_str = ("Epoch %d. Loss: %f, Train acc %f, "
                         % (epoch, train_loss / num_batch_train,
                            correct / total))
        prev_time = cur_time
        print(epoch_str + time_str + ', lr ' + str(trainer.learning_rate))
        
    return best_valid_acc

In [12]:
num_epochs = 500
learning_rate = 0.2
weight_decay = 5e-4
lr_decay = 0.1

In [13]:
def run_train(num_exp, ctx):
    valid_acc = 0
    for i in range(num_exp):
        ### CIFAR VGG_DRM
        model = VGG_DRM('ConvSmallMNIST', batch_size=opt.batch_size // 2, num_class=10, use_bias=opt.use_bias, use_bn=opt.use_bn, do_topdown=opt.do_topdown, do_countpath=opt.do_countpath, do_pn=opt.do_pn, relu_td=opt.relu_td, do_nn=opt.do_nn)
        for param in model.collect_params().values():
            if param.name.find('conv') != -1 or param.name.find('dense') != -1:
                if param.name.find('weight') != -1:
                    param.initialize(init=mx.initializer.Xavier(), ctx=ctx)
                else:
                    param.initialize(init=mx.init.Zero(), ctx=ctx)
            elif param.name.find('batchnorm') != -1 or param.name.find('instancenorm') != -1:
                if param.name.find('gamma') != -1:
                    param.initialize(init=Normal(mean=1, sigma=0.02), ctx=ctx)
                else:
                    param.initialize(init=mx.init.Zero(), ctx=ctx)
            elif param.name.find('biasadder') != -1:
                param.initialize(init=mx.init.Zero(), ctx=ctx)
                
        # model.hybridize()
        
        relu_indx = []
        
        for i in range(len(model.features._children)):
            if model.features._children[i].name.find('relu') != -1:
                relu_indx.append(i)
                
        acc = train(model, train_data_unsup, train_data_sup, valid_data, num_epochs, learning_rate, weight_decay, ctx, lr_decay, relu_indx)
        print('Validation Accuracy - Run %i = %f'%(i, acc))
        valid_acc += acc

    print('Validation Accuracy = %f'%(valid_acc/num_exp))

In [14]:
run_train(1, ctx=opt.ctx)

Epoch 0. Train Loss: 0.150739, Train Xent: 0.413044, Train Reconst: 0.299883, Train Pn: 0.055434, Train acc 0.739553, Valid Loss: 2.187451, Valid acc 0.596238, Best valid acc 0.596238, Time 00:01:43, lr 0.001


KeyboardInterrupt: 