In [None]:
#|default_exp layers

### Notebook to contain customised versions of layers and torch functions for model construction

In [None]:
#|export
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms.functional as TF,torch.nn.functional as F
import fastcore.test as fct

In [None]:
#|export
class GeneralRelu(nn.Module):
    """ Extension of leaky relu with the option to limit the max value as well as subtract a constant from the 
    output of the leaky relu (presumably to move the transition point away from zero
    """
    def __init__(self, leak=None, sub=None, maxv=None):
        super().__init__()
        self.leak, self.sub, self. maxv = leak, sub, maxv
        
    def forward(self, x):
        x = F.leaky_relu(x, self.leak) if self.leak is not None else F.relu(x)
        if self.sub is not None: x -= self.sub
        if self.maxv is not None: x = x.clamp_max_(self.maxv)
        return x

#### Tests for GeneralReLu

Test are needed to:
1. Check that the value returned below sub is sub
2. Check that values between sub and maxv return the identity
3. Check that values above maxv will return maxv
4. Check that if leak is not supplied then it behaves in the same way as relu
5. Check that it works properly if sub and maxv are None

In [None]:
# Check similar to Relu with no options
gru = GeneralRelu(leak=None, sub=None, maxv=None)
high_val = 50000.0
low_val = -50000.0
eps = 1.e-9
fct.is_close(gru(torch.tensor(-1.e-6)).numpy(), 0., eps=eps)
fct.is_close(gru(torch.tensor(high_val)).numpy(), high_val, eps)
fct.is_close(gru(torch.tensor(low_val)).numpy(), 0., eps)

True

In [None]:
# Check leak working with sub and maxv
leaky_slope = 0.05
sub=0.1
maxv=5.0
gru = GeneralRelu(leak=leaky_slope, sub=sub, maxv=maxv)
x = np.zeros(4, dtype=np.float32)
y = np.zeros(4, dtype=np.float32)
# Check max value is limited to maxv
x[0] = 50000.0
y[0] = maxv
# Check that at input of max v the value returned is reduced by sub
x[1] = maxv
y[1] = maxv - sub
# Check that at input of zero then -sub is returned
x[2] = 0.
y[2] = -sub
# Check that for an aribtary negative value the value returned is correct for slope and subtraction
x[3] = -5.
y[3] = (-5. * leaky_slope - sub)
# Note - might be desirable to check that negative value of leaky slope raises an exception, as does positive value of
# sub or negative value of maxv.  Need to build in checks first though

fct.is_close(gru(torch.tensor(x)).numpy(), y, eps=eps)

True

### Export

In [None]:
import nbdev; nbdev.nbdev_export()