In [132]:
import tensorflow as tf

In [133]:
from tqdm import tqdm_notebook
from cleverhans.model_zoo.soft_nearest_neighbor_loss.SNNL_regularized_model import ModelBasicCNN
from cleverhans.loss import SNNLCrossEntropy
from data_utils import Config, DataPipeline
from summary_utils import *
from graph_utils import *
from collections import namedtuple

%reload_ext autoreload
%autoreload 2

In [134]:
def load_mnist():
    data = tf.keras.datasets.mnist.load_data()
    data = [(x.astype('float32'), y.astype('int32')) for x,y in data]
    (x_trainval, y_trainval), (x_test, y_test) = data
    #y_trainval = tf.keras.utils.to_categorical(y_trainval)
    #y_test = tf.keras.utils.to_categorical(y_test)
    n_trainval = len(x_trainval)
    inds_train, inds_val = np.split(np.random.permutation(n_trainval), [n_trainval - len(x_test)])
    return {'train': (x_trainval[inds_train]/255, y_trainval[inds_train]), 
            'valid': (x_trainval[inds_val]/255, y_trainval[inds_val]), 
            'test': (x_test/255, y_test)}

def rand_bool(prob=0.5):
    return tf.greater(tf.random_uniform([]), prob)

def simple_aug(img, aug_prob):
    img = img[None,...,None]
    img = tf.cond(rand_bool(), 
                  lambda: tf.reverse(img, axis=tf.random_uniform([1], 1, 3, dtype=tf.int32)), 
                  lambda:img)
    img = tf.cond(rand_bool(), 
                  lambda: tf.contrib.image.rotate(img, tf.random_uniform([], -1/6, 1/6)*np.pi), 
                  lambda: img)
    img = tf.cond(rand_bool(),
                  lambda: tf.contrib.image.translate(img, tf.random_uniform([2], 0, 4)),
                  lambda: img)
    return img[0]
    

In [135]:
def logsumexp_masked(x, mask, axis):
    x_max = tf.reduce_max(x, axis=axis, keepdims=True)
    return x_max + tf.log(tf.reduce_sum(tf.exp(x - x_max)*mask, axis=axis))

def get_snnl_loss(features, labels, temp=100.):
    features = tf.layers.flatten(features)
    x_not_equal = 1 - tf.eye(tf.shape(features)[0])
    y_equal = tf.to_float(tf.equal(labels[:,None], labels[None]))
    diff = -tf.reduce_sum((features[:,None] - features[None])**2,  axis=-1)/temp
    eps = 0.00001 
    exp_mat = tf.exp(diff)*x_not_equal
    snnl_losses = tf.log(eps + tf.reduce_sum(exp_mat*y_equal, axis=-1)) - tf.log(eps + tf.reduce_sum(exp_mat, axis=-1))
    return -tf.reduce_mean(snnl_losses)

def ce_snnl_loss(logits, layers, labels, temp=100., factor=-10.):
    ce_loss = tf.losses.sparse_softmax_cross_entropy(logits=logits, labels=labels)
    snnl_loss = tf.reduce_sum([get_snnl_loss(tf.layers.flatten(layer), labels, temp) for layer in layers])
    return ce_loss + factor*snnl_loss


In [136]:
#TODO:
#[] Get MNIST inputs
#[] Create pipeline from them
tf.reset_default_graph()
import numpy as np

data = load_mnist()
data['train'] = (np.concatenate([data['train'][0], data['valid'][0]]), 
                 np.concatenate([data['train'][1], data['valid'][1]]))
data['valid'] = data['test']
pipeline = DataPipeline(batch_size=128, 
                        map_fn=lambda x, y, mode: ((simple_aug(x, 0.5) if mode=='train' else x[...,None]), y))
(images, labels), dataset_handle = pipeline.preproc(data)

model = ModelBasicCNN(nb_classes=10, nb_filters=64, scope='')
outputs = model.fprop(images)
layers = [outputs['conv%i'%i] for i in range(1,4)]
logits = outputs['logits']
training = tf.placeholder(shape=[], dtype=tf.bool)

loss = ce_snnl_loss(logits, layers, labels)
ce_loss_op, weighted_snnl_loss_op = loss.op.inputs
_, snnl_loss_op = weighted_snnl_loss_op.op.inputs

labels_pred = tf.to_int32(tf.argmax(logits, axis=-1))
acc = tf.reduce_mean(tf.to_float(tf.equal(labels_pred, labels)))

losses = {'loss':loss, 'ce_loss':ce_loss_op, 'snnl_loss':snnl_loss_op}
losses['acc'] = acc
losses_avg, resets = add_metric_avg_ops(losses)
train_summary_op = add_scalar_summaries(losses)
valid_summary_op = add_scalar_summaries(losses_avg)
train_step = get_train_op(loss, 'AdamOptimizer', learning_rate=1e-3)

In [None]:
config = Config('test')
config.set_train_stats(batch_size=128, 
                       num_train=len(data['train'][0]), 
                       num_val=len(data['valid'][0]),
                       n_epochs=10)
config.set_metric_attrs()
config.make_paths()
saver = tf.train.Saver()
best_loss = config.best_metric 
with tf.Session() as sess:
    if config.ckpt is not None:
        print('Restoring weights from', config.ckpt)
        saver.restore(sess, config.ckpt)
    else:
        sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())
    
    pipeline.prepare(sess)
    
    train_handle, valid_handle = list(map(pipeline.handles.get, ['train', 'valid']))
    train_writer, val_writer = get_summary_writers(sess, config.logs_path, ['train', 'valid'])
    
    tq_train = tqdm_notebook(range(config.iters_done+1, config.iters_done+config.n_iters+1), 
                             initial=config.iters_done+1)
    
    
    for it in tq_train:
        fetch = [train_step, losses_avg, train_summary_op]
        fetch_vals = sess.run(fetch, {dataset_handle: train_handle, training:True})
        _, losses_avg_val, train_sum_str = fetch_vals
        
        tq_train.set_postfix(**losses_avg_val)
        train_writer.add_summary(train_sum_str, it)
        
        if (it%config.valid_every) == 0:
            sess.run(resets, {training:True, dataset_handle: train_handle})
            tq_valid = tqdm_notebook(range(1, config.valid_iters+1), initial=1)
            
            for val_iter in tq_valid:
                fetch_valid = [losses_avg, valid_summary_op]
                fetch_valid_vals = sess.run(fetch_valid, {dataset_handle: valid_handle, training:False})
                losses_valid_val, val_sum_str = fetch_valid_vals
                
                tq_valid.set_postfix(**losses_valid_val)
            
            sess.run(resets, {dataset_handle: valid_handle, training:False})
            val_writer.add_summary(val_sum_str, it)
        
        
            present_loss = losses_valid_val['acc'] 
            if config.metric_compare(present_loss, best_loss):
                print('Validation accuracy increased from {:.4f} to {:.4f}'.format(best_loss, 
                                                                               present_loss))
                save_path = saver.save(sess=sess, 
                                       save_path='{}/best'.format(config.save_path))
                print('Saving to {}'.format(save_path))
                best_loss = present_loss
                config.update_best(float(best_loss), int(it))
                
            config.iters_done = int(it)
            config.save_json()
            
    saver.save(sess=sess, 
               save_path='{}/last'.format(config.save_path))

        
            
        

HBox(children=(IntProgress(value=0, max=4690), HTML(value='')))