Skip to content
Branch: master
Find file Copy path
Find file Copy path
3 contributors

Users who have contributed to this file

@jph00 @sgugger @lgvaz
671 lines (568 sloc) 25.6 KB
# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/01_layers.ipynb (unless otherwise specified).
__all__ = ['module', 'Identity', 'Lambda', 'PartialLambda', 'Flatten', 'View', 'ResizeBatch', 'Debugger',
'sigmoid_range', 'SigmoidRange', 'AdaptiveConcatPool2d', 'PoolType', 'adaptive_pool', 'PoolFlatten',
'NormType', 'BatchNorm', 'InstanceNorm', 'BatchNorm1dFlat', 'LinBnDrop', 'sigmoid', 'sigmoid_',
'vleaky_relu', 'init_default', 'init_linear', 'ConvLayer', 'AdaptiveAvgPool', 'MaxPool', 'AvgPool',
'BaseLoss', 'CrossEntropyLossFlat', 'BCEWithLogitsLossFlat', 'BCELossFlat', 'MSELossFlat', 'L1LossFlat',
'LabelSmoothingCrossEntropy', 'trunc_normal_', 'Embedding', 'SelfAttention', 'PooledSelfAttention2d',
'SimpleSelfAttention', 'icnr_init', 'PixelShuffle_ICNR', 'sequential', 'SequentialEx', 'MergeLayer', 'Cat',
'SimpleCNN', 'ProdLayer', 'inplace_relu', 'SEModule', 'ResBlock', 'SEBlock', 'SEResNeXtBlock',
'SeparableBlock', 'swish', 'Swish', 'MishJitAutoFn', 'mish', 'Mish', 'ParameterModule',
'children_and_parameters', 'flatten_model', 'NoneReduce', 'in_channels']
# Cell
from .imports import *
from .torch_imports import *
from .torch_core import *
from torch.nn.utils import weight_norm, spectral_norm
# Cell
def module(*flds, **defaults):
"Decorator to create an `nn.Module` using `f` as `forward` method"
pa = [inspect.Parameter(o, inspect.Parameter.POSITIONAL_OR_KEYWORD) for o in flds]
pb = [inspect.Parameter(k, inspect.Parameter.POSITIONAL_OR_KEYWORD, default=v)
for k,v in defaults.items()]
params = pa+pb
all_flds = [*flds,*defaults.keys()]
def _f(f):
class c(nn.Module):
def __init__(self, *args, **kwargs):
for i,o in enumerate(args): kwargs[all_flds[i]] = o
kwargs = merge(defaults,kwargs)
for k,v in kwargs.items(): setattr(self,k,v)
__repr__ = basic_repr(all_flds)
forward = f
c.__signature__ = inspect.Signature(params)
c.__name__ = c.__qualname__ = f.__name__
c.__doc__ = f.__doc__
return c
return _f
# Cell
def Identity(self, x):
"Do nothing at all"
return x
# Cell
def Lambda(self, x):
"An easy way to create a pytorch layer for a simple `func`"
return self.func(x)
# Cell
class PartialLambda(Lambda):
"Layer that applies `partial(func, **kwargs)`"
def __init__(self, func, **kwargs):
super().__init__(partial(func, **kwargs))
self.repr = f'{func.__name__}, {kwargs}'
def forward(self, x): return self.func(x)
def __repr__(self): return f'{self.__class__.__name__}({self.repr})'
# Cell
def Flatten(self, x):
"Flatten `x` to a single dimension, e.g. at end of a model. `full` for rank-1 tensor"
return x.view(-1) if self.full else x.view(x.size(0), -1)
# Cell
class View(Module):
"Reshape `x` to `size`"
def __init__(self, *size): self.size = size
def forward(self, x): return x.view(self.size)
# Cell
class ResizeBatch(Module):
"Reshape `x` to `size`, keeping batch dim the same size"
def __init__(self, *size): self.size = size
def forward(self, x): return x.view((x.size(0),) + self.size)
# Cell
def Debugger(self,x):
"A module to debug inside a model."
return x
# Cell
def sigmoid_range(x, low, high):
"Sigmoid function with range `(low, high)`"
return torch.sigmoid(x) * (high - low) + low
# Cell
def SigmoidRange(self, x):
"Sigmoid module with range `(low, high)`"
return sigmoid_range(x, self.low, self.high)
# Cell
class AdaptiveConcatPool2d(nn.Module):
"Layer that concats `AdaptiveAvgPool2d` and `AdaptiveMaxPool2d`"
def __init__(self, size=None):
self.size = size or 1
self.ap = nn.AdaptiveAvgPool2d(self.size) = nn.AdaptiveMaxPool2d(self.size)
def forward(self, x): return[, self.ap(x)], 1)
# Cell
class PoolType: Avg,Max,Cat = 'Avg','Max','Cat'
# Cell
def adaptive_pool(pool_type):
return nn.AdaptiveAvgPool2d if pool_type=='Avg' else nn.AdaptiveMaxPool2d if pool_type=='Max' else AdaptiveConcatPool2d
# Cell
class PoolFlatten(nn.Sequential):
"Combine `nn.AdaptiveAvgPool2d` and `Flatten`."
def __init__(self, pool_type=PoolType.Avg): super().__init__(adaptive_pool(pool_type)(1), Flatten())
# Cell
NormType = Enum('NormType', 'Batch BatchZero Weight Spectral Instance InstanceZero')
# Cell
def _get_norm(prefix, nf, ndim=2, zero=False, **kwargs):
"Norm layer with `nf` features and `ndim` initialized depending on `norm_type`."
assert 1 <= ndim <= 3
bn = getattr(nn, f"{prefix}{ndim}d")(nf, **kwargs)
if bn.affine: if zero else 1.)
return bn
# Cell
def BatchNorm(nf, ndim=2, norm_type=NormType.Batch, **kwargs):
"BatchNorm layer with `nf` features and `ndim` initialized depending on `norm_type`."
return _get_norm('BatchNorm', nf, ndim, zero=norm_type==NormType.BatchZero, **kwargs)
# Cell
def InstanceNorm(nf, ndim=2, norm_type=NormType.Instance, affine=True, **kwargs):
"InstanceNorm layer with `nf` features and `ndim` initialized depending on `norm_type`."
return _get_norm('InstanceNorm', nf, ndim, zero=norm_type==NormType.InstanceZero, affine=affine, **kwargs)
# Cell
class BatchNorm1dFlat(nn.BatchNorm1d):
"`nn.BatchNorm1d`, but first flattens leading dimensions"
def forward(self, x):
if x.dim()==2: return super().forward(x)
*f,l = x.shape
x = x.contiguous().view(-1,l)
return super().forward(x).view(*f,l)
# Cell
class LinBnDrop(nn.Sequential):
"Module grouping `BatchNorm1d`, `Dropout` and `Linear` layers"
def __init__(self, n_in, n_out, bn=True, p=0., act=None, lin_first=False):
layers = [BatchNorm(n_out if lin_first else n_in, ndim=1)] if bn else []
if p != 0: layers.append(nn.Dropout(p))
lin = [nn.Linear(n_in, n_out, bias=not bn)]
if act is not None: lin.append(act)
layers = lin+layers if lin_first else layers+lin
# Cell
def sigmoid(input, eps=1e-7):
"Same as `torch.sigmoid`, plus clamping to `(eps,1-eps)"
return input.sigmoid().clamp(eps,1-eps)
# Cell
def sigmoid_(input, eps=1e-7):
"Same as `torch.sigmoid_`, plus clamping to `(eps,1-eps)"
return input.sigmoid_().clamp_(eps,1-eps)
# Cell
from torch.nn.init import kaiming_uniform_,uniform_,xavier_uniform_,normal_
# Cell
def vleaky_relu(input, inplace=True):
"`F.leaky_relu` with 0.3 slope"
return F.leaky_relu(input, negative_slope=0.3, inplace=inplace)
# Cell
for o in F.relu,nn.ReLU,F.relu6,nn.ReLU6,F.leaky_relu,nn.LeakyReLU:
o.__default_init__ = kaiming_uniform_
# Cell
for o in F.sigmoid,nn.Sigmoid,F.tanh,nn.Tanh,sigmoid,sigmoid_:
o.__default_init__ = xavier_uniform_
# Cell
def init_default(m, func=nn.init.kaiming_normal_):
"Initialize `m` weights with `func` and set `bias` to 0."
if func and hasattr(m, 'weight'): func(m.weight)
with torch.no_grad():
if getattr(m, 'bias', None) is not None: m.bias.fill_(0.)
return m
# Cell
def init_linear(m, act_func=None, init='auto', bias_std=0.01):
if getattr(m,'bias',None) is not None and bias_std is not None: normal_(m.bias, 0, bias_std)
if init=='auto':
if act_func in (F.relu_,F.leaky_relu_): init = kaiming_uniform_
else: init = getattr(act_func.__class__, '__default_init__', None)
if init is None: init = getattr(act_func, '__default_init__', None)
if init is not None: init(m.weight)
# Cell
def _conv_func(ndim=2, transpose=False):
"Return the proper conv `ndim` function, potentially `transposed`."
assert 1 <= ndim <=3
return getattr(nn, f'Conv{"Transpose" if transpose else ""}{ndim}d')
# Cell
# Cell
class ConvLayer(nn.Sequential):
"Create a sequence of convolutional (`ni` to `nf`), ReLU (if `use_activ`) and `norm_type` layers."
def __init__(self, ni, nf, ks=3, stride=1, padding=None, bias=None, ndim=2, norm_type=NormType.Batch, bn_1st=True,
act_cls=defaults.activation, transpose=False, init='auto', xtra=None, bias_std=0.01, **kwargs):
if padding is None: padding = ((ks-1)//2 if not transpose else 0)
bn = norm_type in (NormType.Batch, NormType.BatchZero)
inn = norm_type in (NormType.Instance, NormType.InstanceZero)
if bias is None: bias = not (bn or inn)
conv_func = _conv_func(ndim, transpose=transpose)
conv = conv_func(ni, nf, kernel_size=ks, bias=bias, stride=stride, padding=padding, **kwargs)
act = None if act_cls is None else act_cls()
init_linear(conv, act, init=init, bias_std=bias_std)
if norm_type==NormType.Weight: conv = weight_norm(conv)
elif norm_type==NormType.Spectral: conv = spectral_norm(conv)
layers = [conv]
act_bn = []
if act is not None: act_bn.append(act)
if bn: act_bn.append(BatchNorm(nf, norm_type=norm_type, ndim=ndim))
if inn: act_bn.append(InstanceNorm(nf, norm_type=norm_type, ndim=ndim))
if bn_1st: act_bn.reverse()
layers += act_bn
if xtra: layers.append(xtra)
# Cell
def AdaptiveAvgPool(sz=1, ndim=2):
"nn.AdaptiveAvgPool layer for `ndim`"
assert 1 <= ndim <= 3
return getattr(nn, f"AdaptiveAvgPool{ndim}d")(sz)
# Cell
def MaxPool(ks=2, stride=None, padding=0, ndim=2, ceil_mode=False):
"nn.MaxPool layer for `ndim`"
assert 1 <= ndim <= 3
return getattr(nn, f"MaxPool{ndim}d")(ks, stride=stride, padding=padding)
# Cell
def AvgPool(ks=2, stride=None, padding=0, ndim=2, ceil_mode=False):
"nn.AvgPool layer for `ndim`"
assert 1 <= ndim <= 3
return getattr(nn, f"AvgPool{ndim}d")(ks, stride=stride, padding=padding, ceil_mode=ceil_mode)
# Cell
class BaseLoss():
"Same as `loss_cls`, but flattens input and target."
_methods = "activation decodes".split()
def __init__(self, loss_cls, *args, axis=-1, flatten=True, floatify=False, is_2d=True, **kwargs):
store_attr(self, "axis,flatten,floatify,is_2d")
self.func = loss_cls(*args,**kwargs)
functools.update_wrapper(self, self.func)
def __repr__(self): return f"FlattenedLoss of {self.func}"
def reduction(self): return self.func.reduction
def reduction(self, v): self.func.reduction = v
def __call__(self, inp, targ, **kwargs):
inp = inp .transpose(self.axis,-1).contiguous()
targ = targ.transpose(self.axis,-1).contiguous()
if self.floatify and targ.dtype!=torch.float16: targ = targ.float()
if targ.dtype in [torch.int8, torch.int16, torch.int32]: targ = targ.long()
if self.flatten: inp = inp.view(-1,inp.shape[-1]) if self.is_2d else inp.view(-1)
return self.func.__call__(inp, targ.view(-1) if self.flatten else targ, **kwargs)
# Cell
class CrossEntropyLossFlat(BaseLoss):
"Same as `nn.CrossEntropyLoss`, but flattens input and target."
y_int = True
def __init__(self, *args, axis=-1, **kwargs): super().__init__(nn.CrossEntropyLoss, *args, axis=axis, **kwargs)
def decodes(self, x): return x.argmax(dim=self.axis)
def activation(self, x): return F.softmax(x, dim=self.axis)
# Cell
class BCEWithLogitsLossFlat(BaseLoss):
"Same as `nn.CrossEntropyLoss`, but flattens input and target."
def __init__(self, *args, axis=-1, floatify=True, thresh=0.5, **kwargs):
super().__init__(nn.BCEWithLogitsLoss, *args, axis=axis, floatify=floatify, is_2d=False, **kwargs)
self.thresh = thresh
def decodes(self, x): return x>self.thresh
def activation(self, x): return torch.sigmoid(x)
# Cell
def BCELossFlat(*args, axis=-1, floatify=True, **kwargs):
"Same as `nn.BCELoss`, but flattens input and target."
return BaseLoss(nn.BCELoss, *args, axis=axis, floatify=floatify, is_2d=False, **kwargs)
# Cell
def MSELossFlat(*args, axis=-1, floatify=True, **kwargs):
"Same as `nn.MSELoss`, but flattens input and target."
return BaseLoss(nn.MSELoss, *args, axis=axis, floatify=floatify, is_2d=False, **kwargs)
# Cell
def L1LossFlat(*args, axis=-1, floatify=True, **kwargs):
"Same as `nn.MSELoss`, but flattens input and target."
return BaseLoss(nn.L1Loss, *args, axis=axis, floatify=floatify, is_2d=False, **kwargs)
# Cell
class LabelSmoothingCrossEntropy(Module):
y_int = True
def __init__(self, eps:float=0.1, reduction='mean'): self.eps,self.reduction = eps,reduction
def forward(self, output, target):
c = output.size()[-1]
log_preds = -F.log_softmax(output, dim=-1)
if self.reduction=='sum': loss = log_preds.sum()
loss = log_preds.sum(dim=-1) #We divide by that size at the return line so sum and not mean
if self.reduction=='mean': loss = loss.mean()
return loss*self.eps/c + (1-self.eps) * F.nll_loss(log_preds, target.long(), reduction=self.reduction)
def activation(self, out): return F.softmax(out, dim=-1)
def decodes(self, out): return out.argmax(dim=-1)
# Cell
def trunc_normal_(x, mean=0., std=1.):
"Truncated normal initialization (approximation)"
# From
return x.normal_().fmod_(2).mul_(std).add_(mean)
# Cell
class Embedding(nn.Embedding):
"Embedding layer with truncated normal initialization"
def __init__(self, ni, nf):
super().__init__(ni, nf)
trunc_normal_(, std=0.01)
# Cell
class SelfAttention(nn.Module):
"Self attention layer for `n_channels`."
def __init__(self, n_channels):
self.query,self.key,self.value = [self._conv(n_channels, c) for c in (n_channels//8,n_channels//8,n_channels)]
self.gamma = nn.Parameter(tensor([0.]))
def _conv(self,n_in,n_out):
return ConvLayer(n_in, n_out, ks=1, ndim=1, norm_type=NormType.Spectral, act_cls=None, bias=False)
def forward(self, x):
#Notation from the paper.
size = x.size()
x = x.view(*size[:2],-1)
f,g,h = self.query(x),self.key(x),self.value(x)
beta = F.softmax(torch.bmm(f.transpose(1,2), g), dim=1)
o = self.gamma * torch.bmm(h, beta) + x
return o.view(*size).contiguous()
# Cell
class PooledSelfAttention2d(nn.Module):
"Pooled self attention layer for 2d."
def __init__(self, n_channels):
self.n_channels = n_channels
self.query,self.key,self.value = [self._conv(n_channels, c) for c in (n_channels//8,n_channels//8,n_channels//2)]
self.out = self._conv(n_channels//2, n_channels)
self.gamma = nn.Parameter(tensor([0.]))
def _conv(self,n_in,n_out):
return ConvLayer(n_in, n_out, ks=1, norm_type=NormType.Spectral, act_cls=None, bias=False)
def forward(self, x):
n_ftrs = x.shape[2]*x.shape[3]
f = self.query(x).view(-1, self.n_channels//8, n_ftrs)
g = F.max_pool2d(self.key(x), [2,2]).view(-1, self.n_channels//8, n_ftrs//4)
h = F.max_pool2d(self.value(x), [2,2]).view(-1, self.n_channels//2, n_ftrs//4)
beta = F.softmax(torch.bmm(f.transpose(1, 2), g), -1)
o = self.out(torch.bmm(h, beta.transpose(1,2)).view(-1, self.n_channels//2, x.shape[2], x.shape[3]))
return self.gamma * o + x
# Cell
def _conv1d_spect(ni:int, no:int, ks:int=1, stride:int=1, padding:int=0, bias:bool=False):
"Create and initialize a `nn.Conv1d` layer with spectral normalization."
conv = nn.Conv1d(ni, no, ks, stride=stride, padding=padding, bias=bias)
if bias:
return spectral_norm(conv)
# Cell
class SimpleSelfAttention(Module):
def __init__(self, n_in:int, ks=1, sym=False):
self.sym,self.n_in = sym,n_in
self.conv = _conv1d_spect(n_in, n_in, ks, padding=ks//2, bias=False)
self.gamma = nn.Parameter(tensor([0.]))
def forward(self,x):
if self.sym:
c = self.conv.weight.view(self.n_in,self.n_in)
c = (c + c.t())/2
self.conv.weight = c.view(self.n_in,self.n_in,1)
size = x.size()
x = x.view(*size[:2],-1)
convx = self.conv(x)
xxT = torch.bmm(x,x.permute(0,2,1).contiguous())
o = torch.bmm(xxT, convx)
o = self.gamma * o + x
return o.view(*size).contiguous()
# Cell
def icnr_init(x, scale=2, init=nn.init.kaiming_normal_):
"ICNR init of `x`, with `scale` and `init` function"
ni,nf,h,w = x.shape
ni2 = int(ni/(scale**2))
k = init(x.new_zeros([ni2,nf,h,w])).transpose(0, 1)
k = k.contiguous().view(ni2, nf, -1)
k = k.repeat(1, 1, scale**2)
return k.contiguous().view([nf,ni,h,w]).transpose(0, 1)
# Cell
class PixelShuffle_ICNR(nn.Sequential):
"Upsample by `scale` from `ni` filters to `nf` (default `ni`), using `nn.PixelShuffle`."
def __init__(self, ni, nf=None, scale=2, blur=False, norm_type=NormType.Weight, act_cls=defaults.activation):
nf = ifnone(nf, ni)
layers = [ConvLayer(ni, nf*(scale**2), ks=1, norm_type=norm_type, act_cls=act_cls, bias_std=0),
if blur: layers += [nn.ReplicationPad2d((1,0,1,0)), nn.AvgPool2d(2, stride=1)]
# Cell
def sequential(*args):
"Create an `nn.Sequential`, wrapping items with `Lambda` if needed"
if len(args) != 1 or not isinstance(args[0], OrderedDict):
args = list(args)
for i,o in enumerate(args):
if not isinstance(o,nn.Module): args[i] = Lambda(o)
return nn.Sequential(*args)
# Cell
class SequentialEx(Module):
"Like `nn.Sequential`, but with ModuleList semantics, and can access module input"
def __init__(self, *layers): self.layers = nn.ModuleList(layers)
def forward(self, x):
res = x
for l in self.layers:
res.orig = x
nres = l(res)
# We have to remove res.orig to avoid hanging refs and therefore memory leaks
res.orig = None
res = nres
return res
def __getitem__(self,i): return self.layers[i]
def append(self,l): return self.layers.append(l)
def extend(self,l): return self.layers.extend(l)
def insert(self,i,l): return self.layers.insert(i,l)
# Cell
class MergeLayer(Module):
"Merge a shortcut with the result of the module by adding them or concatenating them if `dense=True`."
def __init__(self, dense:bool=False): self.dense=dense
def forward(self, x): return[x,x.orig], dim=1) if self.dense else (x+x.orig)
# Cell
class Cat(nn.ModuleList):
"Concatenate layers outputs over a given dim"
def __init__(self, layers, dim=1):
def forward(self, x): return[l(x) for l in self], dim=self.dim)
# Cell
class SimpleCNN(nn.Sequential):
"Create a simple CNN with `filters`."
def __init__(self, filters, kernel_szs=None, strides=None, bn=True):
nl = len(filters)-1
kernel_szs = ifnone(kernel_szs, [3]*nl)
strides = ifnone(strides , [2]*nl)
layers = [ConvLayer(filters[i], filters[i+1], kernel_szs[i], stride=strides[i],
norm_type=(NormType.Batch if bn and i<nl-1 else None)) for i in range(nl)]
# Cell
class ProdLayer(Module):
"Merge a shortcut with the result of the module by multiplying them."
def forward(self, x): return x * x.orig
# Cell
inplace_relu = partial(nn.ReLU, inplace=True)
# Cell
def SEModule(ch, reduction, act_cls=defaults.activation):
nf = math.ceil(ch//reduction/8)*8
return SequentialEx(nn.AdaptiveAvgPool2d(1),
ConvLayer(ch, nf, ks=1, norm_type=None, act_cls=act_cls),
ConvLayer(nf, ch, ks=1, norm_type=None, act_cls=nn.Sigmoid),
# Cell
class ResBlock(nn.Module):
"Resnet block from `ni` to `nh` with `stride`"
def __init__(self, expansion, ni, nf, stride=1, groups=1, reduction=None, nh1=None, nh2=None, dw=False, g2=1,
sa=False, sym=False, norm_type=NormType.Batch, act_cls=defaults.activation, ndim=2, ks=3,
pool=AvgPool, pool_first=True, **kwargs):
norm2 = (NormType.BatchZero if norm_type==NormType.Batch else
NormType.InstanceZero if norm_type==NormType.Instance else norm_type)
if nh2 is None: nh2 = nf
if nh1 is None: nh1 = nh2
nf,ni = nf*expansion,ni*expansion
k0 = dict(norm_type=norm_type, act_cls=act_cls, ndim=ndim, **kwargs)
k1 = dict(norm_type=norm2, act_cls=None, ndim=ndim, **kwargs)
convpath = [ConvLayer(ni, nh2, ks, stride=stride, groups=ni if dw else groups, **k0),
ConvLayer(nh2, nf, ks, groups=g2, **k1)
] if expansion == 1 else [
ConvLayer(ni, nh1, 1, **k0),
ConvLayer(nh1, nh2, ks, stride=stride, groups=nh1 if dw else groups, **k0),
ConvLayer(nh2, nf, 1, groups=g2, **k1)]
if reduction: convpath.append(SEModule(nf, reduction=reduction, act_cls=act_cls))
if sa: convpath.append(SimpleSelfAttention(nf,ks=1,sym=sym))
self.convpath = nn.Sequential(*convpath)
idpath = []
if ni!=nf: idpath.append(ConvLayer(ni, nf, 1, act_cls=None, ndim=ndim, **kwargs))
if stride!=1: idpath.insert((1,0)[pool_first], pool(2, ndim=ndim, ceil_mode=True))
self.idpath = nn.Sequential(*idpath)
self.act = defaults.activation(inplace=True) if act_cls is defaults.activation else act_cls()
def forward(self, x): return self.act(self.convpath(x) + self.idpath(x))
# Cell
def SEBlock(expansion, ni, nf, groups=1, reduction=16, stride=1, **kwargs):
return ResBlock(expansion, ni, nf, stride=stride, groups=groups, reduction=reduction, nh1=nf*2, nh2=nf*expansion, **kwargs)
# Cell
def SEResNeXtBlock(expansion, ni, nf, groups=32, reduction=16, stride=1, base_width=4, **kwargs):
w = math.floor(nf * (base_width / 64)) * groups
return ResBlock(expansion, ni, nf, stride=stride, groups=groups, reduction=reduction, nh2=w, **kwargs)
# Cell
def SeparableBlock(expansion, ni, nf, reduction=16, stride=1, base_width=4, **kwargs):
return ResBlock(expansion, ni, nf, stride=stride, reduction=reduction, nh2=nf*2, dw=True, **kwargs)
# Cell
from torch.jit import script
def _swish_jit_fwd(x): return x.mul(torch.sigmoid(x))
def _swish_jit_bwd(x, grad_output):
x_sigmoid = torch.sigmoid(x)
return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid)))
class _SwishJitAutoFn(torch.autograd.Function):
def forward(ctx, x):
return _swish_jit_fwd(x)
def backward(ctx, grad_output):
x = ctx.saved_variables[0]
return _swish_jit_bwd(x, grad_output)
# Cell
def swish(x, inplace=False): return _SwishJitAutoFn.apply(x)
# Cell
class Swish(Module):
def forward(self, x): return _SwishJitAutoFn.apply(x)
# Cell
def _mish_jit_fwd(x): return x.mul(torch.tanh(F.softplus(x)))
def _mish_jit_bwd(x, grad_output):
x_sigmoid = torch.sigmoid(x)
x_tanh_sp = F.softplus(x).tanh()
return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp))
class MishJitAutoFn(torch.autograd.Function):
def forward(ctx, x):
return _mish_jit_fwd(x)
def backward(ctx, grad_output):
x = ctx.saved_variables[0]
return _mish_jit_bwd(x, grad_output)
# Cell
def mish(x): return MishJitAutoFn.apply(x)
# Cell
class Mish(Module):
def forward(self, x): return MishJitAutoFn.apply(x)
# Cell
for o in swish,Swish,mish,Mish: o.__default_init__ = kaiming_uniform_
# Cell
class ParameterModule(Module):
"Register a lone parameter `p` in a module."
def __init__(self, p): self.val = p
def forward(self, x): return x
# Cell
def children_and_parameters(m):
"Return the children of `m` and its direct parameters not registered in modules."
children = list(m.children())
children_p = sum([[id(p) for p in c.parameters()] for c in m.children()],[])
for p in m.parameters():
if id(p) not in children_p: children.append(ParameterModule(p))
return children
# Cell
def _has_children(m:nn.Module):
try: next(m.children())
except StopIteration: return False
return True
nn.Module.has_children = property(_has_children)
# Cell
def flatten_model(m):
"Return the list of all submodules and parameters of `m`"
return sum(map(flatten_model,children_and_parameters(m)),[]) if m.has_children else [m]
# Cell
class NoneReduce():
"A context manager to evaluate `loss_func` with none reduce."
def __init__(self, loss_func): self.loss_func,self.old_red = loss_func,None
def __enter__(self):
if hasattr(self.loss_func, 'reduction'):
self.old_red = self.loss_func.reduction
self.loss_func.reduction = 'none'
return self.loss_func
else: return partial(self.loss_func, reduction='none')
def __exit__(self, type, value, traceback):
if self.old_red is not None: self.loss_func.reduction = self.old_red
# Cell
def in_channels(m):
"Return the shape of the first weight layer in `m`."
for l in flatten_model(m):
if getattr(l, 'weight', None) is not None and l.weight.ndim==4:
return l.weight.shape[1]
raise Exception('No weight layer')
You can’t perform that action at this time.