In [1]:
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
from __future__ import division
from __future__ import absolute_import
from __future__ import print_function

import tensorflow as tf
import numpy as np

from sklearn.metrics import accuracy_score

In [2]:
import tensorflow.contrib.layers as tcl
class SSL:
    def __init__(self):
        # classifier params
        self.hidden_size = 500
        self.num_labels = 10
        
        # encode & decoder params
        self.z_dim = 50
        self.x_dim = 28*28
        
        # training
        self.learning_rate = 1e-3
        
        self._build_graph()
        self._build_train_op()
        self.check_parameters()
    
    def check_parameters(self):
        for var in tf.trainable_variables():
            print('%s: %s' % (var.name, var.get_shape()))
        print()
    
    def classify(self, x, reuse = False):
        with tf.variable_scope('classifier', reuse = reuse):
            h = tcl.fully_connected(x, self.hidden_size, activation_fn = tf.nn.softplus)
            y = tcl.fully_connected(h, self.num_labels, activation_fn = tf.nn.softplus)
        return y
    
    def reparameterize(self, mu, logvar):
        batch_size = tf.shape(mu)[0]
        std = tf.exp(logvar * 0.5)
        eps = tf.random_normal([batch_size, self.z_dim])
        z = mu + eps * std
        return z
    
    def encode(self, x, y, reuse = False):
        with tf.variable_scope('encoder', reuse = reuse):
            concat = tf.concat([x, y], 1)
            h = tcl.fully_connected(concat, self.hidden_size, activation_fn = tf.nn.softplus)

            mu     = tcl.fully_connected(h, self.z_dim, activation_fn = None)
            logvar = tcl.fully_connected(h, self.z_dim, activation_fn = None)

            z = self.reparameterize(mu, logvar)
        return z, mu, logvar
    
    def decode(self, z, y, reuse = False):
        with tf.variable_scope('decoder', reuse = reuse):
            concat = tf.concat([z, y], 1)
            h = tcl.fully_connected(concat, self.hidden_size, activation_fn = tf.nn.softplus)
            x_recon = tcl.fully_connected(h, self.x_dim, activation_fn = tf.nn.sigmoid)
        return x_recon
    
    def L(self, x, x_recon, y, mu, logvar):
        '''
            calculate the ELBO with SGVB.
                where E_q(z)[ log p(x,y|z)] is estimated with 
                    sum_1^L log p(x,y|z)
                here, we set L = 1, thus
                
            L(x, y) = E_q(z)[ log p(x,y|z)] - KL(q(z)||p(z))
                    ~ log p(x,y|z) - KL(q(z)||p(z))
                    = log p(x|z) + log p(y) - KL(q(z)||p(z))
            
            since we use bernoulli decoder, 
                log p(x|z) = log p^(x)(1-p)^(1-x)
                           = x log p + (1-x) log (1-p)
            here p = x_recon, x = x
            
            the polynomial dist could be emitted, because 
                log p(y) = log 1/N = const
                
            NOTE:
                the return value is - L, here L refers to loglikihood
        '''
        def KLD(mu, logvar):
            return - 0.5*(1+logvar-tf.square(mu)-tf.exp(logvar))
        def log_bernoulli(p, x):
            epsilon = 1e-8
            return x * tf.log(p + epsilon) + (1-x) * tf.log(1-p + epsilon)
        
        # uniform dist
        prior_y = (1. / self.num_labels) * tf.ones_like( y )
        logpy = - tf.nn.softmax_cross_entropy_with_logits(labels = prior_y, logits = y )

        # (batch_size, z_dim) -> batch_size,
        kldloss = tf.reduce_sum(KLD(mu, logvar),1)
        # (batch_size, 784) -> batch_size,
        logpx   = tf.reduce_sum(log_bernoulli(x_recon, x), 1)
        
        loss = kldloss - logpx - logpy
        return loss
    
    def _build_graph(self, reuse = False):
        self.x_l = tf.placeholder(tf.float32, shape = (None, self.x_dim))
        self.y_l = tf.placeholder(tf.int32, shape = (None, ))
        self.x_u = tf.placeholder(tf.float32, shape = (None, self.x_dim))
        
        # classifier, labelled & unlabelled
        scores_l = self.classify(self.x_l, reuse = reuse)
        scores_u = self.classify(self.x_u, reuse = True)
        self.pred_y = tf.argmax(scores_u, 1)
        y_u_prob = tf.nn.softmax(scores_u, dim=-1)
        
        y_l_onehot =  tcl.one_hot_encoding(self.y_l, num_classes = self.num_labels)
        
        # encoder, labelled data
        z_l, mu_l, logvar_l = self.encode(self.x_l, y_l_onehot, reuse = reuse)
            
        # encoder, unlabelled data
        z_u      = [0]*self.num_labels
        mu_u     = [0]*self.num_labels
        logvar_u = [0]*self.num_labels
        y_us = [0]*self.num_labels
        for i in range(self.num_labels):
            _y = i * tf.ones([tf.shape(self.x_u)[0]], tf.int32)
            y_us[i] = tcl.one_hot_encoding(_y, num_classes = self.num_labels)
            z_u[i], mu_u[i], logvar_u[i] = self.encode(self.x_u, y_us[i], reuse = True)
        
        # decoder, labelled data
        x_recon = self.decode(z_l, y_l_onehot, reuse = reuse)
        
        # decoder, unlabelled data
        x_recon_u = [0]*self.num_labels
        for i in range(self.num_labels):
            x_recon_u[i] = self.decode(z_u[i], y_us[i], reuse = True)
            
        # loss of classifier
        self.loss_clf = tf.nn.sparse_softmax_cross_entropy_with_logits(\
                        labels = self.y_l, logits = scores_l)
        # loss of labelled data, refered as L(x, y)
        self.loss_l = self.L(self.x_l, x_recon, y_l_onehot, mu_l, logvar_l);
        
        # loss of unlabelled data, refered as U(x)
        self.loss_u = 0;
        for i in range(self.num_labels):
            if i==0:
                _loss_u = self.L(self.x_u, x_recon_u[i], y_us[i], mu_u[i], logvar_u[i])
                self.loss_u = tf.expand_dims(_loss_u, 1)
            else:
                _loss_u = self.L(self.x_u, x_recon_u[i], y_us[i], mu_u[i], logvar_u[i])
                _loss_u = tf.expand_dims(_loss_u, 1)
                self.loss_u = tf.concat([self.loss_u, _loss_u], 1)

        # add the H(q(y|x))
        self.loss_u = tf.multiply(y_u_prob, tf.subtract(self.loss_u, -tf.log(y_u_prob)) )
        self.loss_u = tf.reduce_sum(self.loss_u, 1)

        print('loss_u  : '+str(self.loss_u.shape))
        print('loss_l  : '+str(self.loss_l.shape))
        print('loss_clf: '+str(self.loss_clf.shape))

        alpha = 0.1*100
        self.loss_clf = tf.reduce_mean(self.loss_clf, 0)
        self.loss_l   = tf.reduce_mean(self.loss_l,   0)
        self.loss_u   = tf.reduce_mean(self.loss_u,   0)
        self.loss = self.loss_l + self.loss_u + alpha * self.loss_clf
#         self.loss = self.loss_clf
    
    def _build_train_op(self):
        self.global_step = tf.Variable(0, name="global_step", trainable = False)
        optimizer = tf.train.AdamOptimizer(self.learning_rate)
        grads_and_vars = optimizer.compute_gradients(self.loss)
        def ClipIfNotNone(grad):
            if grad is None:
                return grad
            return tf.clip_by_value(grad, -1, 1)
        capped_gvs = [(ClipIfNotNone(grad), var) for grad, var in grads_and_vars]
        self.train_op = optimizer.apply_gradients(capped_gvs, self.global_step)
        
    def predict(self, x, sess):
        feed_dict = {
            self.x_u: x,
        }
        pred = sess.run([self.pred_y], feed_dict = feed_dict)[0]
        return pred

In [3]:
'''
    the following 3 functions came from the original implementation by D.P. Kingma
        @https://github.com/dpkingma/nips14-ssl
    these function are slightly modified for convenience
'''
import cPickle
import gzip
import random
import numpy as np
MNIST_PATH = './data/mnist_28.pkl.gz'

def load_mnist(path):
    f = gzip.open(path, 'rb')
    train, valid, test = cPickle.load(f)
    f.close()
    train_x, train_y = train
    valid_x, valid_y = valid
    test_x,  test_y  = test        
    return train_x, train_y, valid_x, valid_y, test_x, test_y

# Loads data where data is split into class labels
def load_mnist_split(path = MNIST_PATH):
    train_x, train_y, valid_x, valid_y, test_x, test_y = load_mnist(path)
    
    def split_by_class(x, y, num_classes):
        result_x = [0]*num_classes
        result_y = [0]*num_classes
        for i in range(num_classes):
            idx_i = np.where(y == i)[0]
            result_x[i] = x[idx_i]
            result_y[i] = y[idx_i]
        return result_x, result_y
    
    train_x, train_y = split_by_class(train_x, train_y, 10)
    return train_x, train_y, valid_x, valid_y, test_x, test_y

def create_semisupervised(x, y, n_labeled):
    n_x = x[0].shape[0]
    n_classes = 10
    if n_labeled % n_classes != 0: 
        raise("n_labeled (wished number of labeled samples) not divisible by n_classes (number of classes)")
    n_labels_per_class = n_labeled//n_classes
    x_labeled = [0]*n_classes
    x_unlabeled = [0]*n_classes
    y_labeled = [0]*n_classes
    y_unlabeled = [0]*n_classes
    for i in range(n_classes):
        idx = range(x[i].shape[0])
        random.shuffle(idx)
        x_labeled[i]   = x[i][idx[:n_labels_per_class]]
        y_labeled[i]   = y[i][idx[:n_labels_per_class]]
        x_unlabeled[i] = x[i][idx[n_labels_per_class:]]
        y_unlabeled[i] = y[i][idx[n_labels_per_class:]]
    return np.vstack(x_labeled), np.hstack(y_labeled), np.vstack(x_unlabeled), np.hstack(y_unlabeled)

In [4]:
def batch_generator(data, batch_size, num_epoch, shuffle = True):
    data = list(data)
    data = np.array(data)
    data_size = data.shape[0]
    num_batches_per_epoch = (data_size + batch_size - 1)//batch_size
    for epoch in range(num_epoch):
        if shuffle:
            shuffle_indices = np.random.permutation(np.arange(data_size))
            shuffled_data = data[shuffle_indices]
        else:
            shuffled_data = data

        for batch_idx in range(num_batches_per_epoch):
            start_idx = batch_idx * batch_size
            end_idx   = min((batch_idx + 1)*batch_size, data_size)
            yield(shuffled_data[start_idx:end_idx])
            
import time
def time_since(since):
    now = time.time()
    s = now - since
    m = s // 60
    s -= 60 * m

    return "%d m %d s" % (m, s)

In [5]:
'''
    here are the config params for the experiment.
'''
data_size = 50000
n_batch_size = 100
n_labelled = 100
n_epoch = 100
max_iter = n_epoch*(data_size-n_labelled)//n_batch_size

# load data from mnist
train_x, train_y, valid_x, valid_y, test_x, test_y = load_mnist_split()
# split training set
x_l, y_l, x_u, y_u = create_semisupervised(train_x, train_y, n_labelled)
l_batch_gen = batch_generator(zip(x_l, y_l), n_batch_size, max_iter)
u_batch_gen = batch_generator(zip(x_u), n_batch_size, n_epoch)

# Set config for tensorflow session.
tf_config = tf.ConfigProto(
    device_count = {'GPU': 1}, # single gpu
)
tf_config.gpu_options.allow_growth=True
with tf.Session(config = tf_config) as sess:
    model = SSL()
    sess.run(tf.global_variables_initializer())

    eval_train = [model.train_op, model.global_step]
    eval_loss  = [model.loss, model.loss_l, model.loss_u, model.loss_clf]
    
    start = time.time()
    for l_batch, u_batch in zip(l_batch_gen, u_batch_gen):
        x_l, y_l = zip(*l_batch)
        x_u = zip(*u_batch)[0]

        feed_dict = {
            model.x_l: x_l,
            model.y_l: y_l,

            model.x_u: x_u,
        }
        _, step = sess.run(eval_train, feed_dict = feed_dict)
        if step % 500 == 0:
            loss, loss_l, loss_u, loss_clf = sess.run(eval_loss, feed_dict = feed_dict)

            pred_valid = model.predict(valid_x, sess)
            pred_test  = model.predict(test_x,  sess)
            accuracy_valid = accuracy_score(valid_y, pred_valid)
            accuracy_test  = accuracy_score(test_y,  pred_test)
            
            print('time: %s' % time_since(start))
            print(' Iteration %d/%d'    % (step, max_iter))
            print('  labelled loss: %.2f' % loss_l)
            print('unlabelled loss: %.2f' % loss_u)
            print('classifier loss: %.2f' % loss_clf)
            print('     total loss: %.2f' % loss)
            print(' valid accuracy: %.2f' % accuracy_valid)
            print('  test accuracy: %.2f' % accuracy_test)
            print()

loss_u  : (?,)
loss_l  : (?,)
loss_clf: (?,)
classifier/fully_connected/weights:0: (784, 500)
classifier/fully_connected/biases:0: (500,)
classifier/fully_connected_1/weights:0: (500, 10)
classifier/fully_connected_1/biases:0: (10,)
encoder/fully_connected/weights:0: (794, 500)
encoder/fully_connected/biases:0: (500,)
encoder/fully_connected_1/weights:0: (500, 50)
encoder/fully_connected_1/biases:0: (50,)
encoder/fully_connected_2/weights:0: (500, 50)
encoder/fully_connected_2/biases:0: (50,)
decoder/fully_connected/weights:0: (60, 500)
decoder/fully_connected/biases:0: (500,)
decoder/fully_connected_1/weights:0: (500, 784)
decoder/fully_connected_1/biases:0: (784,)

time: 0 m 21 s
 Iteration 500/49900
  labelled loss: 66.78
unlabelled loss: 144.59
classifier loss: 0.00
     total loss: 211.43
 valid accuracy: 0.70
  test accuracy: 0.68

time: 0 m 34 s
 Iteration 1000/49900
  labelled loss: 60.92
unlabelled loss: 127.46
classifier loss: 0.00
     total loss: 188.40
 valid accuracy: 0.7