Skip to content
This repository has been archived by the owner on Apr 22, 2022. It is now read-only.

Commit

Permalink
Adding Residual Attention layer.
Browse files Browse the repository at this point in the history
  • Loading branch information
gugarosa committed Jun 2, 2020
1 parent 9e68850 commit 268af6d
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 1 deletion.
2 changes: 1 addition & 1 deletion textformer/models/encoders/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def forward(self, x):
x (torch.Tensor): Tensor containing the data.
Returns:
The hidden state and cell values.
The convolutions and output values.
"""

Expand Down
1 change: 1 addition & 0 deletions textformer/models/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
"""

from textformer.models.layers.attention import Attention
from textformer.models.layers.residual_attention import ResidualAttention
4 changes: 4 additions & 0 deletions textformer/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
class Attention(nn.Module):
"""An Attention class is used to provide attention-based mechanisms in a neural network layer.
References:
D. Bahdanau, K. Cho, Y. Bengio. Neural machine translation by jointly learning to align and translate.
Preprint arXiv:1409.0473 (2014).
"""

def __init__(self, n_hidden_enc, n_hidden_dec):
Expand Down
73 changes: 73 additions & 0 deletions textformer/models/layers/residual_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import torch
import torch.nn as nn
import torch.nn.functional as F


class ResidualAttention(nn.Module):
"""A ResidualAttention class is used to provide attention-based mechanisms
in a neural network layer among residual connections.
References:
F. Wang, et al. Residual attention network for image classification.
Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (2017).
"""

def __init__(self, n_hidden, n_embedding, scale):
"""Initialization method.
Args:
n_hidden (int): Number of hidden units.
n_embedding (int): Number of embedding units.
scale (float): Value for the residual learning.
"""

# Overriding its parent class
super(ResidualAttention, self).__init__()

# Defining the energy-based layer
self.e = nn.Linear(n_hidden, n_embedding)

# Defining the weight-based layer
self.v = nn.Linear(n_embedding, n_hidden)

# Defining the scale for the residual connections
self.scale = scale

def forward(self, emb, c, enc_c, enc_o):
"""Performs a forward pass over the layer.
Args:
emb (torch.Tensor): Tensor containing the embedded outputs.
c (torch.Tensor): Tensor containing the decoder convolutioned features.
enc_c (torch.Tensor): Tensor containing the encoder convolutioned features.
enc_o (torch.Tensor): Tensor containing the encoder outputs.
Returns:
The attention-based weights, as well as the residual attention-based weights.
"""

# Transforms to the embedding dimension
emb_c = self.e(c.permute(0, 2, 1))

# Combines the convolutional features and embeddings
combined = (emb_c + emb) * self.scale

# Calculating the energy between combined and convolutioned states
energy = torch.matmul(combined, enc_c.permute(0, 2, 1))

# Calculating the attention
attention = nn.functional.softmax(energy, dim=2)

# Encoding the attention with the combined output
encoded_attention = torch.matmul(attention, enc_o)

# Converting back to hidden dimension
encoded_attention = self.v(encoded_attention)

# Applying residual connections
residual_attention = (c + encoded_attention.permute(0, 2, 1)) * self.scale

return attention, residual_attention

0 comments on commit 268af6d

Please sign in to comment.