This is implementation of [FractalNet](http://arxiv.org/abs/1605.07648) paper

In [43]:
#!/usr/bin/env python

import numpy as np
import tensorflow as tf
import tflearn    
from tflearn.layers.normalization import batch_normalization

class FractalNet():
    def __init__(self, input, n, f_height = 2, f_width = 2, out_chanell=None):
        _, _, _, c = input.get_shape()
        in_channel = int(c)
        if out_chanell is None:
            out_chanell = in_channel
            
        self.n = n 
        self.children = []
        with tf.name_scope("F%d" % n):
            with tf.name_scope("atom"):
                # single comutational "atom" in FractalNet
                
                self.filter = tf.Variable(tf.truncated_normal(
                        [f_height, f_width, in_channel, out_chanell], 
                        stddev=0.35
                    ),
                    name="filter"
                )
                
                self.bias = tf.Variable([0]*out_chanell, dtype=tf.float32, name="bias")
                atom = tf.nn.conv2d(input, self.filter, [1,1,1,1], 'SAME')
                atom = tf.nn.relu(tf.nn.bias_add(atom, self.bias))
                atom = batch_normalization(atom)
                
            self.__tensors = [atom]
            if n > 1:
                Fp = FractalNet(input, n - 1, f_height, f_width, int((out_chanell + in_channel)/2))
                self.children.append(Fp)
                Fp = FractalNet(Fp.get_tensor(), n - 1, f_height, f_width, out_chanell)
                self.children.append(Fp)
                self.__tensors.extend(Fp.__tensors)
            
            with tf.name_scope("join"):
                # activations in join layer 
                # for mean join layer they should be equal and sum to 1
                self.is_active = [
                    tf.Variable(1.0/n, trainable=False, name="a%d"%i)
                    for i in range(n)
                ]
            
                self.__tensor = tf.add_n(
                    [tf.mul(m, x) for m, x in zip(self.is_active, self.__tensors)], 
                    name="Average_pool_join"
                )

    def get_tensor(self):
        return self.__tensor

    def genAssignJoinValues(self, values):
        return tf.group(*[
            var.assign(val)
            for val, var in zip(values, self.is_active)
        ])

    def genColumn(self, column):
        assert 0 <= column < self.n

        values = np.zeros(len(self.is_active))
        values[column] = 1
        
        return tf.group(
            self.genAssignJoinValues(values),
            *[fp.genColumn(column - 1) for fp in self.children if column > 0]
        )

    def genRandomColumn(self):
        return self.genColumn(np.random.randint(self.n))
        
    
    def genLocalDropPath(self, dropout_prob):
        values = np.zeros(self.n)
        while np.sum(values) < 0.5: # ==0; floating point correction
            values = (np.random.random(self.n) > dropout_prob).astype(np.float32)
        values /= np.sum(values) #normalize sum to 1
        return tf.group(
            self.genAssignJoinValues(values),
            *[fp.genLocalDropPath(dropout_prob) for fp in self.children]
        )
    
    def genTestMode(self):
        """
            Kills any droppaths set
        """
        return self.genAssignJoinValues(np.ones(self.n, dtype=np.float32)/self.n)

height, width, channels = 32, 32, 3
noClasses = 10
    
g = tf.Graph()

with g.as_default():
    X = tf.placeholder(tf.float32, [None, height, width, channels], name="input")
    Y = tf.placeholder(tf.float32, [None, noClasses], name="labels")
    FF = []
    net = X
    
    for i, channel_no in enumerate([16, 32, 64, 128, 128]):
        with tf.name_scope("block_%d" % (i + 1)):
            net = FractalNet(net, 4, out_chanell=channel_no)
        FF.append(net)
        net = tf.nn.max_pool(net.get_tensor(), [1,2,2,1], [1,2,2,1], padding='SAME')
        print(i, net.get_shape())
    
    net = tflearn.fully_connected(net, noClasses)
    yp = tf.nn.softmax(net)
    loss = tf.nn.softmax_cross_entropy_with_logits(net, Y)
    
with tf.Session(graph=g) as sess:
    sess.run(tf.initialize_all_variables())
    merged = tf.merge_all_summaries()
    writer = tf.train.SummaryWriter("/tmp/FractalNet", sess.graph)
    tflearn.is_training(True)
    for F in FF[::2]:
        sess.run(F.genLocalDropPath(0.15))
    for F in FF[1::2]:
        sess.run(F.genRandomColumn())

0 (?, 16, 16, 16)
1 (?, 8, 8, 32)
2 (?, 4, 4, 64)
3 (?, 2, 2, 128)
4 (?, 1, 1, 128)
