-
Notifications
You must be signed in to change notification settings - Fork 21
/
tutorial_splitbn.html
143 lines (116 loc) · 12.7 KB
/
tutorial_splitbn.html
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
---
title: What is Split Batch Normalization and how can we implement it?
keywords: fastai
sidebar: home_sidebar
nb_path: "nbs/06b_tutorial_splitbn.ipynb"
---
<!--
#################################################
### THIS FILE WAS AUTOGENERATED! DO NOT EDIT! ###
#################################################
# file to edit: nbs/06b_tutorial_splitbn.ipynb
# command to build the docs after a change: nbdev_build_docs
-->
<div class="container" id="notebook-container">
{% raw %}
<div class="cell border-box-sizing code_cell rendered">
</div>
{% endraw %}
<div class="cell border-box-sizing text_cell rendered"><div class="inner_cell">
<div class="text_cell_render border-box-sizing rendered_html">
<p>Split Batch Normalization was first introduced in <a href="https://arxiv.org/abs/1904.03515">Split Batch Normalization: Improving Semi-Supervised Learning under Domain Shift</a>.</p>
</div>
</div>
</div>
<div class="cell border-box-sizing text_cell rendered"><div class="inner_cell">
<div class="text_cell_render border-box-sizing rendered_html">
<p>From the abstract of the paper:</p>
<pre><code>Recent work has shown that using unlabeled data in semisupervised learning is not always beneficial and can even hurt generalization, especially when there is a class mismatch between the unlabeled and labeled examples. We investigate this phenomenon for image classification on the CIFAR-10 and the ImageNet datasets, and with many other forms of domain shifts applied (e.g. salt-and-pepper noise). Our main contribution is Split Batch Normalization (Split-BN), a technique to improve SSL when the additional unlabeled data comes from a shifted distribution. We achieve it by using separate batch normalization statistics for unlabeled examples. Due to its simplicity, we recommend it as a standard practice. Finally, we analyse how domain shift affects the SSL training process. In particular, we find that during training the statistics of hidden activations in late layers become markedly different between the unlabeled and the labeled examples.</code></pre>
</div>
</div>
</div>
<div class="cell border-box-sizing text_cell rendered"><div class="inner_cell">
<div class="text_cell_render border-box-sizing rendered_html">
<p>In simple words, they propose to compute separately batch normalization statistics for the unsupervised and supervised dataset. That is, have separate BN layers instead of 1 for the whole batch.</p>
</div>
</div>
</div>
<div class="cell border-box-sizing text_cell rendered"><div class="inner_cell">
<div class="text_cell_render border-box-sizing rendered_html">
<p>You might say that's easy to say but how do we implement in code?</p>
<p>Well, in <code>timm</code> training, you just do:</p>
<pre><code>python train.py ../imagenette2-320 --aug-splits 3 --split-bn --aa rand-m9-mstd0.5-inc1 --resplit</code></pre>
<p>And that's it. But what does this command mean?</p>
</div>
</div>
</div>
<div class="cell border-box-sizing text_cell rendered"><div class="inner_cell">
<div class="text_cell_render border-box-sizing rendered_html">
<p>Running the above command-</p>
<ol>
<li>Creates 3 groups of training batches <ol>
<li>The first one is referred to as the original (with minimal or zero augmentation)</li>
<li>The second one is with random augmentation applied to the first one.</li>
<li>The third one is again with random augmentation applied to the first one.
{% include note.html content='Random augmentations are stochastic. Therefore, the second and the third batch are different from each other. ' %}2. Converts every Batch Normalization inside the model to Split Batch Normalization Layer. </li>
</ol>
</li>
<li>Does not apply random erase to the first batch, also referred to as the first augmentation split. </li>
</ol>
</div>
</div>
</div>
<div class="cell border-box-sizing text_cell rendered"><div class="inner_cell">
<div class="text_cell_render border-box-sizing rendered_html">
<h2 id="SplitBatchNorm2d"><code>SplitBatchNorm2d</code><a class="anchor-link" href="#SplitBatchNorm2d"> </a></h2>
</div>
</div>
</div>
<div class="cell border-box-sizing text_cell rendered"><div class="inner_cell">
<div class="text_cell_render border-box-sizing rendered_html">
<p>The <code>SplitBatchNorm2d</code> on it's own is few lines of code:</p>
<div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">SplitBatchNorm2d</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">BatchNorm2d</span><span class="p">):</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">num_features</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-5</span><span class="p">,</span> <span class="n">momentum</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span> <span class="n">affine</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="n">track_running_stats</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">num_splits</span><span class="o">=</span><span class="mi">2</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">num_features</span><span class="p">,</span> <span class="n">eps</span><span class="p">,</span> <span class="n">momentum</span><span class="p">,</span> <span class="n">affine</span><span class="p">,</span> <span class="n">track_running_stats</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">num_splits</span> <span class="o">></span> <span class="mi">1</span><span class="p">,</span> <span class="s1">'Should have at least one aux BN layer (num_splits at least 2)'</span>
<span class="bp">self</span><span class="o">.</span><span class="n">num_splits</span> <span class="o">=</span> <span class="n">num_splits</span>
<span class="bp">self</span><span class="o">.</span><span class="n">aux_bn</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ModuleList</span><span class="p">([</span>
<span class="n">nn</span><span class="o">.</span><span class="n">BatchNorm2d</span><span class="p">(</span><span class="n">num_features</span><span class="p">,</span> <span class="n">eps</span><span class="p">,</span> <span class="n">momentum</span><span class="p">,</span> <span class="n">affine</span><span class="p">,</span> <span class="n">track_running_stats</span><span class="p">)</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_splits</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)])</span>
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="nb">input</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">training</span><span class="p">:</span> <span class="c1"># aux BN only relevant while training</span>
<span class="n">split_size</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">//</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_splits</span>
<span class="k">assert</span> <span class="nb">input</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">split_size</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_splits</span><span class="p">,</span> <span class="s2">"batch size must be evenly divisible by num_splits"</span>
<span class="n">split_input</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">split_size</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="p">[</span><span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">split_input</span><span class="p">[</span><span class="mi">0</span><span class="p">])]</span>
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">a</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">aux_bn</span><span class="p">):</span>
<span class="n">x</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">a</span><span class="p">(</span><span class="n">split_input</span><span class="p">[</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]))</span>
<span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">return</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="nb">input</span><span class="p">)</span>
</pre></div>
</div>
</div>
</div>
<div class="cell border-box-sizing text_cell rendered"><div class="inner_cell">
<div class="text_cell_render border-box-sizing rendered_html">
<p>Basically, inside the <a href="https://arxiv.org/abs/1911.09665">Adversarial Examples Improve Image Recognition</a> paper, the authors refer to this Split Batch Norm as Auxilary batch norm. Therefore, as we can see in code, <code>self.aux_bn</code> is a list of <code>num_splits-1</code> length.</p>
<p>Basically, because we subclass <code>torch.nn.BatchNorm2d</code>, therefore, this SplitBatchNorm2d is in itself an instance of Batch Normalization, therefore the first batch norm layer is the <code>nn.BatchNorm2d</code> itself which can be used to normalize the first augmentation split or the clean batch.</p>
<p>Then, we create <code>num_splits-1</code> number of auxiliary batch norms to normalize the remaining splits in the input batch.</p>
<p>This way, we normalize the input batch <code>X</code> separately depending on the number of splits. This is achieved inside these lines of code:</p>
<div class="highlight"><pre><span></span><span class="n">split_input</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">split_size</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="p">[</span><span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">split_input</span><span class="p">[</span><span class="mi">0</span><span class="p">])]</span>
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">a</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">aux_bn</span><span class="p">):</span>
<span class="n">x</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">a</span><span class="p">(</span><span class="n">split_input</span><span class="p">[</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]))</span>
<span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
</pre></div>
</div>
</div>
</div>
<div class="cell border-box-sizing text_cell rendered"><div class="inner_cell">
<div class="text_cell_render border-box-sizing rendered_html">
<p>And that's how <code>timm</code> implements <code>SplitBatchNorm2d</code> in PyTorch :)</p>
</div>
</div>
</div>
</div>