# Speed-up inference with Batch Normalization Folding
> How to remove the batch normalization layer to make your neural networks faster.

- toc: true
- badges: false
- categories: [Deep Learning]
- comments: true

## **Introduction**

Batch Normalization {% fn 1 %} {% fn 2 %} is a technique which takes care of normalizing the input of each layer to make the training process faster and more stable. In practice, it is an extra layer that we generally add after the computation layer and before the non-linearity.

#hide
<!-- Create a div where the graph will take place -->
<div id="my_dataviz"></div>

<style>

.axisGray text{
  fill: rgb(169,169,179);
}  
  
</style>

<!-- Load color palettes -->


<script>

// set the dimensions and margins of the graph
var margin = {top: 80, right: 25, bottom: 30, left: 40},
  width = 600 - margin.left - margin.right,
  height = 600 - margin.top - margin.bottom;

// append the svg object to the body of the page
var svg = d3.select("#my_dataviz")
.append("svg")
  .attr("width", width + margin.left + margin.right)
  .attr("height", height + margin.top + margin.bottom)
.append("g")
  .attr("transform",
        "translate(" + margin.left + "," + margin.top + ")");

//Read the data
d3.csv("{{site.baseurl}}/assets/csv/out.csv", function(data) {

  // Labels of row and columns -> unique identifier of the column called 'group' and 'variable'
  var myGroups = d3.map(data, function(d){return d.group;}).keys()
  var myVars = d3.map(data, function(d){return d.variable;}).keys()
    
    

  // Build X scales and axis:
  var x = d3.scaleBand()
    .range([ 0, width ])
    .domain(myGroups)
    .padding(0.05);
  svg.append("g")
    .style("font-size", 15)
    .attr("class", "axisGray")
    .attr("transform", "translate(0," + height + ")")
    .call(d3.axisBottom(x).tickSize(0))
    .select(".domain").remove()

  // Build Y scales and axis:
  var y = d3.scaleBand()
    .range([ height, 0 ])
    .domain(myVars)
    .padding(0.05);
  svg.append("g")
    .style("font-size", 15)
    .attr("class", "axisGray")
    .call(d3.axisLeft(y).tickSize(0))
    .select(".domain").remove()


  // Build color scale
  var myColor = d3.scaleSequential()
    .interpolator(d3.interpolateInferno)
    .domain([1,100])

  // create a tooltip
  var tooltip = d3.select("#my_dataviz")
    .append("div")
    .style("opacity", 0)
    .attr("class", "tooltip")
    .style("background-color", "white")
    .style("border", "solid")
    .style("border-width", "2px")
    .style("border-radius", "5px")
    .style("padding", "5px")

  // Three function that change the tooltip when user hover / move / leave a cell
  var mouseover = function(d) {
    tooltip
      .style("opacity", 1)
    d3.select(this)
      .style("stroke", "black")
      .style("opacity", 1)
  }
  var mousemove = function(d) {
    tooltip
      .html("Pixel value: " + d.value)
      .style("left", (d3.mouse(this)[0]+70) + "px")
      .style("top", (d3.mouse(this)[1]) + "px")
  }
  var mouseleave = function(d) {
    tooltip
      .style("opacity", 0)
    d3.select(this)
      .style("stroke", "none")
      .style("opacity", 0.8)
  }

  // add the squares
  svg.selectAll()
    .data(data, function(d) {return d.group+':'+d.variable;})
    .enter()
    .append("rect")
      .attr("x", function(d) { return x(d.variable) })
      .attr("y", function(d) { return y(d.group) })
      .attr("rx", 4)
      .attr("ry", 4)
      .attr("width", x.bandwidth() )
      .attr("height", y.bandwidth() )
      .style("fill", function(d) { return myColor(d.value)} )
      .style("stroke-width", 4)
      .style("stroke", "none")
      .style("opacity", 0.8)
    .on("mouseover", mouseover)
    .on("mousemove", mousemove)
    .on("mouseleave", mouseleave)
})

// Add subtitle to graph
svg.append("text")
        .attr("x", 0)
        .attr("y", -20)
        .attr("text-anchor", "left")
        .style("font-size", "14px")
        .style("fill", "grey")
        .style("max-width", 400)
        .text("MNIST visualization");


</script>

It consists of **2** steps:

1. Normalize the batch by first subtracting its mean $\mu$, then dividing it by its standard deviation $\sigma$.
2. Further scale by a factor $\gamma$ and shift by a factor $\beta$. Those are the parameters of the batch normalization layer, required in case of the network not needing the data to have a mean of **0** and a standard deviation of **1**.

$$
\Large
\begin{aligned}
&\mu_{\mathcal{B}} \leftarrow \frac{1}{m} \sum_{i=1}^{m} x_{i}\\
&\sigma_{\mathcal{B}}^{2} \leftarrow \frac{1}{m} \sum_{i=1}^{m}\left(x_{i}-\mu_{\mathcal{B}}\right)^{2}\\
&\widehat{x}_{i} \leftarrow \frac{x_{i}-\mu_{\mathcal{B}}}{\sqrt{\sigma_{\mathcal{B}}^{2}+\epsilon}}\\
&y_{i} \leftarrow \gamma \widehat{x}_{i}+\beta \equiv \mathrm{BN}_{\gamma, \beta}\left(x_{i}\right)
\end{aligned}
$$

Due to its efficiency for training neural networks, batch normalization is now widely used. But how useful is it at inference time?

Once the training has ended, each batch normalization layer possesses a specific set of $\gamma$ and $\beta$, but also $\mu$ and $\sigma$, the latter being computed using an exponentially weighted average during training. It means that during inference, the batch normalization acts as a simple linear transformation of what comes out of the previous layer, often a convolution.

As a convolution is also a linear transformation, it also means that both operations can be merged into a single linear transformation!

This would remove some unnecessary parameters but also reduce the number of operations to be performed at inference time.

---

<br>

## **How to do that in practice?**


With a little bit of math, we can easily rearrange the terms of the convolution to take the batch normalization into account.

As a little reminder, the convolution operation followed by the batch normalization operation can be expressed, for an input $x$, as:

$$
\Large
\begin{aligned}
z &=W * x+b \\
\text { out } &=\gamma \cdot \frac{z-\mu}{\sqrt{\sigma^{2}+\epsilon}}+\beta
\end{aligned}
$$

So, if we re-arrange the $W$ and $b$ of the convolution to take the parameters of the batch normalization into account, as such:

$$
\Large
\begin{aligned}
w_{\text {fold }} &=\gamma \cdot \frac{W}{\sqrt{\sigma^{2}+\epsilon}} \\
b_{\text {fold }} &=\gamma \cdot \frac{b-\mu}{\sqrt{\sigma^{2}+\epsilon}}+\beta
\end{aligned}
$$

We can remove the batch normalization layer and still have the same results!

> Note: Usually, you don’t have a bias in a layer preceding a batch normalization layer. It is useless and a waste of parameters as any constant will be canceled out by the batch normalization.

---
<br>

## **How efficient is it?**

We will try for **2** common architectures:

1. VGG16 with batch norm
2. ResNet50

Just for the demonstration, we will use ImageNette dataset and PyTorch. Both networks will be trained for **5** epochs and what changes in terms of parameter number and inference time.

<br>

### **VGG16**

Let’s start by training VGG16 for **5** epochs (the final accuracy doesn’t matter):

In [16]:
#hide_input
learn.fit_one_cycle(5, 1e-3)

epoch,train_loss,valid_loss,accuracy,time
0,1.985012,3.945934,0.226497,00:31
1,1.868819,1.620619,0.472611,00:31
2,1.574975,1.295385,0.576815,00:31
3,1.305211,1.16146,0.617325,00:32
4,1.072395,0.955824,0.684076,00:32


Then show its number of parameters:

In [16]:
#hide_input
count_parameters(model)

Total parameters : 134,309,962


We can get the initial inference time by using the `%%timeit` magic command:

In [36]:
%%timeit
model(x[0][None].cuda())

2.77 ms ± 1.65 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


So now if we apply batch normalization folding, we have:

In [21]:
#hide_input
count_parameters(folded_model)

Total parameters : 134,301,514


And: 

In [37]:
%%timeit
folded_model(x[0][None].cuda())

2.41 ms ± 2.49 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


So **8448** parameters removed and even better, almost **0.4 ms** faster inference! Most importantly, this is completely lossless, there is absolutely no change in terms of performance:

In [32]:
folded_learner.validate()

[0.9558241, tensor(0.6841)]

Let’s see how it behaves in the case of Resnet50!

<br>

### **Resnet50**

Same, we start by training it for **5** epochs:

In [39]:
#hide_input
learn.fit_one_cycle(5, 1e-3)

epoch,train_loss,valid_loss,accuracy,time
0,2.076416,2.491038,0.246624,00:20
1,1.69675,1.517581,0.489427,00:19
2,1.313028,1.206347,0.606115,00:20
3,1.0576,0.890211,0.716943,00:21
4,0.828224,0.79313,0.740892,00:19


The initial amount of parameters is:

In [24]:
#hide_input
count_parameters(model)

Total parameters : 23,528,522


And inference time is:

In [43]:
%%timeit
model(x[0][None].cuda())

6.17 ms ± 13.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


After using batch normalization folding, we have:

In [33]:
#hide_input
count_parameters(final_model)

Total parameters : 23,501,962


And:

In [58]:
%%timeit
final_model(x[0][None].cuda())

4.47 ms ± 8.97 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


So now, we have **26,560** parameters removed and even more impressive, an inference time reduce by **1.7ms**! And still without any drop in performance.

In [59]:
final_learner.validate()

[0.7931296, tensor(0.7409)]

<br>

<span style="font-size:larger;">So if we can reduce the inference time and the number of parameters of our models without enduring any drop in performance, why shouldn’t we always do it?</span>

<br>

**I hope that this blog post helped you! Feel free to give me feedback or ask me questions is something is not clear enough.**

Code available at [this address!](https://github.com/nathanhubens/fasterai)

---

<br>

## **References**

- {{ '[The Batch Normalization paper](https://arxiv.org/pdf/1502.03167.pdf)' | fndetail: 1 }} 
- {{ '[DeepLearning.ai Batch Normalization Lesson](https://www.youtube.com/watch?v=tNIpEZLv_eg&t=1s)' | fndetail: 2 }}