Skip to content

Commit

Permalink
update regularizer
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanhubens committed Mar 20, 2024
1 parent e8fbc1b commit 70d9977
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 42 deletions.
8 changes: 0 additions & 8 deletions fasterai/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,14 +163,6 @@
'fasterai.quantize.quantizer.Quantizer.quantize': ( 'quantize.quantizer.html#quantizer.quantize',
'fasterai/quantize/quantizer.py')},
'fasterai.regularize.all': {},
'fasterai.regularize.regularization_callback': { 'fasterai.regularize.regularization_callback.RegularizationCallback': ( 'regularize.regularizer.html#regularizationcallback',
'fasterai/regularize/regularization_callback.py'),
'fasterai.regularize.regularization_callback.RegularizationCallback.__init__': ( 'regularize.regularizer.html#regularizationcallback.__init__',
'fasterai/regularize/regularization_callback.py'),
'fasterai.regularize.regularization_callback.RegularizationCallback.after_loss': ( 'regularize.regularizer.html#regularizationcallback.after_loss',
'fasterai/regularize/regularization_callback.py'),
'fasterai.regularize.regularization_callback.RegularizationCallback.get_norm': ( 'regularize.regularizer.html#regularizationcallback.get_norm',
'fasterai/regularize/regularization_callback.py')},
'fasterai.regularize.regularize_callback': { 'fasterai.regularize.regularize_callback.RegularizeCallback': ( 'regularize.regularizer.html#regularizecallback',
'fasterai/regularize/regularize_callback.py'),
'fasterai.regularize.regularize_callback.RegularizeCallback.__init__': ( 'regularize.regularizer.html#regularizecallback.__init__',
Expand Down
28 changes: 0 additions & 28 deletions fasterai/regularize/regularization_callback.py

This file was deleted.

4 changes: 2 additions & 2 deletions fasterai/regularize/regularize_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# %% ../../nbs/05_regularize.regularizer.ipynb 4
class RegularizeCallback(Callback):
"Callback to apply grouped weight decay"
def __init__(self, g, wd=0.01):
def __init__(self, g, wd=0.01, layer_type=nn.Conv2d):
store_attr()

def after_loss(self):
Expand All @@ -25,4 +25,4 @@ def after_loss(self):
self.learn.loss = self.learn.loss_grad.clone()

def get_norm(self):
return self.wd*torch.stack([large_final.f(m.weight)[None].sum(Granularities.get_dim(m, self.g)).sum() for m in self.learn.modules() if isinstance(m, nn.Conv2d)]).sum()
return self.wd*torch.stack([large_final.f(m.weight)[None].sum(Granularities.get_dim(m, self.g)).sum() for m in self.learn.modules() if isinstance(m, self.layer_type)]).sum()
10 changes: 6 additions & 4 deletions nbs/05_regularize.regularizer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
"#| export\n",
"class RegularizeCallback(Callback):\n",
" \"Callback to apply grouped weight decay\"\n",
" def __init__(self, g, wd=0.01):\n",
" def __init__(self, g, wd=0.01, layer_type=nn.Conv2d):\n",
" store_attr()\n",
"\n",
" def after_loss(self):\n",
Expand All @@ -73,7 +73,7 @@
" self.learn.loss = self.learn.loss_grad.clone()\n",
" \n",
" def get_norm(self):\n",
" return self.wd*torch.stack([large_final.f(m.weight)[None].sum(Granularities.get_dim(m, self.g)).sum() for m in self.learn.modules() if isinstance(m, nn.Conv2d)]).sum()"
" return self.wd*torch.stack([large_final.f(m.weight)[None].sum(Granularities.get_dim(m, self.g)).sum() for m in self.learn.modules() if isinstance(m, self.layer_type)]).sum()"
]
},
{
Expand All @@ -89,7 +89,8 @@
"\n",
"### RegularizeCallback\n",
"\n",
"> RegularizeCallback (g, wd=0.01)\n",
"> RegularizeCallback (g, wd=0.01, layer_type=<class\n",
"> 'torch.nn.modules.conv.Conv2d'>)\n",
"\n",
"Callback to apply grouped weight decay"
],
Expand All @@ -98,7 +99,8 @@
"\n",
"### RegularizeCallback\n",
"\n",
"> RegularizeCallback (g, wd=0.01)\n",
"> RegularizeCallback (g, wd=0.01, layer_type=<class\n",
"> 'torch.nn.modules.conv.Conv2d'>)\n",
"\n",
"Callback to apply grouped weight decay"
]
Expand Down

0 comments on commit 70d9977

Please sign in to comment.