Skip to content

Commit

Permalink
Merge pull request #6 from nathanhubens/criteria
Browse files Browse the repository at this point in the history
fix regularizer
  • Loading branch information
nathanhubens committed May 5, 2023
2 parents e971710 + c118f55 commit 39e8be7
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 48 deletions.
6 changes: 3 additions & 3 deletions conda/fasterai/meta.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
package:
name: fasterai
version: 0.1.12
version: 0.1.13
source:
sha256: 41eb94135382410c0ed4475111c1a5c7dc18e48f9b6d1e1954fc67d3a2d8cb98
url: https://files.pythonhosted.org/packages/aa/ac/218034da9af0c94b1095d462e7d79d22dec255b96449f022ce491305c9bf/fasterai-0.1.12.tar.gz
sha256: 61fd53a32080b08543a91cfceaaf28423a5bdf48bfd047bc95704c5fdc10b555
url: https://files.pythonhosted.org/packages/59/17/5ace7091abdcf7f8c5e53a84c87d099a77ae6a22eb252c51ce0a28aaae76/fasterai-0.1.13.tar.gz
about:
description: "# Fasterai\n\n\n\n![header](https://capsule-render.vercel.app/api?type=waving&color=008080&height=300&section=header&text=fasterai%20&fontSize=90&animation=fadeIn&fontAlignY=38&desc=A%20Library%20to%20make%20smaller%20and%20faster%20neural%20networks&descAlignY=51&descAlign=62)\n\
\n<p align=\"center\">\n <a href=\"https://pypi.org/project/fasterai/\"><img\
Expand Down
5 changes: 2 additions & 3 deletions fasterai/core/criteria.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ class Criteria():
def __init__(self, f, reducer='mean', normalizer=None, needs_init=False, needs_update=False, output_f=None, return_init=False):
store_attr()
assert (needs_init and needs_update)==False, "The init values will be overwritten by the updating ones."

@torch.no_grad()

def __call__(self, m, g):
try:
dim = Granularities.get_dim(m, g)
Expand All @@ -41,7 +40,7 @@ def __call__(self, m, g):
elif self.return_init: scores = wi
else: scores = wf

scores = self._rescale(scores).mul_(m._mask)
scores = self._rescale(scores)._mul(m._mask)
scores = self._reduce(scores, dim)
scores = self._normalize(scores)
return scores
Expand Down
5 changes: 3 additions & 2 deletions fasterai/regularize/regularization_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from fastai.callback.all import *
from fastcore.basics import store_attr
from ..core.criteria import *
from ..core.granularity import *

import torch
import torch.nn as nn
Expand All @@ -15,7 +16,7 @@
# %% ../../nbs/05_regularize.regularizer.ipynb 4
class RegularizationCallback(Callback):
"Callback to apply grouped weight decay"
def __init__(self, granularity, wd=0.01):
def __init__(self, g, wd=0.01):
store_attr()

def after_loss(self):
Expand All @@ -24,4 +25,4 @@ def after_loss(self):
self.learn.loss = self.learn.loss_grad.clone()

def get_norm(self):
return self.wd*torch.stack([large_final.get_scores(m, large_final(m),self.granularity).sum() for m in self.learn.modules() if isinstance(m, nn.Conv2d)]).sum()
return self.wd*torch.stack([large_final.f(m.weight)[None].sum(Granularities.get_dim(m, self.g)).sum() for m in self.learn.modules() if isinstance(m, nn.Conv2d)]).sum()
5 changes: 2 additions & 3 deletions nbs/00b_core.criteria.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,7 @@
" def __init__(self, f, reducer='mean', normalizer=None, needs_init=False, needs_update=False, output_f=None, return_init=False):\n",
" store_attr()\n",
" assert (needs_init and needs_update)==False, \"The init values will be overwritten by the updating ones.\"\n",
" \n",
" @torch.no_grad() \n",
" \n",
" def __call__(self, m, g):\n",
" try:\n",
" dim = Granularities.get_dim(m, g)\n",
Expand All @@ -156,7 +155,7 @@
" elif self.return_init: scores = wi\n",
" else: scores = wf\n",
" \n",
" scores = self._rescale(scores).mul_(m._mask)\n",
" scores = self._rescale(scores)._mul(m._mask)\n",
" scores = self._reduce(scores, dim)\n",
" scores = self._normalize(scores)\n",
" return scores\n",
Expand Down
5 changes: 3 additions & 2 deletions nbs/05_regularize.regularizer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
"from fastai.callback.all import *\n",
"from fastcore.basics import store_attr\n",
"from fasterai.core.criteria import *\n",
"from fasterai.core.granularity import *\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
Expand All @@ -63,7 +64,7 @@
"#| export\n",
"class RegularizationCallback(Callback):\n",
" \"Callback to apply grouped weight decay\"\n",
" def __init__(self, granularity, wd=0.01):\n",
" def __init__(self, g, wd=0.01):\n",
" store_attr()\n",
"\n",
" def after_loss(self):\n",
Expand All @@ -72,7 +73,7 @@
" self.learn.loss = self.learn.loss_grad.clone()\n",
" \n",
" def get_norm(self):\n",
" return self.wd*torch.stack([large_final.get_scores(m, large_final(m),self.granularity).sum() for m in self.learn.modules() if isinstance(m, nn.Conv2d)]).sum()"
" return self.wd*torch.stack([large_final.f(m.weight)[None].sum(Granularities.get_dim(m, self.g)).sum() for m in self.learn.modules() if isinstance(m, nn.Conv2d)]).sum()"
]
},
{
Expand Down
110 changes: 75 additions & 35 deletions nbs/09a_tutorial.regularizer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,30 @@
"id": "ad51bb5a",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"<style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
" background: #F44336;\n",
" }\n",
"</style>\n"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
Expand All @@ -80,24 +104,24 @@
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>0.681536</td>\n",
" <td>0.466989</td>\n",
" <td>0.835589</td>\n",
" <td>00:11</td>\n",
" <td>0.668929</td>\n",
" <td>0.511108</td>\n",
" <td>0.841001</td>\n",
" <td>00:07</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>0.358927</td>\n",
" <td>0.318825</td>\n",
" <td>0.865359</td>\n",
" <td>00:10</td>\n",
" <td>0.370381</td>\n",
" <td>0.225852</td>\n",
" <td>0.897158</td>\n",
" <td>00:07</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>0.201207</td>\n",
" <td>0.220008</td>\n",
" <td>0.923545</td>\n",
" <td>00:10</td>\n",
" <td>0.195002</td>\n",
" <td>0.208013</td>\n",
" <td>0.922192</td>\n",
" <td>00:07</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
Expand Down Expand Up @@ -128,27 +152,43 @@
{
"cell_type": "code",
"execution_count": null,
"id": "9d7cddff",
"id": "05dab4a5",
"metadata": {},
"outputs": [],
"source": [
"reg_cb = RegularizationCallback('filter')"
]
},
{
"cell_type": "markdown",
"id": "15e836e3",
"metadata": {},
"source": [
"Train a model with Regularization"
"reg_cb = RegularizationCallback('filter', wd=0.0001)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5f962c44",
"id": "cb9146dd",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"<style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
" background: #F44336;\n",
" }\n",
"</style>\n"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
Expand All @@ -165,24 +205,24 @@
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>1.633497</td>\n",
" <td>1.468702</td>\n",
" <td>0.812585</td>\n",
" <td>00:10</td>\n",
" <td>16.619804</td>\n",
" <td>15.744213</td>\n",
" <td>0.817997</td>\n",
" <td>00:07</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>1.334702</td>\n",
" <td>1.173871</td>\n",
" <td>0.907307</td>\n",
" <td>00:10</td>\n",
" <td>14.298106</td>\n",
" <td>13.068723</td>\n",
" <td>0.901218</td>\n",
" <td>00:07</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>1.152696</td>\n",
" <td>1.136654</td>\n",
" <td>0.933694</td>\n",
" <td>00:11</td>\n",
" <td>12.655948</td>\n",
" <td>12.294454</td>\n",
" <td>0.928958</td>\n",
" <td>00:07</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
Expand Down

0 comments on commit 39e8be7

Please sign in to comment.