Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

Commit

Permalink
implemented gelu activations (#829)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #829

Implements Gaussian Error Linear Units (GELUs) as an activation for Deep CNN representation. Also creates an interface to leverage different types of activation functions.

Differential Revision: D16462672

fbshipit-source-id: 177dff360ef46f0e6041c76712e18eee89666b0a
  • Loading branch information
shreydesai authored and facebook-github-bot committed Jul 27, 2019
1 parent e8dca0f commit 3a5675a
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 4 deletions.
8 changes: 8 additions & 0 deletions pytext/config/module_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,11 @@ class PerplexityType(Enum):
MEAN = "mean"
MEDIAN = "median"
EOS = "eos"


class Activation(Enum):
RELU = "relu"
LEAKYRELU = "leakyrelu"
TANH = "tanh"
GELU = "gelu"
GLU = "glu"
15 changes: 11 additions & 4 deletions pytext/models/representations/deepcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@

import torch
import torch.nn as nn
from pytext.config.module_config import CNNParams
from pytext.config.module_config import Activation, CNNParams
from pytext.models.representations.representation_base import RepresentationBase
from pytext.optimizer.activations import get_activation


class Trim1d(nn.Module):
Expand Down Expand Up @@ -40,6 +41,7 @@ class DeepCNNRepresentation(RepresentationBase):
class Config(RepresentationBase.Config):
cnn: CNNParams = CNNParams()
dropout: float = 0.3
activation: Activation = Activation.GLU

def __init__(self, config: Config, embed_dim: int) -> None:
super().__init__(config)
Expand All @@ -49,6 +51,7 @@ def __init__(self, config: Config, embed_dim: int) -> None:
weight_norm = config.cnn.weight_norm
dilated = config.cnn.dilated
causal = config.cnn.causal
activation = config.activation

conv_layers = []
trim_layers = []
Expand All @@ -69,7 +72,11 @@ def __init__(self, config: Config, embed_dim: int) -> None:
padding = (k - 1) * dilation if causal else ((k - 1) // 2) * dilation

single_conv = nn.Conv1d(
in_channels, 2 * out_channels, k, padding=padding, dilation=dilation
in_channels,
(out_channels * 2 if activation == Activation.GLU else out_channels),
k,
padding=padding,
dilation=dilation,
)
single_conv = (
nn.utils.weight_norm(single_conv) if weight_norm else single_conv
Expand All @@ -84,7 +91,7 @@ def __init__(self, config: Config, embed_dim: int) -> None:
self.convs = nn.ModuleList(conv_layers)
self.trims = nn.ModuleList(trim_layers)
self.projections = nn.ModuleList(linear_layers)
self.glu = nn.GLU(dim=1)
self.activation = get_activation(activation)

self.representation_dim = out_channels
self.dropout = nn.Dropout(p=config.dropout)
Expand All @@ -102,6 +109,6 @@ def forward(self, inputs: torch.Tensor, *args) -> torch.Tensor:
words = conv(words)
if trim:
words = trim(words)
words = self.glu(words)
words = self.activation(words)
words = (words + residual) * math.sqrt(0.5)
return words.permute(0, 2, 1)
42 changes: 42 additions & 0 deletions pytext/optimizer/activations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import math

import torch
import torch.nn as nn
from pytext.config.module_config import Activation


class GeLU(nn.Module):
"""
Implements Gaussian Error Linear Units (GELUs). Note: x * x * x is used
instead of torch.pow(x, 3) due to issues with ONNX compatibility:
https://github.com/pytorch/pytorch/issues/18475
Reference:
Gaussian Error Linear Units (GELUs). Dan Hendrycks, Kevin Gimpel.
Technical Report, 2017. https://arxiv.org/pdf/1606.08415.pdf
"""

def forward(self, x):
return (
0.5
* x
* (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * (x * x * x))))
)


def get_activation(name):
if name == Activation.RELU:
return nn.ReLU()
elif name == Activation.LEAKYRELU:
return nn.LeakyReLU()
elif name == Activation.TANH:
return torch.tanh
elif name == Activation.GELU:
return GeLU()
elif name == Activation.GLU:
return nn.GLU(dim=1)
else:
raise RuntimeError(f"{name} is not supported")

0 comments on commit 3a5675a

Please sign in to comment.