Skip to content
Permalink
Branch: master
Find file Copy path
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
43 lines (35 sloc) 1.75 KB
# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/19_callback.mixup.ipynb (unless otherwise specified).
__all__ = ['reduce_loss', 'MixUp']
# Cell
from ..basics import *
from .progress import *
from ..vision.core import *
from torch.distributions.beta import Beta
# Cell
def reduce_loss(loss, reduction='mean'):
return loss.mean() if reduction=='mean' else loss.sum() if reduction=='sum' else loss
# Cell
class MixUp(Callback):
run_after,run_valid = [Normalize],False
def __init__(self, alpha=0.4): self.distrib = Beta(tensor(alpha), tensor(alpha))
def begin_fit(self):
self.stack_y = getattr(self.learn.loss_func, 'y_int', False)
if self.stack_y: self.old_lf,self.learn.loss_func = self.learn.loss_func,self.lf
def after_fit(self):
if self.stack_y: self.learn.loss_func = self.old_lf
def begin_batch(self):
lam = self.distrib.sample((self.y.size(0),)).squeeze().to(self.x.device)
lam = torch.stack([lam, 1-lam], 1)
self.lam = lam.max(1)[0]
shuffle = torch.randperm(self.y.size(0)).to(self.x.device)
xb1,self.yb1 = tuple(L(self.xb).itemgot(shuffle)),tuple(L(self.yb).itemgot(shuffle))
nx_dims = len(self.x.size())
self.learn.xb = tuple(L(xb1,self.xb).map_zip(torch.lerp,weight=unsqueeze(self.lam, n=nx_dims-1)))
if not self.stack_y:
ny_dims = len(self.y.size())
self.learn.yb = tuple(L(self.yb1,self.yb).map_zip(torch.lerp,weight=unsqueeze(self.lam, n=ny_dims-1)))
def lf(self, pred, *yb):
if not self.training: return self.old_lf(pred, *yb)
with NoneReduce(self.old_lf) as lf:
loss = torch.lerp(lf(pred,*self.yb1), lf(pred,*yb), self.lam)
return reduce_loss(loss, getattr(self.old_lf, 'reduction', 'mean'))
You can’t perform that action at this time.