This is implementation of [FractalNet](http://arxiv.org/abs/1605.07648) paper

In [56]:
%load_ext pep8_magic

In [79]:
#!/usr/bin/env python

import numpy as np
import tensorflow as tf
import tflearn
from scipy import stats
import tflearn.layers.normalization.batch_normalization as BN

    
class FractalNet():
    def __init__(self, input, n, f_height = 3, f_width = 3):
        _, _, _, c = input.get_shape()
        f_channels = int(c)
        self.n = n 
        self.children = []
        with tf.name_scope("F%d" % n) as scope:
            # Convolutional layer filter
            self.filter = tf.Variable(tf.truncated_normal(
                [f_height, f_width, f_channels, f_channels], stddev=0.35),
                name="filter")

            self.__tensors = [
                tf.nn.relu_layer() BN(tf.nn.conv2d(input, self.filter, [1,1,1,1], 'SAME'))
            ]
            if n > 1:
                Fp = FractalNet(input, n - 1, f_height, f_width)
                self.children.append(Fp)
                Fp = FractalNet(Fp.get_tensor(), n - 1, f_height, f_width)
                self.children.append(Fp)
                self.__tensors.extend(Fp.__tensors)
            
            with tf.name_scope("join"):
                # activations in join layer 
                # for mean join layer they should be equal and sum to 1
                self.is_active = [
                    tf.Variable(1.0/n, trainable=False, name="a%d"%i)
                    for i in range(n)
                ]
            
                self.__tensor = tf.add_n(
                    [tf.mul(m, x) for m, x in zip(self.is_active, self.__tensors)], 
                    name="Average_pool_join"
                )

    def get_tensor(self):
        return self.__tensor

    def genAssignJoinValues(self, values):
        s = np.sum(values)
        return tf.group(*[
            var.assign(val)
            for val, var in zip(values/s, self.is_active)
        ])

    def genColumn(self, column):
        assert 0 <= column < self.n

        values = np.zeros(len(self.is_active))
        values[column] = 1
        
        return tf.group(
            self.genAssignJoinValues(values),
            *[fp.genColumn(column - 1) for fp in self.children if column > 0]
        )

    def genRandomColumn(self):
        return self.genColumn(np.random.randint(self.n))
        
    
    def genLocalDropPath(self, dropout_prob):
        values = np.zeros(self.n)
        while np.sum(values) < 0.5:
            values = (np.random.random(self.n) > dropout_prob).astype(np.float32)
        return tf.group(
            self.genAssignJoinValues(values),
            *[fp.genLocalDropPath(dropout_prob) for fp in self.children]
        )
    
    def genTestMode(self):
        """
            Kills any droppaths set
        """
        return self.genAssignJoinValues(np.ones(self.n))

height, width, channels = 28, 28, 1
noClasses = 10
    
g = tf.Graph()

with g.as_default():
    X = tf.placeholder(tf.float32, [None, height, width, channels], name="input")
    Y = tf.placeholder(tf.float32, [None, noClasses], name="labels")
    
    F = FractalNet(X, 3)
    
    net = F.get_tensor()
    net = tflearn.fully_connected(net, noClasses)
    yp = tf.nn.softmax(net)
    loss = tf.nn.softmax_cross_entropy_with_logits(net, Y)
    
with tf.Session(graph=g) as sess:
    sess.run(tf.initialize_all_variables())
    merged = tf.merge_all_summaries()
    writer = tf.train.SummaryWriter("/tmp/FractalNet", sess.graph)
    sess.run(F.genRandomColumn())
    sess.run(F.genLocalDropPath(0.15))

In [70]:

x = tf.constant(5.0, shape=[5, 6])
print(x.get_shape())
w = tf.constant([0.0, 1.0, 2.0, 3.0, 4.0, 5.0])
print(w.get_shape())
xw = tf.mul(x, w)
print(xw.get_shape())
max_in_rows = tf.reduce_max(xw, 1)

with tf.Session() as sess:
    print(sess.run(x))
    print (sess.run(xw))
    # ==> [[0.0, 5.0, 10.0, 15.0, 20.0, 25.0],
    #      [0.0, 5.0, 10.0, 15.0, 20.0, 25.0],
    #      [0.0, 5.0, 10.0, 15.0, 20.0, 25.0],
    #      [0.0, 5.0, 10.0, 15.0, 20.0, 25.0],
    #      [0.0, 5.0, 10.0, 15.0, 20.0, 25.0]]

    print (sess.run(max_in_rows))

    # ==> [25.0, 25.0, 25.0, 25.0, 25.0]


(5, 6)
(6,)
(5, 6)
[[ 5.  5.  5.  5.  5.  5.]
 [ 5.  5.  5.  5.  5.  5.]
 [ 5.  5.  5.  5.  5.  5.]
 [ 5.  5.  5.  5.  5.  5.]
 [ 5.  5.  5.  5.  5.  5.]]
[[  0.   5.  10.  15.  20.  25.]
 [  0.   5.  10.  15.  20.  25.]
 [  0.   5.  10.  15.  20.  25.]
 [  0.   5.  10.  15.  20.  25.]
 [  0.   5.  10.  15.  20.  25.]]
[ 25.  25.  25.  25.  25.]
