# MobileNetV2 in XABY

In this notebook I'm going to implement the MobileNetV2 model in XABY. Read the original paper here: https://arxiv.org/abs/1801.04381. I'm following the paper as well as the [torchvision implementation](https://github.com/pytorch/vision/blob/master/torchvision/models/mobilenetv2.py).

The goal here is to add features to XABY that are needed to for the MobileNetV2 model. My typically workflow is to develop new things in a notebook like this, then copy it over to the actual module files.

In [None]:
from typing import Optional, List, Union

import xaby as xb
import xaby.nn as xn
from xaby import jnp

import numpy as np
import jax
from jax.ops import index, index_add

The original paper uses ReLU6 activations, implement that here.

In [None]:
@xb.fn
def relu6(x: jnp.DeviceArray) -> jnp.DeviceArray:
    return jnp.clip(x, a_min=0, a_max=6)

The MobileNetV2 model uses [batch normalization](https://arxiv.org/abs/1502.03167) (known as batchnorm). As of starting this notebook, batchnorm doesn't exist in XABY. So, implement it here.

Typically, during evaluation batch normalization is modified to use global statistics (mean and variance) instead of the batch statistics. You would implement this by tracking the mean and variance of every batch you used to train in the batchnorm layer. But, this requires modifying state resulting in an impure function. To get JAX's jit to work, you can use only pure functions. I haven't found a good way around this yet, so this batchnorm will continue to use batch statistics at evaluation time.

In [None]:
class batchnorm2d(xb.Fn):
    def __init__(self, num_features: int, epsilon: int=1e-5):
        """ BatchNorm for 2D input, expects 4D input array in (B, C, H, W) format """
        
        @jax.jit
        def batchnorm2d(inputs: xb.ArrayList, params: dict) -> xb.ArrayList:
            x, = inputs
            weights, bias = params["weights"], params["bias"]
            num_features = x.shape[1]
            
            # Reshaping for broadcasting ease
            x_mean = jnp.mean(x, axis=(0, 2, 3)).reshape(1, num_features, 1, 1)
            x_var = jnp.mean((x - x_mean)**2, axis=(0, 2, 3)).reshape(1, num_features, 1, 1)
                    
            x_norm = (x - x_mean) / jnp.sqrt(x_var + epsilon)
            y = weights * x_norm + bias
            return xb.pack(y)
        
        super().__init__(batchnorm2d, 1, 1, name="batchnorm2d")
        
        self.params["weights"] = jnp.ones((1, num_features, 1, 1))
        self.params["bias"] = jnp.zeros((1, num_features, 1, 1))
        
        self.num_features = num_features
        self.epsilon = epsilon
    
    def __repr__(self):
        return f"batchnorm({self.num_features}, epsilon={self.epsilon})"

Here I'll create functions for creating sub-functions of the overall model.

In [None]:
def residual(func: xb.Fn) -> xb.Fn:
    return xb.parallel(func, xb.skip) >> xb.add

def conv_norm(in_f: int, out_f: int, kernel_size: int = 3, stride: int = 1, padding:int = 0, groups: int = 1, 
              norm: Optional[xb.Fn] = None, activation: Optional[xb.Fn] = None):
    """ Returns a function conv2d >> batchnorm2d >> relu6 by default """
    
    if norm is None:
        norm = batchnorm2d
    if activation is None:
        activation = relu6
        
    conv = xn.conv2d(in_f, out_f, kernel_size, stride, padding, groups, bias=False)
    func = conv >> norm(out_f) >> activation
    func.name = "conv_norm"
    return func

def bottleneck(in_features: int, out_features: int, stride: int=1, expand_ratio: Union[int, float] = 1.0) -> xb.Fn:
    """ Returns an Inverted Residual Bottleneck function """
    
    expand_f = int(round(expand_ratio * in_features))
    
    if expand_ratio != 1:
        expand = conv_norm(in_features, expand_f, kernel_size=1)
    else:
        expand = None
    
    depthwise = conv_norm(expand_f, expand_f, stride=stride, padding=1, groups=expand_f)
    
    transform = xn.conv2d(expand_f, out_features, kernel_size=1, stride=1, padding=0, bias=False)
    
    if expand is not None:
        func = expand >> depthwise >> transform >> batchnorm2d(out_features)
    else:
        func = depthwise >> transform >> batchnorm2d(out_features)
    
    if in_features == out_features and stride == 1:
        func = residual(func)
        
    func.name = "bottleneck"
    
    return func

Time to put the parts together into one big function / model

In [None]:
def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int:
    """
    This function is taken from the original tf repo.
    It ensures that all layers have a channel number that is divisible by 8
    It can be seen here:
    https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
    """
    if min_value is None:
        min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    # Make sure that round down does not go down by more than 10%.
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v
    
def mobilenetv2(num_classes: int =1000, width_mult: float = 1.0, round_nearest=1, architecture: Optional[List[List[int]]]=None) -> xb.Fn:
    
    in_features = 32
    last_features = 1280
    
    # Architecture from the MobileNetV2 paper
    if architecture is None:
        architecture = [
            # t, c, n, s
            [1, 16, 1, 1],
            [6, 24, 2, 2],
            [6, 32, 3, 2],
            [6, 64, 4, 2],
            [6, 96, 3, 1],
            [6, 160, 3, 2],
            [6, 320, 1, 1],
        ]
        
    in_features = _make_divisible(in_features * width_mult, round_nearest)
    last_features = _make_divisible(last_features * max(1.0, width_mult), round_nearest)
    
    features = xn.conv2d(3, in_features, stride=2, padding=1) >> batchnorm2d(in_features)
    
    # Add bottleneck layers
    for t, c, n, s in architecture:
        for i in range(n):
            out_features = c
            # We only change the dimensions at the first bottleneck layer of any series
            stride = s if i == 0 else 1
            
            features >> bottleneck(in_features, out_features, stride=stride, expand_ratio=t)
            
            in_features = out_features

    features = features >> conv_norm(in_features, last_features, kernel_size=1)
    features.name = "features"
    
    classifier = xn.dropout(0.2) >> xn.linear(last_features, num_classes)
    classifier.name = "classifier"
    
    # The mean here computes the spatial average of each feature. This is sometimes known as 
    # global average pooling. It's becoming more popular in new models. I like it because it
    # avoids imposing a fixed input size for the model. You can train the model on 224x224 images 
    # like normal, but then use the model for images with other sizes and it still works.
    model = features >> xb.mean(axis=(2, 3)) >> classifier
    model.name = "mobilenetv2"
    
    return model

Now we can create the model itself and look at the structure.

In [14]:
model = mobilenetv2()
model

sequential('mobilenetv2') {
  (features): sequential('features') {
    (conv2d): conv2d(3, 32, kernel=(3, 3), stride=(2, 2), padding=((1, 1), (1, 1)), bias=True)
    (batchnorm2d): batchnorm(32, epsilon=1e-05)
    (bottleneck): sequential('bottleneck') {
      (conv_norm): sequential('conv_norm') {
        (conv2d): conv2d(32, 32, kernel=(3, 3), stride=(1, 1), padding=((1, 1), (1, 1)), bias=False)
        (batchnorm2d): batchnorm(32, epsilon=1e-05)
        (relu6): relu6
      }
      (conv2d): conv2d(32, 16, kernel=(1, 1), stride=(1, 1), padding=((0, 0), (0, 0)), bias=False)
      (batchnorm2d): batchnorm(16, epsilon=1e-05)
    }
    (bottleneck_1): sequential('bottleneck') {
      (conv_norm): sequential('conv_norm') {
        (conv2d): conv2d(16, 96, kernel=(1, 1), stride=(1, 1), padding=((0, 0), (0, 0)), bias=False)
        (batchnorm2d): batchnorm(96, epsilon=1e-05)
        (relu6): relu6
      }
      (conv_norm_1): sequential('conv_norm') {
        (conv2d): conv2d(96, 96, kerne

In [None]:
# This first time you run this tends to be slow because the functions are compiling. 
# Subsequent runs should be much faster
a = xb.random.uniform((10, 3, 224, 224))
xb.pack(a) >> model

In [None]:
loss = xb.split(model >> xn.log_softmax(axis=0), xb.skip) >> xn.losses.nll_loss()

In [None]:
a = xb.random.uniform((10, 3, 224, 224))
b = xb.random.bernoulli((10, 1)).astype(xb.jnp.int32)

xb.pack(a, b) >> loss

In [None]:
a = xb.random.uniform((10, 3, 224, 224))
xb.pack(a) >> model

In [None]:
a = xb.random.uniform((10, 3, 224, 224))
xb.pack(a) >> model.features