In [2]:
from typing import Callable

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [15]:
class GatedLinearUnit(nn.Module):
    def __init__(
        self,
        activation_fn: Callable = F.tanh,
        channel_dim: int = -1,
    ):
        super().__init__()
        self.activation_fn = activation_fn
        self.channel_dim = channel_dim

    def forward(self, x):
        x0, x1 = torch.chunk(x, 2, dim=self.channel_dim)

        act_o = self.activation_fn(x0)
        gate_o = F.sigmoid(x1)

        return act_o * gate_o

In [16]:
glu = GatedLinearUnit()

In [23]:
x = torch.tensor([10, 10, 0, -10, -10, 1, 1, 1, 1, 1])

In [24]:
x

tensor([ 10,  10,   0, -10, -10,   1,   1,   1,   1,   1])

In [25]:
glu(x)

tensor([ 0.7311,  0.7311,  0.0000, -0.7311, -0.7311])