Skip to content
This repository has been archived by the owner on Apr 22, 2022. It is now read-only.

Commit

Permalink
Starting to implement the Convolutional Seq2Seq.
Browse files Browse the repository at this point in the history
  • Loading branch information
gugarosa committed May 30, 2020
1 parent 28c0ef4 commit 5dd3067
Show file tree
Hide file tree
Showing 3 changed files with 194 additions and 0 deletions.
74 changes: 74 additions & 0 deletions textformer/models/conv_seq2seq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import torch
from torch import distributions
from torchtext.data.metrics import bleu_score

import textformer.utils.logging as l
from textformer.core.model import Model
from textformer.models.decoders import LSTMDecoder
from textformer.models.encoders import ConvEncoder

logger = l.get_logger(__name__)


class ConvSeq2Seq(Model):
"""A ConvSeq2Seq class implements a Convolutional Sequence-To-Sequence learning architecture.
References:
J. Gehring, et al. Convolutional sequence to sequence learning.
Proceedings of the 34th International Conference on Machine Learning (2017).
"""

def __init__(self, n_input=128, n_output=128, n_hidden=128, n_embedding=128, n_layers=1, kernel_size=3,
dropout=0.5, max_length=100, ignore_token=None, init_weights=None, device='cpu'):
"""Initialization method.
Args:
n_input (int): Number of input units.
n_output (int): Number of output units.
n_hidden (int): Number of hidden units.
n_embedding (int): Number of embedding units.
n_layers (int): Number of convolutional layers.
kernel_size (int): Size of the convolutional kernels.
dropout (float): Amount of dropout to be applied.
max_length (int): Maximum length of positional embeddings.
ignore_token (int): The index of a token to be ignore by the loss function.
init_weights (tuple): Tuple holding the minimum and maximum values for weights initialization.
device (str): Device that model should be trained on, e.g., `cpu` or `cuda`.
"""

logger.info('Overriding class: Model -> ConvSeq2Seq.')

# Creating the encoder network
E = ConvEncoder(n_input, n_hidden, n_embedding, n_layers, kernel_size, dropout, max_length)

# Creating the decoder network
D = LSTMDecoder(n_output, n_hidden, n_embedding, n_layers, dropout)

# Overrides its parent class with any custom arguments if needed
super(ConvSeq2Seq, self).__init__(E, D, ignore_token, init_weights, device)

logger.info('Class overrided.')

def forward(self, x, y, teacher_forcing_ratio=0.0):
"""Performs a forward pass over the architecture.
Args:
x (torch.Tensor): Tensor containing the data.
y (torch.Tensor): Tensor containing the true labels.
teacher_forcing_ratio (float): Whether the next prediction should come
from the predicted sample or from the true labels.
Returns:
The predictions over the input tensor.
"""

# Performs the encoding
conv, output = self.E(x)

# Decodes the encoded inputs
preds, _ = self.decoder(y, conv, output)

return preds
1 change: 1 addition & 0 deletions textformer/models/encoders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
"""

from textformer.models.encoders.bi_gru import BiGRUEncoder
from textformer.models.encoders.conv import ConvEncoder
from textformer.models.encoders.gru import GRUEncoder
from textformer.models.encoders.lstm import LSTMEncoder
119 changes: 119 additions & 0 deletions textformer/models/encoders/conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
from torch import nn

import textformer.utils.logging as l
from textformer.core import Encoder

logger = l.get_logger(__name__)


class ConvEncoder(Encoder):
"""A ConvEncoder is used to supply the encoding part of the Convolutional Seq2Seq architecture.
"""

def __init__(self, n_input=128, n_hidden=128, n_embedding=128, n_layers=1, kernel_size=3, dropout=0.5, max_length=100):
"""Initializion method.
Args:
n_input (int): Number of input units.
n_hidden (int): Number of hidden units.
n_embedding (int): Number of embedding units.
n_layers (int): Number of convolutional layers.
kernel_size (int): Size of the convolutional kernels.
dropout (float): Amount of dropout to be applied.
max_length (int): Maximum length of positional embeddings.
"""

logger.info('Overriding class: Encoder -> ConvEncoder.')

# Overriding its parent class
super(ConvEncoder, self).__init__()

# Number of input units
self.n_input = n_input

# Number of hidden units
self.n_hidden = n_hidden

# Number of embedding units
self.n_embedding = n_embedding

# Number of layers
self.n_layers = n_layers

# Kernel size
if kernel_size % 2 == 0:
self.kernel_size = kernel_size + 1

# Maximum length of positional embeddings
self.max_length = max_length

#
self.scale = torch.sqrt(torch.FloatTensor([0.5]))

# Embedding layers
self.embedding = nn.Embedding(n_input, n_embedding)
self.pos_embedding = nn.Embedding(max_length, n_embedding)

# Fully connected layers
self.fc1 = nn.Linear(n_embedding, n_hidden)
self.fc2 = nn.Linear(n_hidden, n_embedding)

# Convolutional layers
self.conv = nn.ModuleList([nn.Conv1d(in_channels=n_hidden,
out_channels=2 * n_hidden,
kernel_size=kernel_size,
padding=(kernel_size - 1) // 2)
for _ in range(n_layers)])

# Dropout layer
self.dropout = nn.Dropout(dropout)

logger.debug(f'Size: ({self.n_input}, {self.n_hidden}) | Embeddings: {self.n_embedding} | Core: {self.conv}.')

def forward(self, x):
"""Performs a forward pass over the architecture.
Args:
x (torch.Tensor): Tensor containing the data.
Returns:
The hidden state and cell values.
"""

# Creates the positions tensor
pos = torch.arange(0, x.shape[1]).unsqueeze(0).repeat(x.shape[0], 1)

# Calculates the embedded outputs
x_embedded = self.embedding(x)
pos_embedded = self.pos_embedding(pos)

# Combines the embeddings
embedded = self.dropout(x_embedded + pos_embedded)

# Passing down to the first linear layer and permuting its dimension
conv = self.fc1(embedded).permute(0, 2, 1)

# For every convolutional layer
for i, c in enumerate(self.conv):
# Pass down through convolutional layer
conv = c(self.dropout(hidden))

#
conv = F.glu(conv, dim=1)

#
conv = (conv + hidden) * self.scale

#
hidden = conv

#
conv = self.fc2(conv.permute(0, 2, 1))

#
output = (conv + embedded) * self.scale

return conv, output

0 comments on commit 5dd3067

Please sign in to comment.