ReLU activation function is one of the breakthrought in deep neural network due to its simplcity and the fact that its locals derivative equals to $1$ for non-negative value, hence prevent the vanishing gradient problem. Despite it efectiveness and simplicity, some effort on make ReLU more expressive has been made. Some effort to  generalize  ReLU are based on using a nonzero slopes $\alpha$ for negative input. As for November 2020, there are 3 variant of ReLU that do this kind of generalization. Absolute value rectification fixes $\alpha =-1$. [LeakyReLU](https://ai.stanford.edu/~amaas/papers/relu_hybrid_icml2013_final.pdf) fixes $\alpha$ to a small value, while [PReLU](https://arxiv.org/pdf/1502.01852.pdf) treats $\alpha$ as a learnable parameter.

Let us begin by the formal definition of Dynamic ReLU based on the [paper](https://arxiv.org/pdf/2003.10027.pdf)


**Definition** : Let  us  denote  the  traditional  or  static  ReLU  as $\textbf{y}=  max\{\textbf{x},0\}$, where $\textbf{x}$ is the input vector. For the input $\textbf{x}_c$ at the $c$ thchannel, the activation is  computed  as $y_c=  max\{x_c,0\}$. ReLU  can  be  generalized  to  a  parametric piecewise linear function $y_c= max_k\{a_k^c(x_c)+b_k^c\}$. Dynamic ReLU further extend this piecewise linear function from static to dynamic by adapting $a_k^c$, $b_k^c$ based upon all input elements $x=\{x_c\}$ as follows: $y_c=f_{\theta(x)}(x_c) =   max_{1≤k≤K}\{a^k_c(\mathbf{x})x_c+b^k_c(\mathbf{x})\}$where the coefficients $(a^k_c,b^k_c)$ are the output of a hyper function $\theta(\mathbf{x})$ as:$[a^1_1,...,a^1_C,...,a^K_1,...,a^K_C,b^1_1,...,b^1_C,...,b^K_1,\dots,b^K_C]^{\mathsf{T}}=\theta(\mathbf{x})$,where $K$ is the number of functions, and $C$ is the number of channels.

Let us rewrite vanila ReLU $\textbf{y}=  max\{\textbf{x},0\}$ to a generalized form from the definition above. It is easy to show that $K=2$ for this vanila ReLU and $\theta(\mathbf{x}) =[a^1_c(x), b^1_c(x) , a^2_c(x) , b^2_c(x)]^{\mathsf{T}} = [1,0,0,0]^{\mathsf{T}}$. Let us plug in some number to $y_c=f_{\theta(x)}(x_c)$. Let $x_c=-1$

$f_{\theta(x)}(-1) =   max_{1≤k≤2}\{a^k_c(\mathbf{x})(-1)+b^k_c(\mathbf{x})\}=a^2_c(\mathbf{x})(-1)+b^2_c(\mathbf{x})=0(-1)+0=0$




Now take a look about how to get the theta value

<img src="theta.jpg">

In [1]:
import torch
import torch.nn as nn


class Dynamic_ReLUB2D(nn.Module):
    def __init__(self, channels, reduction=4, k=2):
        super(Dynamic_ReLUB2D, self).__init__()
        self.channels = channels
        self.k = k
        self.fc1 = nn.Linear(channels, channels // reduction)
        self.relu = nn.ReLU(inplace=True)
        self.fc2 = nn.Linear(channels // reduction, 2 * k * channels)
        self.sigmoid = nn.Sigmoid()
        self.register_buffer('lambdas', torch.Tensor([1.] * k + [0.5] * k).float())
        self.register_buffer('init_v', torch.Tensor([1.] + [0.] * (2 * k - 1)).float())

    def compute_theta(self, x):
        theta = torch.mean(x, dim=-1)
        theta = torch.mean(theta, dim=-1)
        theta = self.fc1(theta)
        theta = self.relu(theta)
        theta = self.fc2(theta)
        theta = 2 * self.sigmoid(theta) - 1
        return theta

    def forward(self, x):
        assert x.shape[1] == self.channels
        theta = self.compute_theta(x)
        relu_coefs = theta.view(-1, self.channels, 2 * self.k) * self.lambdas + self.init_v
        x_perm = x.permute(2, 3, 0, 1).unsqueeze(-1)
        output = x_perm * relu_coefs[:, :, :self.k] + relu_coefs[:, :, self.k:]
        result = torch.max(output, dim=-1)[0].permute(2, 3, 0, 1)
        return result


In [2]:
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR


Then we create simple deep neural network for classifying FashionMNIST

In [3]:
class Net(nn.Module):
    def __init__(self, dynamic_relu=False):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.is_dynamic_relu = dynamic_relu
        self.relu1 = F.relu
        self.relu2 = F.relu
        if dynamic_relu:
            self.relu1 = Dynamic_ReLUB2D(32, k=3)
            self.relu2 = Dynamic_ReLUB2D(64, k=3)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output


In [4]:
def train(model, device, train_loader, optimizer, epoch, dry_run, log_interval):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), loss.item()))
            if dry_run:
                break

In [5]:
def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [6]:
def main(dynamic_relu=False):
    # Hyperparameters
    batch_size = 64
    test_batch_size = 1000
    epochs = 14
    lr = 1.0
    gamma = 0.7
    no_cuda = False
    dry_run = False
    seed = 1
    log_interval = 10
    save_model = False

    use_cuda = not no_cuda and torch.cuda.is_available()

    torch.manual_seed(seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    train_kwargs = {'batch_size': batch_size}
    test_kwargs = {'batch_size': test_batch_size}
    if use_cuda:
        cuda_kwargs = {'num_workers': 1,
                       'pin_memory': True,
                       'shuffle': True}
        train_kwargs.update(cuda_kwargs)
        test_kwargs.update(cuda_kwargs)

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    dataset1 = datasets.FashionMNIST('../data', train=True, download=True,
                                     transform=transform)
    dataset2 = datasets.FashionMNIST('../data', train=False,
                                     transform=transform)
    train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)
    test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

    model = Net(dynamic_relu=dynamic_relu).to(device)
    optimizer = optim.Adadelta(model.parameters(), lr=lr)

    scheduler = StepLR(optimizer, step_size=1, gamma=gamma)
    for epoch in range(1, epochs + 1):
        train(model, device, train_loader, optimizer, epoch,dry_run,log_interval)
        test(model, device, test_loader)
        scheduler.step()

    if save_model:
        torch.save(model.state_dict(), "mnist_cnn.pt")


In [7]:
main(False)


Test set: Average loss: 0.3101, Accuracy: 8874/10000 (89%)




Test set: Average loss: 0.2601, Accuracy: 9044/10000 (90%)


Test set: Average loss: 0.2495, Accuracy: 9139/10000 (91%)




Test set: Average loss: 0.2316, Accuracy: 9193/10000 (92%)


Test set: Average loss: 0.2192, Accuracy: 9208/10000 (92%)




Test set: Average loss: 0.2167, Accuracy: 9238/10000 (92%)




Test set: Average loss: 0.2145, Accuracy: 9253/10000 (93%)


Test set: Average loss: 0.2139, Accuracy: 9274/10000 (93%)




Test set: Average loss: 0.2128, Accuracy: 9270/10000 (93%)


Test set: Average loss: 0.2089, Accuracy: 9282/10000 (93%)




Test set: Average loss: 0.2104, Accuracy: 9280/10000 (93%)


Test set: Average loss: 0.2084, Accuracy: 9278/10000 (93%)




Test set: Average loss: 0.2090, Accuracy: 9277/10000 (93%)




Test set: Average loss: 0.2079, Accuracy: 9284/10000 (93%)



In [None]:
main(True)


Test set: Average loss: 0.3261, Accuracy: 8864/10000 (89%)




Test set: Average loss: 0.2423, Accuracy: 9141/10000 (91%)


Test set: Average loss: 0.2643, Accuracy: 9086/10000 (91%)




Test set: Average loss: 0.2514, Accuracy: 9233/10000 (92%)

