In [0]:
#@title imports
from functools import partial
import importlib
import math

import tensorflow as tf

import tf.contrib.slim as slim
from slim import preprocess

import tf.app as app

from sklearn import metrics

import numpy as np

import matplotlib.pylab as pl
import matplotlib.patheffects as PathEffects

from IPython.display import clear_output, display, Image, HTML


from OpenCVX import cvx2 as cv2

import semisup, mnist_tools, svhn_tools, synth_tools, train

import PIL.Image

import itertools as it
from cStringIO import StringIO

flags = tf.app.flags
FLAGS = flags.FLAGS

In [0]:
#@title boilerplate

def plot_conf_mtx(conf_mtx):
  norm_conf = []
  for i in conf_mtx:
      a = 0
      tmp_arr = []
      a = sum(i, 0)
      for j in i:
          tmp_arr.append(float(j)/float(a))
      norm_conf.append(tmp_arr)

  fig = pl.figure(figsize=(10,10))
  ax = fig.add_subplot(111)
  ax.set_aspect(1)
  res = ax.imshow(np.array(norm_conf), cmap=pl.cm.jet, 
                  interpolation='nearest')

  width, height = conf_mtx.shape

  for x in xrange(width):
      for y in xrange(height):
          ax.annotate(str(conf_mtx[x][y]), xy=(y, x), 
                      horizontalalignment='center',
                      verticalalignment='center', 
                      color='w', weight='bold',
                      path_effects=[PathEffects.withStroke(linewidth=2,foreground="k")])

  ax.set_xticks(range(num_labels))
  ax.set_yticks(range(num_labels))

  ax_lbls = [str(i) for i in range(num_labels)]
  ax.set_xticklabels(ax_lbls)
  ax.set_yticklabels(ax_lbls)
  pl.grid()  

def eval_mnist():
  test_pred = []
  test_lbls = []
  runs = len(test_labels)//eval_batch_size
  for i in range(runs):
    print i+1, '/', runs
    clear_output(True)
    res = sess.run([test_predictions, t_test_labels])
    test_pred.append(res[0])
    test_lbls.append(res[1])

  test_pred = np.array(test_pred).flatten()
  test_lbls = np.array(test_lbls).flatten()  
  truth_array = (test_pred == test_lbls)
  test_err = 1.0-(truth_array).mean()
  conf_mtx = metrics.confusion_matrix(test_lbls, test_pred)

  print 'Test error: %.2f %%' % (test_err*100)
  
  return test_err, conf_mtx

def showarray(a, fmt='jpeg'):
    a = np.asarray(a)
    if a.dtype in [np.float32, np.float64]:
        a = np.uint8(np.clip(a, 0, 1)*255)
    f = StringIO()
    PIL.Image.fromarray(a).save(f, fmt)
    display(Image(data=f.getvalue()))

def grouper(iterable, n, fillvalue=None):
    "Collect data into fixed-length chunks or blocks"
    # grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx
    args = [iter(iterable)] * n
    return it.izip_longest(fillvalue=fillvalue, *args)

def tile2d(a, w=16):
    pad = np.zeros_like(a[0])
    return np.vstack(map(np.hstack, grouper(a, w, pad)))

def plot_results(ckpt_fn): 
  removed_classes = int(ckpt_fn.split('rem')[1].split('/')[0])
  cfg = tf.ConfigProto(gpu_options={'allow_growth':True})
  sess = tf.InteractiveSession(graph=graph, config=cfg)
  tf.initialize_all_variables().run()
  coord = tf.Coordinator()
  threads = tf.start_queue_runners(sess=sess, coord=coord)

  saver = tf.Saver()
  saver.restore(sess, ckpt_fn)

  test_pred = []
  test_lbls = []
  runs = len(test_labels)//eval_batch_size
  for i in range(runs):
    print i+1, '/', runs
    clear_output(True)
    res = sess.run([predictions, labels])
    test_pred.append(res[0])
    test_lbls.append(res[1])

  coord.request_stop()
  coord.join(threads)
  test_pred = np.array(test_pred).flatten()
  test_lbls = np.array(test_lbls).flatten()

  truth_array = np.array([test_pred[i] == test_lbls[i] for i in range(len(test_pred)) if test_lbls[i] in range(num_labels-removed_classes)])
  test_err = 1.0-(truth_array).mean()
  conf_mtx = metrics.confusion_matrix(test_lbls, test_pred)

  print removed_classes, 'classes removed'
  print 'fraction of samples used for eval:', float(len(truth_array))/len(test_pred)
  print 'Test error for not-removed classes: %.2f %%' % (test_err*100)

  norm_conf = []
  for i in conf_mtx:
      a = 0
      tmp_arr = []
      a = sum(i, 0)
      for j in i:
          tmp_arr.append(float(j)/float(a))
      norm_conf.append(tmp_arr)

  fig = pl.figure(figsize=(10,10))
  ax = fig.add_subplot(111)
  ax.set_aspect(1)
  res = ax.imshow(np.array(norm_conf), cmap=pl.cm.jet, 
                  interpolation='nearest')

  width, height = conf_mtx.shape

  for x in xrange(width):
      for y in xrange(height):
          ax.annotate(str(conf_mtx[x][y]), xy=(y, x), 
                      horizontalalignment='center',
                      verticalalignment='center', 
                      color='w', weight='bold',
                      path_effects=[PathEffects.withStroke(linewidth=2,foreground="k")])
  ax.vlines(x=num_labels-removed_classes-0.5 , ymin=-0.5, ymax=num_labels-0.5, color='w')
  ax.hlines(y=num_labels-removed_classes-0.5 , xmin=-0.5, xmax=num_labels-0.5, color='w')
  ax.set_xticks(range(num_labels))
  ax.set_yticks(range(num_labels))

  ax_lbls = [str(i) if i<num_labels-removed_classes else '(%d)' % i for i in range(num_labels)]
  ax.set_xticklabels(ax_lbls)
  ax.set_yticklabels(ax_lbls)
  pl.grid()

In [0]:
#@title load model trained on source domain

FLAGS = train.FLAGS
FLAGS.visit_weight = 0.5

eval_batch_size = 2000

# Dynamic import of the set of tools containing the network architecture etc.
source_tools = synth_tools
target_tools = svhn_tools

num_labels = source_tools.NUM_LABELS
image_shape = source_tools.IMAGE_SHAPE
visit_weight = FLAGS.visit_weight
logit_weight = FLAGS.logit_weight

# Load data.
train_images, train_labels = source_tools.get_data('train')
train_images_unlabeled, _ = target_tools.get_data('train')

test_images, test_labels = target_tools.get_data('test')


# Sample labeled training subset.
seed = FLAGS.sup_seed if FLAGS.sup_seed != -1 else None
sup_by_label = semisup.sample_by_label(train_images, train_labels,
                                       FLAGS.sup_per_class, num_labels, seed)


graph = tf.Graph()
with graph.as_default():
  with tf.device(tf.ReplicaDeviceSetter(FLAGS.ps_tasks, merge_devices=True)):

    # Set up inputs.
    t_unsup_images = semisup.create_input(train_images_unlabeled, None,
                                          FLAGS.unsup_batch_size)
    t_sup_images, t_sup_labels = semisup.create_per_class_inputs(
        sup_by_label, FLAGS.sup_per_batch)
    
 
    t_test_image, t_test_label = tf.train.slice_input_producer([test_images, test_labels])
    t_test_images, t_test_labels = tf.train.batch(
        [t_test_image, t_test_label], batch_size=eval_batch_size)
    #t_test_images = tf.cast(t_test_images, tf.float32)
    t_test_labels = tf.cast(t_test_labels, tf.int64)


    # Resize if necessary.
    if FLAGS.new_size > 0:
      new_shape = [FLAGS.new_size, FLAGS.new_size, 3]
    else:
      new_shape = None
      
    # Adapt unlabeled data.


    # Apply augmentation


    # Create function that defines the network.
    model_function = partial(
        source_tools.default_model,
        new_shape=new_shape,
        img_shape=image_shape,
        augmentation_function=None,
        batch_norm_decay=FLAGS.batch_norm_decay)

    # Set up semisup model.
    model = semisup.SemisupModel(model_function, num_labels, image_shape,
                                test_in=t_test_images)

    # Compute embeddings and logits.
    t_sup_emb = model.image_to_embedding(t_sup_images)
    t_unsup_emb = model.image_to_embedding(t_unsup_images)

    t_sup_logit = model.embedding_to_logit(t_sup_emb)

    # Add losses.
    if FLAGS.visit_weight_sigmoid:
      visit_weight = logistic_growth(model.step, FLAGS.visit_weight,
                                     FLAGS.max_steps)
    else:
      visit_weight = FLAGS.visit_weight
    slim.summaries.add_scalar_summary(visit_weight, 'VisitLossWeight')

    if FLAGS.unsup_samples != 0:
      model.add_semisup_loss(
          t_sup_emb, t_unsup_emb, t_sup_labels, visit_weight=visit_weight)
    model.add_logit_loss(t_sup_logit, t_sup_labels, weight=logit_weight)

    # Set up learning rate schedule if necessary.
    if FLAGS.custom_lr_vals is not None and FLAGS.custom_lr_steps is not None:
      boundaries = [
          tf.convert_to_tensor(x, tf.int64) for x in FLAGS.custom_lr_steps
      ]

      t_learning_rate = tf.train.piecewise_constant(model.step, boundaries,
                                                    FLAGS.custom_lr_vals)
    else:
      t_learning_rate = tf.maximum(
          tf.train.exponential_decay(
              FLAGS.learning_rate,
              model.step,
              FLAGS.decay_steps,
              FLAGS.decay_factor,
              staircase=True),
          FLAGS.minimum_learning_rate)

    lr_placeholder = tf.placeholder(tf.float32)
    # Create training operation and start the actual training loop.
    train_op = model.create_train_op(lr_placeholder)
    
    # Get prediction tensor from semisup model.
    test_predictions = tf.argmax(model.test_logit, 1)
    

cfg = tf.ConfigProto(gpu_options={'allow_growth':True})
sess = tf.InteractiveSession(graph=graph, config=cfg)
tf.initialize_all_variables().run()
coord = tf.Coordinator()
threads = tf.start_queue_runners(sess=sess, coord=coord)

saver = tf.Saver()
ckpt_fn = 'path_to_checkpoint' #@param
saver.restore(sess, ckpt_fn)    

In [0]:
#@title visualize batches
res = sess.run([t_sup_images, t_unsup_images, t_test_images])

print 'sup images (source)'
showarray(tile2d(res[0][:16]))

print 'unsup images (target)'
showarray(tile2d(res[1][:16]))

print 'test images'
showarray(tile2d(res[2][:16]))

In [0]:
#@title evaluate model on target domain  
test_err, conf_mtx = eval_mnist()  
plot_conf_mtx(conf_mtx)

In [0]:
#@title reset network to train from scratch using data from both domains
reset = False #@param
if reset:
  coord.request_stop()
  coord.join(threads)

  cfg = tf.ConfigProto(gpu_options={'allow_growth':True})
  sess = tf.InteractiveSession(graph=graph, config=cfg)
  tf.initialize_all_variables().run()
  coord = tf.Coordinator()
  threads = tf.start_queue_runners(sess=sess, coord=coord)

In [0]:
loss_steps = []
loss_vals = []
test_steps = []
test_errs = []

step = 0

In [0]:
try:
  for step in xrange(step, 50000):
    
    lr = 1e-4 if step < 10000 else 1e-5
    res = sess.run([train_op, model.train_loss_average], feed_dict={lr_placeholder: lr})

    if step % 500 == 0:
      test_err, conf_mtx = eval_mnist()  
      test_steps.append(step)
      test_errs.append(test_err)

    if step % 50 == 0:
      loss_steps.append(step)
      loss_avg = res[1]
      loss_vals.append(loss_avg)

      print 'step', step, '| avg loss', loss_avg, '| test err', test_err*100, '%'
      
      best_iter = np.argmin(test_errs)
      print 'best', test_errs[best_iter]*100, '% @ iteration', test_steps[best_iter]
      
      fig, ax1 = pl.subplots()
      ax1.set_xlabel('Iteration')
      ax1.set_ylabel('Train Loss (Avg)', color='b')
      ax1.plot(loss_steps, loss_vals, label='avg loss', color='b')

      ax2 = ax1.twinx()
      ax2.plot(test_steps, np.array(test_errs)*100, color='r')
      ax2.set_ylabel('Test Err (%)', color='r')

      
      plot_conf_mtx(conf_mtx)
      pl.show()
      clear_output(True)
except KeyboardInterrupt:
  pass

In [0]:
cfg = tf.ConfigProto(gpu_options={'allow_growth':True})
sess = tf.InteractiveSession(graph=graph, config=cfg)
tf.initialize_all_variables().run()
coord = tf.Coordinator()
threads = tf.start_queue_runners(sess=sess, coord=coord)

In [0]:
coord.request_stop()
coord.join(threads)