In [1]:
from __future__ import print_function           
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import time
import os
  
# LeNet-Conv
def lenet(x, y, name='lenet', reuse=None):
  relu = tf.nn.relu
  dense = tf.layers.dense
  flatten = tf.contrib.layers.flatten
  T = tf.get_variable(name='temperature', trainable=True, initializer=tf.constant(1.0))

  def conv(x, filters, kernel_size=3, strides=1, **kwargs):
    return tf.layers.conv2d(x, filters, kernel_size, strides,
        data_format='channels_first', **kwargs)

  def pool(x, **kwargs):
    return tf.layers.max_pooling2d(x, 2, 2,
        data_format='channels_first', **kwargs)

  def cross_entropy(logits, labels):
    return tf.losses.softmax_cross_entropy(logits=logits,
        onehot_labels=labels)

  def accuracy(logits, labels):
    correct = tf.equal(tf.argmax(logits, 1), tf.argmax(labels, 1))
    return tf.reduce_mean(tf.cast(correct, tf.float32))
  
  x = tf.reshape(x, [-1, 1, 28, 28])
  x = conv(x, 20, 5, name=name+'/conv1', reuse=reuse)
  x = relu(x)
  x = pool(x, name=name+'/pool1')
  x = conv(x, 50, 5, name=name+'/conv2', reuse=reuse)
  x = relu(x)
  x = pool(x, name=name+'/pool2')
  x = flatten(x)
  x = dense(x, 500, activation=relu, name=name+'/dense', reuse=reuse)
  logit = dense(x, 10, name=name+'/logits', reuse=reuse)
  
  net = {}
  all_vars = tf.trainable_variables()
  net['cent'] = cross_entropy(logit / T, y)
  net['acc'] = accuracy(logit / T, y)
  net['output'] = tf.nn.softmax(logit / T)
  net['weights'] = [v for v in all_vars]
  net['temp_var'] = T
  return net

# MNIST data loader
def mnist_input(path):
    mnist = input_data.read_data_sets(path, one_hot=True, validation_size=0)
    x, y = mnist.train.images, mnist.train.labels
    y_ = np.argmax(y, axis=1)

    xtr = [x[y_==k][:30,:] for k in range(10)]
    ytr = [y[y_==k][:30,:] for k in range(10)]
    xtr, ytr = np.concatenate(xtr, axis=0), np.concatenate(ytr, axis=0)

    xva = [x[y_==k][30:40,:] for k in range(10)]
    yva = [y[y_==k][30:40,:] for k in range(10)]
    xva, yva = np.concatenate(xva, axis=0), np.concatenate(yva, axis=0)

    xte, yte = mnist.test.images, mnist.test.labels
    return xtr, ytr, xva, yva, xte, yte

args = {
    'mnist_path': './data',
    'batch_size': 100,
    'n_epochs': 200,
    'gpu_num': 0    
}
os.environ['CUDA_VISIBLE_DEVICES'] = str(args['gpu_num'])
                                         
bs = args['batch_size']
xtr, ytr, xva, yva, xte, yte = mnist_input(args['mnist_path'])
x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10])
net = lenet(x, y)
loss = net['cent']
  
# Training
def run():
  global_step = tf.train.get_or_create_global_step()
  lr_step = (300//args['batch_size'])*args['n_epochs']/2
  lr = tf.train.piecewise_constant(tf.cast(global_step, tf.float32),
      [lr_step], [1e-3, 1e-4])
  train_op = tf.train.AdamOptimizer(lr).minimize(loss, 
                                                 global_step=global_step,
                                                 var_list=net['weights'])

  val_op = tf.train.AdamOptimizer(1e-1).minimize(loss,
                                                 var_list=net['temp_var'])

  sess = tf.Session()
  sess.run(tf.global_variables_initializer())

  # Training
  for i in range(args['n_epochs']):
    # shuffle the training data every epoch
    xytr = np.concatenate((xtr, ytr), axis=1)
    np.random.shuffle(xytr)
    xtr_, ytr_ = xytr[:,:784], xytr[:,784:]

    for j in range(300//args['batch_size']):
      bx, by = xtr_[j*bs:(j+1)*bs,:], ytr_[j*bs:(j+1)*bs,:]
      _, cent, acc = sess.run([train_op, net['cent'], net['acc']], 
                              {x:bx, y:by})    
    if i % 10 == 0:
      print('epoch %d: cent = %f, acc = %f' % (i, cent, acc))
    
  #print ECE before temperature scaling
  cent, acc, output = sess.run([net['cent'], net['acc'], net['output']], 
                               {x:xte, y:yte})
  label = yte
  print('ECE = %f' % ece(output, label))
  # temperature scaling
  for i in range(args['n_epochs']):
    _, cent, acc = sess.run([val_op, net['cent'], net['acc']], {x:xva, y:yva})
    if i % 10 == 0:
      print('epoch %d: cent = %f, acc = %f' % (i, cent, acc))

  # Test & ECE
  cent, acc, output = sess.run([net['cent'], net['acc'], net['output']], 
                               {x:xte, y:yte})
  label = yte
  print('Test: cent=%f, acc=%f' % (cent, acc))
  print('ECE = %f' % ece(output, label))
  
def ece(output, label):
  idx = (np.arange(10000),np.argmax(output,1))
  conf = output[idx]
  correct = label[idx]

  M = 10
  bins, confs, accs = np.zeros(M), np.zeros(M), np.zeros(M)

  for m in range(M):
    idx = (m*0.1 <= conf) * (conf <= (m+1)*0.1)
    nbin = sum(idx)
    bins[m] = nbin
    confs[m] = 0. if nbin == 0 else conf[idx].mean()
    accs[m] = 0. if nbin == 0 else correct[idx].mean()

  ece = np.sum((bins/float(10000))*np.abs(accs-confs))
  return ece

W0730 21:38:36.577927 27284 deprecation.py:323] From <ipython-input-1-2567aa2a3f8b>:53: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
W0730 21:38:36.578924 27284 deprecation.py:323] From c:\users\ironm\tf-nightly\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please write your own downloading logic.
W0730 21:38:36.579921 27284 deprecation.py:323] From c:\users\ironm\tf-nightly\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Ple

Extracting ./data\train-images-idx3-ubyte.gz
Extracting ./data\train-labels-idx1-ubyte.gz
Extracting ./data\t10k-images-idx3-ubyte.gz
Extracting ./data\t10k-labels-idx1-ubyte.gz


W0730 21:38:36.778420 27284 deprecation.py:323] From c:\users\ironm\tf-nightly\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
W0730 21:38:37.161392 27284 deprecation.py:323] From <ipython-input-1-2567aa2a3f8b>:17: conv2d (from tensorflow.python.layers.convolutional) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.keras.layers.Conv2D` instead.
W0730 21:38:37.164388 27284 deprecation.py:506] From c:\users\ironm\tf-nightly\lib\site-packages\tensorflow\python\ops\init_ops.py:1251: calling VarianceScaling.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtyp

In [2]:
run()

epoch 0: cent = 2.218958, acc = 0.340000
epoch 10: cent = 0.105859, acc = 0.980000
epoch 20: cent = 0.018365, acc = 1.000000
epoch 30: cent = 0.004111, acc = 1.000000
epoch 40: cent = 0.002644, acc = 1.000000
epoch 50: cent = 0.001318, acc = 1.000000
epoch 60: cent = 0.000909, acc = 1.000000
epoch 70: cent = 0.000581, acc = 1.000000
epoch 80: cent = 0.000536, acc = 1.000000
epoch 90: cent = 0.000373, acc = 1.000000
epoch 100: cent = 0.000350, acc = 1.000000
epoch 110: cent = 0.000331, acc = 1.000000
epoch 120: cent = 0.000356, acc = 1.000000
epoch 130: cent = 0.000349, acc = 1.000000
epoch 140: cent = 0.000351, acc = 1.000000
epoch 150: cent = 0.000322, acc = 1.000000
epoch 160: cent = 0.000269, acc = 1.000000
epoch 170: cent = 0.000277, acc = 1.000000
epoch 180: cent = 0.000274, acc = 1.000000
epoch 190: cent = 0.000281, acc = 1.000000
ECE = 0.076206
epoch 0: cent = 0.454206, acc = 0.940000
epoch 10: cent = 0.293485, acc = 0.940000
epoch 20: cent = 0.278875, acc = 0.940000
epoch 30: c