Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

implemented gelu activations #829

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions pytext/config/module_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ class CNNParams(ConfigBase):
kernel_num: int = 100
# Kernel sizes to use in convolution
kernel_sizes: List[int] = [3, 4]
# Use weight norm in convolution
weight_norm: bool = False
# Enables dilated convolutions
dilated: bool = False
# Enables causal convolutions
causal: bool = False


class PoolingType(Enum):
Expand All @@ -42,3 +48,11 @@ class PerplexityType(Enum):
MEAN = "mean"
MEDIAN = "median"
EOS = "eos"


class Activation(Enum):
RELU = "relu"
LEAKYRELU = "leakyrelu"
TANH = "tanh"
GELU = "gelu"
GLU = "glu"
60 changes: 51 additions & 9 deletions pytext/models/representations/deepcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,25 @@

import torch
import torch.nn as nn
from pytext.config.module_config import CNNParams
from pytext.config.module_config import Activation, CNNParams
from pytext.models.representations.representation_base import RepresentationBase
from pytext.optimizer.activations import get_activation


class Trim1d(nn.Module):
"""
Trims a 1d convolutional output. Used to implement history-padding
by removing excess padding from the right.

"""

def __init__(self, trim):
super(Trim1d, self).__init__()

self.trim = trim

def forward(self, x):
return x[:, :, : -self.trim].contiguous()


class DeepCNNRepresentation(RepresentationBase):
Expand All @@ -24,34 +41,57 @@ class DeepCNNRepresentation(RepresentationBase):
class Config(RepresentationBase.Config):
cnn: CNNParams = CNNParams()
dropout: float = 0.3
activation: Activation = Activation.GLU

def __init__(self, config: Config, embed_dim: int) -> None:
super().__init__(config)

out_channels = config.cnn.kernel_num
kernel_sizes = config.cnn.kernel_sizes
weight_norm = config.cnn.weight_norm
dilated = config.cnn.dilated
causal = config.cnn.causal
activation = config.activation

conv_layers = []
trim_layers = []
linear_layers = []
in_channels = embed_dim

for k in kernel_sizes:
for i, k in enumerate(kernel_sizes):
assert (k - 1) % 2 == 0

proj = (
nn.Linear(in_channels, out_channels)
if in_channels != out_channels
else None
)
linear_layers.append(proj)

dilation = 2 ** i if dilated else 1
padding = (k - 1) * dilation if causal else ((k - 1) // 2) * dilation

single_conv = nn.Conv1d(
in_channels, 2 * out_channels, k, padding=int((k - 1) / 2)
in_channels,
(out_channels * 2 if activation == Activation.GLU else out_channels),
k,
padding=padding,
dilation=dilation,
)
single_conv = (
nn.utils.weight_norm(single_conv) if weight_norm else single_conv
)
conv_layers.append(single_conv)

trim = Trim1d(padding) if causal else None
trim_layers.append(trim)

in_channels = out_channels

self.convs = nn.ModuleList(conv_layers)
self.trims = nn.ModuleList(trim_layers)
self.projections = nn.ModuleList(linear_layers)
self.glu = nn.GLU(dim=1)
self.activation = get_activation(activation)

self.representation_dim = out_channels
self.dropout = nn.Dropout(p=config.dropout)
Expand All @@ -60,13 +100,15 @@ def forward(self, inputs: torch.Tensor, *args) -> torch.Tensor:
inputs = self.dropout(inputs)
# bsz * seq_len * embed_dim -> bsz * embed_dim * seq_len
words = inputs.permute(0, 2, 1)
for conv, proj in zip(self.convs, self.projections):
if proj is None:
residual = words
else:
for conv, trim, proj in zip(self.convs, self.trims, self.projections):
if proj:
tranposed = words.permute(0, 2, 1)
residual = proj(tranposed).permute(0, 2, 1)
else:
residual = words
words = conv(words)
words = self.glu(words)
if trim:
words = trim(words)
words = self.activation(words)
words = (words + residual) * math.sqrt(0.5)
return words.permute(0, 2, 1)
42 changes: 42 additions & 0 deletions pytext/optimizer/activations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import math

import torch
import torch.nn as nn
from pytext.config.module_config import Activation


class GeLU(nn.Module):
"""
Implements Gaussian Error Linear Units (GELUs). Note: x * x * x is used
instead of torch.pow(x, 3) due to issues with ONNX compatibility:
https://github.com/pytorch/pytorch/issues/18475

Reference:
Gaussian Error Linear Units (GELUs). Dan Hendrycks, Kevin Gimpel.
Technical Report, 2017. https://arxiv.org/pdf/1606.08415.pdf
"""

def forward(self, x):
return (
0.5
* x
* (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * (x * x * x))))
)


def get_activation(name):
if name == Activation.RELU:
return nn.ReLU()
elif name == Activation.LEAKYRELU:
return nn.LeakyReLU()
elif name == Activation.TANH:
return torch.tanh
elif name == Activation.GELU:
return GeLU()
elif name == Activation.GLU:
return nn.GLU(dim=1)
else:
raise RuntimeError(f"{name} is not supported")