Skip to content

Commit

Permalink
New KD callback
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanhubens committed May 29, 2022
1 parent 0006b7f commit 91f8e12
Show file tree
Hide file tree
Showing 6 changed files with 1,174 additions and 212 deletions.
187 changes: 179 additions & 8 deletions docs/knowledge_distillation.html
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@

<div class="inner_cell">
<div class="input_area">
<div class=" highlight hl-ipython3"><pre><span></span><span class="n">predictions</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">hard_preds</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">);</span> <span class="n">predictions</span>
<div class=" highlight hl-ipython3"><pre><span></span><span class="n">predictions</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">);</span> <span class="n">predictions</span>
</pre></div>

</div>
Expand All @@ -140,7 +140,7 @@


<div class="output_text output_subarea output_execute_result">
<pre>tensor([0.0864, 0.6386, 0.0388, 0.2126, 0.0236])</pre>
<pre>tensor([0.1063, 0.6431, 0.0354, 0.1937, 0.0215])</pre>
</div>

</div>
Expand All @@ -165,7 +165,7 @@

<div class="inner_cell">
<div class="input_area">
<div class=" highlight hl-ipython3"><pre><span></span><span class="n">soft_predictions</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">hard_preds</span><span class="o">/</span><span class="mi">3</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">);</span> <span class="n">soft_predictions</span>
<div class=" highlight hl-ipython3"><pre><span></span><span class="n">soft_predictions</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">logits</span><span class="o">/</span><span class="mi">3</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">);</span> <span class="n">soft_predictions</span>
</pre></div>

</div>
Expand All @@ -180,7 +180,7 @@


<div class="output_text output_subarea output_execute_result">
<pre>tensor([0.1751, 0.3410, 0.1341, 0.2363, 0.1135])</pre>
<pre>tensor([0.1879, 0.3423, 0.1302, 0.2294, 0.1102])</pre>
</div>

</div>
Expand Down Expand Up @@ -252,7 +252,7 @@


<div class="output_markdown rendered_html output_subarea ">
<h2 id="KnowledgeDistillation" class="doc_header"><code>class</code> <code>KnowledgeDistillation</code><a href="https://github.com/nathanhubens/fasterai/tree/master/fasterai/distill/distillation_callback.py#L13" class="source_link" style="float:right">[source]</a></h2><blockquote><p><code>KnowledgeDistillation</code>(<strong><code>teacher</code></strong>, <strong><code>loss</code></strong>) :: <code>Callback</code></p>
<h2 id="KnowledgeDistillationCallback" class="doc_header"><code>class</code> <code>KnowledgeDistillationCallback</code><a href="https://github.com/nathanhubens/fasterai/tree/master/fasterai/distill/distillation_callback.py#L17" class="source_link" style="float:right">[source]</a></h2><blockquote><p><code>KnowledgeDistillationCallback</code>(<strong><code>teacher</code></strong>, <strong><code>loss</code></strong>, <strong><code>activations_student</code></strong>=<em><code>None</code></em>, <strong><code>activations_teacher</code></strong>=<em><code>None</code></em>, <strong><code>weight</code></strong>=<em><code>0.5</code></em>) :: <code>Callback</code></p>
</blockquote>
<p>Basic class handling tweaks of the training loop by changing a <code>Learner</code> in various events</p>

Expand All @@ -270,6 +270,61 @@ <h2 id="KnowledgeDistillation" class="doc_header"><code>class</code> <code>Knowl

<div class="cell border-box-sizing code_cell rendered">

</div>
{% endraw %}

{% raw %}

<div class="cell border-box-sizing code_cell rendered">

<div class="output_wrapper">
<div class="output">

<div class="output_area">


<div class="output_markdown rendered_html output_subarea ">
<h4 id="get_model_layers" class="doc_header"><code>get_model_layers</code><a href="https://github.com/nathanhubens/fasterai/tree/master/fasterai/distill/distillation_callback.py#L62" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>get_model_layers</code>(<strong><code>model</code></strong>, <strong><code>getLayerRepr</code></strong>=<em><code>False</code></em>)</p>
</blockquote>

</div>

</div>

</div>
</div>

</div>
{% endraw %}

{% raw %}

<div class="cell border-box-sizing code_cell rendered">

<div class="output_wrapper">
<div class="output">

<div class="output_area">


<div class="output_markdown rendered_html output_subarea ">
<h4 id="get_module_by_name" class="doc_header"><code>get_module_by_name</code><a href="https://github.com/nathanhubens/fasterai/tree/master/fasterai/distill/distillation_callback.py#L79" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>get_module_by_name</code>(<strong><code>module</code></strong>:<code>Union</code>[<code>Tensor</code>, <code>Module</code>], <strong><code>access_string</code></strong>:<code>str</code>)</p>
</blockquote>

</div>

</div>

</div>
</div>

</div>
{% endraw %}

{% raw %}

<div class="cell border-box-sizing code_cell rendered">

</div>
{% endraw %}

Expand All @@ -280,6 +335,122 @@ <h2 id="KnowledgeDistillation" class="doc_header"><code>class</code> <code>Knowl
</div>
</div>
</div>
<div class="cell border-box-sizing text_cell rendered"><div class="inner_cell">
<div class="text_cell_render border-box-sizing rendered_html">
<h1 id="export">export<a class="anchor-link" href="#export"> </a></h1><p>def SoftTarget(pred, teacher_pred, T=5, *<em>kwargs):
return nn.KLDivLoss(reduction='batchmean')(F.log_softmax(pred/T, dim=1), F.softmax(teacher_pred/T, dim=1)) </em> (T*T)</p>
<p>def Logits(pred, teacher_pred, **kwargs):
return F.mse_loss(preds, teacher_pred)</p>
<p>def Mutual(pred, teacher_pred, **kwargs):
return nn.KLDivLoss(reduction='batchmean')(F.log_softmax(pred, dim=1), F.softmax(teacher_pred, dim=1))</p>
<p>def Attention(pred, teacher_pred, fm_s, fm_t, p=2, **kwargs):
return sum([F.mse_loss(F.normalize(fm_s[name_st].pow(p).mean(1),dim=(1,2)), F.normalize(fm_t[name_t].pow(p).mean(1),dim=(1,2))) for name_st, name_t in zip(fm_s, fm_t)])</p>
<p>def ActivationBoundaries(pred, teacher_pred, fm_s, fm_t, m=2, *<em>kwargs):
return sum([((fm_s[name_st] + m).pow(2) </em> ((fm_s[name_st] &gt; -m) &amp; (fm_t[name_t] &lt;= 0)).float() + (fm_s[name_st] - m).pow(2) * ((fm_s[name_st] &lt;= m) &amp; (fm_t[name_t] &gt; 0)).float()).mean() for name_st, name_t in zip(fm_s, fm_t)])</p>
<p>def FitNet(pred, teacher_pred, fm_s, fm_t, **kwargs):
return sum([F.mse_loss(fm_s[name_st],fm_t[name_t]) for name_st, name_t in zip(fm_s, fm_t)])</p>
<p>def Similarity(fm_s, fm_t, p=2, **kwargs):
return sum([F.mse_loss(F.normalize(fm_s[name_st].view(fm_s[name_st].size(0), -1) @ fm_s[name_st].view(fm_s[name_st].size(0), -1).t(), p=p, dim=1), F.normalize(fm_t[name_t].view(fm_t[name_t].size(0), -1) @ fm_t[name_t].view(fm_t[name_t].size(0), -1).t(), p=p, dim=1)) for name_st, name_t in zip(fm_s, fm_t)])</p>

</div>
</div>
</div>
{% raw %}

<div class="cell border-box-sizing code_cell rendered">

<div class="output_wrapper">
<div class="output">

<div class="output_area">


<div class="output_markdown rendered_html output_subarea ">
<h4 id="SoftTarget" class="doc_header"><code>SoftTarget</code><a href="https://github.com/nathanhubens/fasterai/tree/master/fasterai/distill/distillation_callback.py#L86" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>SoftTarget</code>(<strong><code>pred</code></strong>, <strong><code>teacher_pred</code></strong>, <strong><code>T</code></strong>=<em><code>5</code></em>, <strong>**<code>kwargs</code></strong>)</p>
</blockquote>

</div>

</div>

</div>
</div>

</div>
{% endraw %}

{% raw %}

<div class="cell border-box-sizing code_cell rendered">

<div class="output_wrapper">
<div class="output">

<div class="output_area">


<div class="output_markdown rendered_html output_subarea ">
<h4 id="Logits" class="doc_header"><code>Logits</code><a href="https://github.com/nathanhubens/fasterai/tree/master/fasterai/distill/distillation_callback.py#L89" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>Logits</code>(<strong><code>pred</code></strong>, <strong><code>teacher_pred</code></strong>, <strong>**<code>kwargs</code></strong>)</p>
</blockquote>

</div>

</div>

</div>
</div>

</div>
{% endraw %}

{% raw %}

<div class="cell border-box-sizing code_cell rendered">

<div class="output_wrapper">
<div class="output">

<div class="output_area">


<div class="output_markdown rendered_html output_subarea ">
<h4 id="Mutual" class="doc_header"><code>Mutual</code><a href="https://github.com/nathanhubens/fasterai/tree/master/fasterai/distill/distillation_callback.py#L92" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>Mutual</code>(<strong><code>pred</code></strong>, <strong><code>teacher_pred</code></strong>, <strong>**<code>kwargs</code></strong>)</p>
</blockquote>

</div>

</div>

</div>
</div>

</div>
{% endraw %}

{% raw %}

<div class="cell border-box-sizing code_cell rendered">

<div class="output_wrapper">
<div class="output">

<div class="output_area">


<div class="output_markdown rendered_html output_subarea ">
<h4 id="Attention" class="doc_header"><code>Attention</code><a href="https://github.com/nathanhubens/fasterai/tree/master/fasterai/distill/distillation_callback.py#L96" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>Attention</code>(<strong><code>fm_s</code></strong>, <strong><code>fm_t</code></strong>, <strong><code>p</code></strong>=<em><code>2</code></em>, <strong>**<code>kwargs</code></strong>)</p>
</blockquote>

</div>

</div>

</div>
</div>

</div>
{% endraw %}

{% raw %}

<div class="cell border-box-sizing code_cell rendered">
Expand All @@ -291,7 +462,7 @@ <h2 id="KnowledgeDistillation" class="doc_header"><code>class</code> <code>Knowl


<div class="output_markdown rendered_html output_subarea ">
<h4 id="SoftTarget" class="doc_header"><code>SoftTarget</code><a href="https://github.com/nathanhubens/fasterai/tree/master/fasterai/distill/distillation_callback.py#L25" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>SoftTarget</code>(<strong><code>y</code></strong>, <strong><code>labels</code></strong>, <strong><code>teacher_scores</code></strong>, <strong><code>T</code></strong>=<em><code>20</code></em>, <strong><code>α</code></strong>=<em><code>0.7</code></em>, <strong>**<code>kwargs</code></strong>)</p>
<h4 id="ActivationBoundaries" class="doc_header"><code>ActivationBoundaries</code><a href="https://github.com/nathanhubens/fasterai/tree/master/fasterai/distill/distillation_callback.py#L99" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>ActivationBoundaries</code>(<strong><code>fm_s</code></strong>, <strong><code>fm_t</code></strong>, <strong><code>m</code></strong>=<em><code>2</code></em>, <strong>**<code>kwargs</code></strong>)</p>
</blockquote>

</div>
Expand All @@ -315,7 +486,7 @@ <h4 id="SoftTarget" class="doc_header"><code>SoftTarget</code><a href="https://g


<div class="output_markdown rendered_html output_subarea ">
<h4 id="LogitsRegression" class="doc_header"><code>LogitsRegression</code><a href="https://github.com/nathanhubens/fasterai/tree/master/fasterai/distill/distillation_callback.py#L28" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>LogitsRegression</code>(<strong><code>y</code></strong>, <strong><code>labels</code></strong>, <strong><code>teacher_scores</code></strong>, <strong>**<code>kwargs</code></strong>)</p>
<h4 id="FitNet" class="doc_header"><code>FitNet</code><a href="https://github.com/nathanhubens/fasterai/tree/master/fasterai/distill/distillation_callback.py#L102" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>FitNet</code>(<strong><code>fm_s</code></strong>, <strong><code>fm_t</code></strong>, <strong>**<code>kwargs</code></strong>)</p>
</blockquote>

</div>
Expand All @@ -339,7 +510,7 @@ <h4 id="LogitsRegression" class="doc_header"><code>LogitsRegression</code><a hre


<div class="output_markdown rendered_html output_subarea ">
<h4 id="WeightRegression" class="doc_header"><code>WeightRegression</code><a href="https://github.com/nathanhubens/fasterai/tree/master/fasterai/distill/distillation_callback.py#L31" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>WeightRegression</code>(<strong><code>y</code></strong>, <strong><code>labels</code></strong>, <strong><code>teacher_scores</code></strong>, <strong><code>student</code></strong>, <strong><code>teacher</code></strong>, <strong><code>α</code></strong>=<em><code>0.5</code></em>, <strong>**<code>kwargs</code></strong>)</p>
<h4 id="Similarity" class="doc_header"><code>Similarity</code><a href="https://github.com/nathanhubens/fasterai/tree/master/fasterai/distill/distillation_callback.py#L105" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>Similarity</code>(<strong><code>fm_s</code></strong>, <strong><code>fm_t</code></strong>, <strong><code>pred</code></strong>, <strong><code>p</code></strong>=<em><code>2</code></em>, <strong>**<code>kwargs</code></strong>)</p>
</blockquote>

</div>
Expand Down
Loading

0 comments on commit 91f8e12

Please sign in to comment.