In [4]:
import numpy as np
import tensorflow as tf

# Data

In [6]:
class Dataset:
    def __init__(self):
        # np.random.seed(100)
        np.random.seed(int(time.time() * 1e6) % 2 ** 31)
        dataset = input_data.read_data_sets(DATA_DIR, one_hot=True)

        self.train_x = dataset.train.images
        self.train_x = self.train_x.reshape([-1, 28, 28, 1])
        self.train_y = dataset.train.labels
        self.valid_x = dataset.validation.images
        self.valid_x = self.valid_x.reshape([-1, 28, 28, 1])
        self.valid_y = dataset.validation.labels
        self.test_x = dataset.test.images
        self.test_x = self.test_x.reshape([-1, 28, 28, 1])
        self.test_y = dataset.test.labels

        train_mean = self.train_x.mean()
        self.train_x -= train_mean
        self.valid_x -= train_mean
        self.test_x -= train_mean

    def __get_batches_from(self, set_x, set_y, num_batches=100):
        Xs = []
        Ys = []
        N = set_x.shape[0]
        groupedIndexes = np.split(np.random.permutation(N), num_batches)
        for group in groupedIndexes:
            Xs.append(set_x[group])
            Ys.append(set_y[group])

        return zip(Xs, Ys)

    def get_train_batches(self, num_batches=100):
        return self.__get_batches_from(self.train_x, self.train_y, num_batches)

    def get_valid_batches(self, num_batches=100):
        return self.__get_batches_from(self.valid_x, self.valid_y, num_batches)

    def get_test_batches(self, num_batches=100):
        return self.__get_batches_from(self.test_x, self.test_y, num_batches)

In [None]:
class Params:
    def __init__(self):
        pass

# Model

In [3]:
class Model:
    def __init__(self):
        tf.reset_default_graph()
        imageSize = 28
        classesNum = 10

        self.inputs = tf.placeholder(tf.float32, (None, imageSize, imageSize, 1))
        self.labels = tf.placeholder(tf.float32, (None, classesNum))

            net = layers.fully_connected(net, 512, scope='fc3')
            # net = layers.fully_connected(net, 10, scope='fc4')

        self.logits = layers.fully_connected(net, classesNum, activation_fn=None, scope='logits')
        self.loss = tf.losses.softmax_cross_entropy(self.labels, self.logits) + sum(
            tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
        self.train_step = tf.train.GradientDescentOptimizer(learning_rate=learning_rate).minimize(self.loss)
        self.sess = tf.Session()

    def draw_conv_filters(self, epoch, step, weights, save_dir):
        C = 1
        w = weights.copy()
        num_filters = w.shape[3]
        k = w.shape[1]
        w = w.reshape(num_filters, C, k, k)
        w -= w.min()
        w /= w.max()
        border = 1
        cols = 8
        rows = math.ceil(num_filters / cols)
        width = cols * k + (cols - 1) * border
        height = rows * k + (rows - 1) * border
        # for i in range(C):
        for i in range(1):
            img = np.zeros([height, width])
            for j in range(num_filters):
                r = int(j / cols) * (k + border)
                c = int(j % cols) * (k + border)
                img[r:r + k, c:c + k] = w[j, i]
            filename = 'epoch_%02d_step_%06d_input_%03d.png' % (epoch, step, i)
            ski.io.imsave(os.path.join(save_dir, filename), img)

    def train(self, dataset, param_niter=10, num_batches=100):
        self.sess.run(tf.global_variables_initializer())

        for i in range(1, 1 + int(param_niter)):
            train_loss = 0
            train_correct = 0
            batch_i = 0
            for x, y in dataset.get_train_batches(num_batches):
                loss_val, logits, _ = self.sess.run([self.loss, self.logits, self.train_step],
                                                    feed_dict={self.inputs: x, self.labels: y})
                train_loss += loss_val
                train_correct += np.sum(y == (self.logits_to_hot(logits)))

                batch_i += 1
            filters = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'conv1')[0].eval(self.sess)
            self.draw_conv_filters(i, batch_i, filters, SAVE_DIR)

            train_loss /= num_batches
            train_correct /= dataset.train_y.shape[0] * dataset.train_y.shape[1]

            valid_loss = 0
            valid_correct = 0
            for x, y in dataset.get_valid_batches(num_batches):
                loss_val, logits = self.sess.run([self.loss, self.logits], feed_dict={self.inputs: x, self.labels: y})
                valid_loss += loss_val
                valid_correct += np.sum(y == (self.logits_to_hot(logits)))
            valid_loss /= num_batches
            valid_correct /= dataset.valid_y.shape[0] * dataset.valid_y.shape[1]

            print("Iteration", i, "has loss:", (train_loss, valid_loss))
            print("Precision:", (train_correct, valid_correct))