Dependencies
- Batch renorm
- LayerNorm
- InstanceNorm
- GroupNorm
- ResnetV2-50
- InceptionV3
- Augmentation from 26
- VGG-A
- PReLU
- Maxout

In [1]:
import tensorflow as tf

In [17]:
class FRN(tf.keras.models.Model):
    # TODO: write init
    
    def __init__(self, eps=None, eps_mode='abs'):
        super(FRN, self).__init__()
        if self.eps is not None:
            self.eps = eps
        else:
            self.eps = self.add_weight("eps", shape=[])
        self.dense = tf.keras.layers.Dense(units=1)
        
    def __call__(self, x):
        """
        x: feature_maps of shape (B, H, W, F)
        """
        # (B, 1, 1, F)
        # If H = W = 1, then mean_sq_norm = x**2
        mean_sq_norm = tf.reduce_mean(x ** 2, axis=[1, 2], keepdims=True)
        # If H = W = 1, then x_hat = x / (+sqrt(x**2 + eps)) ~ x / |x| ~ sign(x)
        x_hat = x / tf.sqrt(mean_sq_norm**2 + self.eps)
        x_hat = tf.reshape(x_hat, [-1, 1])
        y = self.dense(x_hat)
        y = tf.reshape(y, tf.shape(x))
        return y
    
    
class TLU(tf.keras.layers.Layer):
    def __init__(self):
        super(TLU, self).__init__()
        self.tau = self.add_weight("tau", shape=[])
    def call(self, x):
        return tf.maximum(x, self.tau)

- Gradients of TLU
    - $y = \max(x, \tau) = x\{x \gt \tau\} + \tau\{x \leq \tau\}$
    - $dy/dx = \{x \gt \tau\}$ - $x$ gets gradients for those normalised elements that are larger 
    - $dy/d\tau = \{x \leq \tau\}$ - $\tau$ gets gradients from those normalised elements of x that are smaller
    
- No mean normalization
    - May mean that can be arbitrarily away from 0 but that is why we have TLU not ReLU
    - Arbitrary unless done for batch 

- Implementation details
    - SGD
    - 8 GPUs 
    - 300K steps
    - norm stats per GPU for BN
    - lr = 0.1 x batch_size / 256 with cosine decay
    - other details from [8, 9]
    - Metrics
        - Accuracy using highest scoring class (precision at 1)
        - Accuracy using top-5 scoring classes (recall at 5)
     - For comparing other norm methods
         - 32 images / GPU = 256 batch_size
         
    - Object detection
        - Fine-tuned baseline
            - steps = 25000
        - Rest from scratch
            - steps = 125000
            - batch_size in {62, 32, 16}
            - lr = {0.01, 0.05, 0.1} * batch_size / 64
            - train_steps = 125000 * 64 / batch_size
            - momentum = 0.9
            - weight_decay = $4 \times 10^{-4}$

- Results
    - To the effect that FRN always does better