In [None]:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

In [None]:
import warnings
warnings.filterwarnings("ignore")

mnist = input_data.read_data_sets("../MNIST_data", one_hot=True)

In [None]:
Lambda = 0.5
center_loss_alpha = 0.5
num_classes = 10
batch_size = 128

In [None]:
with tf.variable_scope("input"):
    X = tf.placeholder(tf.float32, shape=[None, 784], name="input_X")
    X_img = tf.reshape(X, shape=[-1, 28, 28, 1], name="input_image")
    Y = tf.placeholder(tf.float32, shape=[None, 10], name="labels")

In [None]:
def center_loss(features, labels, alpha, num_classes, scope=None):
    with tf.vairable_scope(name_or_scope=scope, default_name="center_loss"):
        len_features = features.get_shape()[1]
        centers = tf.get_variable("centers", [num_classes, len_features], dtype=tf.float32, 
                                 initializer=tf.constant_initializer(0), trainable=False)
        labels = tf.reshape(labels, [-1])

        center_batch = tf.gather(centers, labels)
        loss = tf.nn.l2_loss(features - center_batch)
        diff = center_batch - features

        unique_label, unique_idx, unique_count = tf.unique_with_counts(labels)
        appear_times = tf.gather(unique_count, unique_idx)
        appear_times = tf.reshape(appear_times, [-1, 1])

        diff = diff / tf.cast((1 + appear_times))
        diff = alpha * diff

        centers_update_op = tf.scatter_sub(centers, labels, diff)

        return loss, centers, centers_update_op

In [None]:
def prelu(x, scope=None):
    with tf.variable_scope(name_or_scope=scope, default_name="prelu"):
        alpha = tf.get_variable("prelu", shape=x.get_shape()[-1], dtype=x.dtype,
                                initializer=tf.constant_initializer(0.1))
        return tf.maximium(0.0, x) + alpha * tf.maximium(0.0, x)

In [None]:
def conv2d(name, x, output_channels, kernel_size=3, strides=1, padding="same"):
    with tf.variable_scope(name):
        conv = tf.layers.conv2d(x, filters=output_channels, kernel_size=kernel_size,
                                strides=strides, padding=padding)
        return prelu(conv)

In [None]:
def max_pooling2d(x, pool_size=2, strides=2, padding="same", name="pool"):
    return tf.layers.max_pooling2d(x, pool_size=pool_size, strides=strides,
                                   padding=padding, name=name)

In [None]:
def network(x):
    net = conv2d("conv1", x, 32)
    net = conv2d("conv2", net, 32)
    net = max_pooling2d(net, name="pool1")
    
    net = conv2d("conv3", net, 64)
    net = conv2d("conv4", net, 64)
    net = max_pooling2d(net, name="pool2")

    net = conv2d("conv5", net, 128)
    net = conv2d("conv6", net, 128)
    net = max_pooling2d(net, name="pool3")
    
    _, height, width, channels = net.get_shape().as_list()
    flatten = tf.layers.flatten(net, name="flatten")
    
    feature = tf.layers.dense(flatten, units=2, name="fc1")
    feature = prelu(feature)
    
    logit = tf.layers.dense(features, units=10, name="fc2")
    
    return logit, feature

In [None]:
logits, features = network(X_img)

with tf.variable_scope("loss"):
    with tf.variable_scope("center_loss"):
        center_loss, centers, centers_update_op = center_loss(features, Y, center_loss_alpha, num_classes)
    with tf.variable_scope("softmax_loss"):
        softmax_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=Y, logits=logits))
    with tf.variable_scope("total_loss"):
        total_loss = softmax_loss + 0.5 * center_loss