In [1]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
import math

In [2]:
batch_size = 8
n_var = 4
embedding_dim = 24
embedding_output = torch.randn((batch_size, n_var, embedding_dim))

In [3]:
embedding_output.shape

torch.Size([8, 4, 24])

In [4]:
num_attention_heads = 2
attention_head_size = embedding_dim // 2

In [5]:
embedding_output.size()[:-1]

torch.Size([8, 4])

In [6]:
def transpose_for_scores(x):
    new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size)
    x = x.view(*new_x_shape)
    return x.permute(0, 2, 1, 3)

In [7]:
new_x_shape = embedding_output.size()[:-1]+ (num_attention_heads, attention_head_size)

In [8]:
new_x_shape

torch.Size([8, 4, 2, 12])

In [9]:
embedding_output = embedding_output.view(*new_x_shape)


In [10]:
embedding_output.shape

torch.Size([8, 4, 2, 12])

In [11]:
embedding_output.permute(0, 2, 1,3).shape

torch.Size([8, 2, 4, 12])

In [12]:
hidden_size = embedding_dim
all_head_size = embedding_dim

In [13]:
query = nn.Linear(hidden_size, all_head_size)
key = nn.Linear(hidden_size, all_head_size)
value = nn.Linear(hidden_size, all_head_size)
hidden_states = embedding_output

In [14]:
embedding_output = torch.randn((batch_size, n_var, embedding_dim))
hidden_states = embedding_output

In [15]:
mixed_query_layer = query(hidden_states)
mixed_key_layer = key(hidden_states)
mixed_value_layer = value(hidden_states)

In [16]:
mixed_query_layer.shape

torch.Size([8, 4, 24])

In [17]:
query_layer = transpose_for_scores(mixed_query_layer)
key_layer = transpose_for_scores(mixed_key_layer)
value_layer = transpose_for_scores(mixed_value_layer)

In [18]:
query_layer.shape

torch.Size([8, 2, 4, 12])

In [19]:
key_layer.transpose(-1, -2).shape

torch.Size([8, 2, 12, 4])

In [20]:
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

In [21]:
attention_scores.shape

torch.Size([8, 2, 4, 4])

In [22]:
attention_scores = attention_scores / math.sqrt(attention_head_size)


In [23]:
attention_probs = nn.Softmax(dim=-1)(attention_scores)

In [24]:
context_layer = torch.matmul(attention_probs, value_layer)

In [25]:
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()

In [26]:
context_layer.shape

torch.Size([8, 4, 2, 12])

In [27]:
new_context_layer_shape = context_layer.size()[:-2] + (all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)

In [28]:
context_layer.shape

torch.Size([8, 4, 24])

In [29]:
class SelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
            raise ValueError(
                "The hidden size (%d) is not a multiple of the number of attention "
                "heads (%d)" % (config.hidden_size, config.num_attention_heads)
            )
        self.output_attentions = config.output_attentions

        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)

        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, hidden_states):
        mixed_query_layer = self.query(hidden_states)
        mixed_key_layer = self.key(hidden_states)
        mixed_value_layer = self.value(hidden_states)

        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        attention_probs = nn.Softmax(dim=-1)(attention_scores)
        attention_probs = self.dropout(attention_probs)
        context_layer = torch.matmul(attention_probs, value_layer)

        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)

        outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer,)
        return outputs

In [30]:
class TransformerConfig:
    def __init__(self, 
                output_attentions = True,
                num_attention_heads = 2,
                hidden_size = 24,
                attention_probs_dropout_prob = 0.1,
                hidden_dropout_prob = 0.1,
                intermediate_size = 24):
        self.output_attentions = output_attentions
        self.num_attention_heads = num_attention_heads
        self.hidden_size = hidden_size
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
        self.hidden_dropout_prob = hidden_dropout_prob
        self.intermediate_size = intermediate_size

In [31]:
transformer_config = TransformerConfig()
self_attention = SelfAttention(transformer_config)

In [32]:
self_attention(embedding_output)[0].shape

torch.Size([8, 4, 24])

In [33]:
class TransformerOutput(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.LayerNorm = nn.LayerNorm(config.hidden_size)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states
    

In [34]:
class TransformerIntermediate(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
        self.intermediate_act_fn = F.relu

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.intermediate_act_fn(hidden_states)
        return hidden_states

In [35]:
class TransformerLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.attention = SelfAttention(config)
        self.output = TransformerOutput(config)
        self.intermediate = TransformerIntermediate(config)
    def forward(self, hidden):
        self_attention_outputs = self.attention(hidden)
        attention_output = self_attention_outputs[0]
        outputs = self_attention_outputs[1:]  # add self attentions if we output attention weights

        intermediate_output = self.intermediate(attention_output)
        layer_output = self.output(intermediate_output, attention_output)
        outputs = (layer_output,) + outputs
        return outputs

In [36]:
transformer_layer = TransformerLayer(transformer_config)

In [37]:
transformer_layer(embedding_output)[0].shape

torch.Size([8, 4, 24])