From 70d9977e896083736d1c2c6f91e1b5dd0c9f7061 Mon Sep 17 00:00:00 2001 From: nathanhubens Date: Wed, 20 Mar 2024 13:50:08 +0100 Subject: [PATCH] update regularizer --- fasterai/_modidx.py | 8 ------ .../regularize/regularization_callback.py | 28 ------------------- fasterai/regularize/regularize_callback.py | 4 +-- nbs/05_regularize.regularizer.ipynb | 10 ++++--- 4 files changed, 8 insertions(+), 42 deletions(-) delete mode 100644 fasterai/regularize/regularization_callback.py diff --git a/fasterai/_modidx.py b/fasterai/_modidx.py index 50dadf5..16ea9ce 100644 --- a/fasterai/_modidx.py +++ b/fasterai/_modidx.py @@ -163,14 +163,6 @@ 'fasterai.quantize.quantizer.Quantizer.quantize': ( 'quantize.quantizer.html#quantizer.quantize', 'fasterai/quantize/quantizer.py')}, 'fasterai.regularize.all': {}, - 'fasterai.regularize.regularization_callback': { 'fasterai.regularize.regularization_callback.RegularizationCallback': ( 'regularize.regularizer.html#regularizationcallback', - 'fasterai/regularize/regularization_callback.py'), - 'fasterai.regularize.regularization_callback.RegularizationCallback.__init__': ( 'regularize.regularizer.html#regularizationcallback.__init__', - 'fasterai/regularize/regularization_callback.py'), - 'fasterai.regularize.regularization_callback.RegularizationCallback.after_loss': ( 'regularize.regularizer.html#regularizationcallback.after_loss', - 'fasterai/regularize/regularization_callback.py'), - 'fasterai.regularize.regularization_callback.RegularizationCallback.get_norm': ( 'regularize.regularizer.html#regularizationcallback.get_norm', - 'fasterai/regularize/regularization_callback.py')}, 'fasterai.regularize.regularize_callback': { 'fasterai.regularize.regularize_callback.RegularizeCallback': ( 'regularize.regularizer.html#regularizecallback', 'fasterai/regularize/regularize_callback.py'), 'fasterai.regularize.regularize_callback.RegularizeCallback.__init__': ( 'regularize.regularizer.html#regularizecallback.__init__', diff --git a/fasterai/regularize/regularization_callback.py b/fasterai/regularize/regularization_callback.py deleted file mode 100644 index 58e9e86..0000000 --- a/fasterai/regularize/regularization_callback.py +++ /dev/null @@ -1,28 +0,0 @@ -# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/05_regularize.regularizer.ipynb. - -# %% auto 0 -__all__ = ['RegularizationCallback'] - -# %% ../../nbs/05_regularize.regularizer.ipynb 3 -from fastai.callback.all import * -from fastcore.basics import store_attr -from ..core.criteria import * -from ..core.granularity import * - -import torch -import torch.nn as nn -import torch.nn.functional as F - -# %% ../../nbs/05_regularize.regularizer.ipynb 4 -class RegularizationCallback(Callback): - "Callback to apply grouped weight decay" - def __init__(self, g, wd=0.01): - store_attr() - - def after_loss(self): - reg = self.get_norm() - self.learn.loss_grad += reg - self.learn.loss = self.learn.loss_grad.clone() - - def get_norm(self): - return self.wd*torch.stack([large_final.f(m.weight)[None].sum(Granularities.get_dim(m, self.g)).sum() for m in self.learn.modules() if isinstance(m, nn.Conv2d)]).sum() diff --git a/fasterai/regularize/regularize_callback.py b/fasterai/regularize/regularize_callback.py index 741f1a3..718a914 100644 --- a/fasterai/regularize/regularize_callback.py +++ b/fasterai/regularize/regularize_callback.py @@ -16,7 +16,7 @@ # %% ../../nbs/05_regularize.regularizer.ipynb 4 class RegularizeCallback(Callback): "Callback to apply grouped weight decay" - def __init__(self, g, wd=0.01): + def __init__(self, g, wd=0.01, layer_type=nn.Conv2d): store_attr() def after_loss(self): @@ -25,4 +25,4 @@ def after_loss(self): self.learn.loss = self.learn.loss_grad.clone() def get_norm(self): - return self.wd*torch.stack([large_final.f(m.weight)[None].sum(Granularities.get_dim(m, self.g)).sum() for m in self.learn.modules() if isinstance(m, nn.Conv2d)]).sum() + return self.wd*torch.stack([large_final.f(m.weight)[None].sum(Granularities.get_dim(m, self.g)).sum() for m in self.learn.modules() if isinstance(m, self.layer_type)]).sum() diff --git a/nbs/05_regularize.regularizer.ipynb b/nbs/05_regularize.regularizer.ipynb index 4b5a651..401720a 100644 --- a/nbs/05_regularize.regularizer.ipynb +++ b/nbs/05_regularize.regularizer.ipynb @@ -64,7 +64,7 @@ "#| export\n", "class RegularizeCallback(Callback):\n", " \"Callback to apply grouped weight decay\"\n", - " def __init__(self, g, wd=0.01):\n", + " def __init__(self, g, wd=0.01, layer_type=nn.Conv2d):\n", " store_attr()\n", "\n", " def after_loss(self):\n", @@ -73,7 +73,7 @@ " self.learn.loss = self.learn.loss_grad.clone()\n", " \n", " def get_norm(self):\n", - " return self.wd*torch.stack([large_final.f(m.weight)[None].sum(Granularities.get_dim(m, self.g)).sum() for m in self.learn.modules() if isinstance(m, nn.Conv2d)]).sum()" + " return self.wd*torch.stack([large_final.f(m.weight)[None].sum(Granularities.get_dim(m, self.g)).sum() for m in self.learn.modules() if isinstance(m, self.layer_type)]).sum()" ] }, { @@ -89,7 +89,8 @@ "\n", "### RegularizeCallback\n", "\n", - "> RegularizeCallback (g, wd=0.01)\n", + "> RegularizeCallback (g, wd=0.01, layer_type= 'torch.nn.modules.conv.Conv2d'>)\n", "\n", "Callback to apply grouped weight decay" ], @@ -98,7 +99,8 @@ "\n", "### RegularizeCallback\n", "\n", - "> RegularizeCallback (g, wd=0.01)\n", + "> RegularizeCallback (g, wd=0.01, layer_type= 'torch.nn.modules.conv.Conv2d'>)\n", "\n", "Callback to apply grouped weight decay" ]