In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Attention

In [144]:
class Attention(nn.Module):
    def __init__(self, d_in, d_out):
        """
        Parameters:
        - d_in (int): size of input - the hidden size aka d_model
        - d_out (int): size of output - the length of the k and q vectors
        """
        super().__init__()
        self.d_in = d_in
        self.d_out = d_out
        self.Q = nn.Linear(d_in, d_out)
        self.K = nn.Linear(d_in, d_out)
        self.V = nn.Linear(d_in, d_out)

    def forward(self, x):
        # x is (batch_size, seq_len, hidden_size)
        # queries, keys, values are (batch_size, seq_len, d_out)
        queries = self.Q(x)
        keys = self.K(x)
        values = self.V(x)
        
        # transpose keys => (batch_size, d_out, seq_len)
        #
        # Example: 
        # seq_len = 5, d_out = 2
        #                    v
        # Q: | x x |<   K: | x x x x x |      | X x x x x |
        #    | x x |       | x x x x x |      | x x x x x |
        #    | x x |  @                    =  | x x x x x |
        #    | x x |                          | x x x x x |
        #    | x x |                          | x x x x x |
        #
        # X represents the dot product of the 1x2 Q vector in row 1 of Q and the 2x1 K vector in col 1 of K
        # See < and v for positions in Q and K
        #
        # scores, attention = (batch_size, seq_len, seq_len)
        scores = torch.bmm(queries, keys.transpose(1, 2))
        scores = scores / (self.d_out ** 0.5)
        attention = F.softmax(scores, dim=2)

        # Hidden states
        # Iterate over weights for each query, generating weighted sum of all values in seq
        # A: | x x x x x |     V: | x x |   | x x |
        #    | x x x x x |        | x x |   | x x |
        #    | x x x x x |  @     | x x | = | x x |
        #    | x x x x x |        | x x |   | x x |
        #    | x x x x x |        | x x |   | x x |
        # 
        hidden_states = torch.bmm(attention, values)
        return hidden_states        

In [145]:
SOS_token = 0
EOS_token = 0

index2words = {
    SOS_token: 'SOS',
    EOS_token: 'EOS',
}

words = "How are you doing ? I am good and you ?"
words_list = set(words.lower().split())
for word in words_list:
    index2words[len(index2words)] = word

words2index = {w: i for i, w in index2words.items()}

In [107]:
def convert2tensors(sentence):
    words_list = sentence.lower().split(' ')
    indexes = [words2index[word] for word in words_list]
    return torch.tensor(indexes, dtype=torch.long).view(1, -1)

In [110]:
sentence = "How are you doing ? I am good and you ?"
convert2tensors(sentence)

tensor([[3, 6, 1, 7, 2, 8, 4, 9, 5, 1, 2]])

In [124]:
HIDDEN_SIZE = 10
VOCAB_SIZE = len(words2index)

embedding = nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
attention = Attention(HIDDEN_SIZE, HIDDEN_SIZE)

sentence = "How are you doing ?"
input_tensor = convert2tensors(sentence)
embedded = embedding(input_tensor)
hidden_states = attention(embedded)

# Multihead Attention

In [146]:
class MultiheadAttention(nn.Module):
    def __init__(self, hidden_size, num_heads):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.out = nn.Linear(hidden_size, hidden_size)
        self.heads = nn.ModuleList([
            Attention(hidden_size, hidden_size // num_heads)
            for _ in range(num_heads)
        ])

    def forward(self, x):
        # List[ (batch, seq_len, hidden_size // num_heads ) ], len(list) = num_heads 
        outputs = [head(x) for head in self.heads]
        # (batch, seq_len, (hidden_size // num_heads) * num_heads) = (batch, seq_len, hidden_size)
        outputs = torch.cat(outputs, dim=2)
        # (batch, seq_len, hidden_size)
        hidden_states = self.out(outputs)
        return hidden_states

In [147]:
NUM_HEADS = 3
HIDDEN_SIZE = 12
VOCAB_SIZE = len(words2index)

multi_att = MultiheadAttention(HIDDEN_SIZE, NUM_HEADS)
embedding = nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)

sentence = "How are you doing ?"
input_tensor = convert2tensors(sentence)
embedded = embedding(input_tensor)
hidden_states = multi_att(embedded)

In [148]:
hidden_states.shape

torch.Size([1, 5, 12])

In [129]:
hidden_states

tensor([[[ 0.0108,  0.0522,  0.1712,  0.2370, -0.0738, -0.2954,  0.0350,
          -0.0551,  0.0423, -0.1061, -0.0453, -0.0933],
         [ 0.0239,  0.0556,  0.1297,  0.1967, -0.0386, -0.2766,  0.0329,
          -0.0894,  0.0359, -0.1271, -0.0508, -0.1351],
         [-0.0068,  0.1204,  0.0525,  0.1229,  0.0398, -0.2572,  0.0457,
          -0.1234,  0.0786, -0.1508, -0.0979, -0.2468],
         [-0.0199,  0.0468,  0.1574,  0.2117, -0.0436, -0.2889,  0.0429,
          -0.0565,  0.0578, -0.1142, -0.0644, -0.1291],
         [-0.0412,  0.0683,  0.1340,  0.1760, -0.0055, -0.2682,  0.0511,
          -0.0569,  0.0727, -0.1113, -0.0643, -0.1850]]],
       grad_fn=<ViewBackward0>)

In [149]:
class MultiheadAttention(nn.Module):
    def __init__(self, hidden_size, num_heads):
        super().__init__()

        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads

        self.qkv_linear = nn.Linear(hidden_size, hidden_size * 3)
        self.out = nn.Linear(hidden_size, hidden_size)

    def forward(self, x):
        batch_size, seq_length, hidden_size = x.size()

        # Compute Q, K, and V in one shot
        # We have one Wq matrix of size (hidden_size, hidden_size) instead of
        # num_heads Q matrices of size (hidden_size // num_heads, hidden_size)
        # It's equivalent, since the former implementation is segmenting the output layer of Q into num_head partitions
        # (batch_size, seq_length, hidden_size * 3)
        qkv = self.qkv_linear(x)

        # Split the Q,K,V output into num_heads partitions, each with it's own mini Q,K,V vector
        # (batch_size, seq_length, num_heads, head_dim * 3)
        qkv = qkv.reshape(batch_size, seq_length, self.num_heads, 3 * self.head_dim)

        # Promote the heads group so we focus on the smaller seq_length, qkv group
        # (batch_size, num_heads, seq_length, head_dim * 3)
        qkv = qkv.transpose(1, 2)

        # Break apart the glued-together qkv tensors into separate tensors for each type
        # (batch_size, num_heads, seq_length, head_dim)
        queries, keys, values = qkv.chunk(3, dim=-1)

        # Same as in single-head attention layer, compute self-attention scores.  Only now we're doing it in parallel across num_heads.
        # So we windo up with (batch_size, num_heads) attention grids.
        # (batch_size, num_heads, seq_length, seq_length)
        scores = torch.matmul(queries, keys.transpose(2, 3))        
        # (batch_size, num_heads, seq_length, seq_length)
        scores = scores / (self.head_dim ** 0.5)
        # (batch_size, num_heads, seq_length, seq_length)
        attention = F.softmax(scores, dim=-1)

        # same as single-head attention, except we're doing num_heads in parallel
        # so for each head, we have a list of hidden states of size head_dim
        # (batch_size, num_heads, seq_length, head_dim)
        context = torch.matmul(attention, values)

        # prepare to merge mini-hidden states but collecting all the head outputs per token
        # Example token 1 -> heads(1, 2, 3) -> mini-states (1, 2, 3)
        # (batch_size, seq_length, num_heads, head_dim)
        context = context.transpose(1, 2)

        # squish the (num_heads, head_dim) into a vector of num_heads * head_dim = hidden_size
        # (batch_size, seq_length, hidden_size)
        context = context.reshape(batch_size, seq_length, hidden_size)

        # Transform the context.  Not sure how this helps but whatever.
        # (batch_size, seq_length, hidden_size)
        output = self.out(context)
        return output
        
        
        

In [150]:
NUM_HEADS = 3
HIDDEN_SIZE = 12
VOCAB_SIZE = len(words2index)

multi_att = MultiheadAttention(HIDDEN_SIZE, NUM_HEADS)
embedding = nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)

sentence = "How are you doing ?"
input_tensor = convert2tensors(sentence)
embedded = embedding(input_tensor)
hidden_states = multi_att(embedded)

In [151]:
hidden_states.size()

torch.Size([1, 5, 12])

In [152]:
multi_att = nn.MultiheadAttention(HIDDEN_SIZE, NUM_HEADS)

In [153]:
multi_att

MultiheadAttention(
  (out_proj): NonDynamicallyQuantizableLinear(in_features=12, out_features=12, bias=True)
)

# Positional Encoding

In [227]:
class PositionalEncoding(nn.Module):
    def __init__(self, context_size, d_model):
        super().__init__()

        self.encoding = torch.zeros(context_size, d_model)
        
        pos = torch.arange(0, context_size).unsqueeze(dim=1)
        dim = torch.arange(0, d_model, 2)
        
        self.encoding[:, 0::2] = torch.sin(pos / (10000 ** (2 * dim / d_model)))
        self.encoding[:, 1::2] = torch.cos(pos / (10000 ** (2 * dim / d_model)))

    def forward(self, x):
        seq_len = x.size(1)
        return self.encoding[:seq_len, :]

# Position-wise Feed-Forward Network

In [210]:
class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()

        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        return x

# Encoder

In [238]:
class EncoderBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, num_heads)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x):
        hidden_states, _ = self.self_attn(query=x, key=x, value=x)
        x = self.norm1(hidden_states + x)
        ff_output = self.feed_forward(x)
        x = self.norm2(ff_output + x)
        return x
        
        

In [236]:
class Encoder(nn.Module):
    def __init__(self, input_size, context_size, d_model, d_ff, num_heads, n_blocks):
        super().__init__()

        self.embedding = nn.Embedding(input_size, d_model)
        self.pos_embedding = PositionalEncoding(context_size, d_model)
        self.blocks = nn.ModuleList([EncoderBlock(d_model, num_heads, d_ff) for _ in range(n_blocks)])

    def forward(self, x):
        x = self.embedding(x) + self.pos_embedding(x)
        for block in self.blocks:
            x = block(x)
        return x

# Decoder

In [217]:
class DecoderBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, num_heads)
        self.norm1 = nn.LayerNorm(d_model)
        self.cross_attn = nn.MultiheadAttention(d_model, num_heads)
        self.norm2 = nn.LayerNorm(d_model)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff)
        self.norm3 = nn.LayerNorm(d_model)


    def forward(self, x, enc_output):
        hidden_states, _ = self.self_attn(query=x, key=x, value=x)
        x = self.norm1(hidden_states + x)
        hidden_states, _ = self.cross_attn(query=x, key=enc_output, value=enc_output)
        x = self.norm2(hidden_states + x)
        ff_output = self.feed_forward(x)
        x = self.norm3(ff_output + x)
        return x

In [229]:
class Decoder(nn.Module):
    def __init__(self, output_size, context_size, d_model, d_ff, num_heads, num_blocks):
        super().__init__()

        self.embedding = nn.Embedding(output_size, d_model)
        self.pos_embedding = PositionalEncoding(context_size, d_model)
        self.blocks = nn.ModuleList([DecoderBlock(d_model, num_heads, d_ff) for _ in range(num_blocks)])
        self.out = nn.Linear(d_model, output_size)

    def forward(self, x, enc_output):
        x = self.embedding(x) + self.pos_embedding(x)
        
        for block in self.blocks:
            x = block(x, enc_output)

        output = self.out(x)
        return output            

In [224]:
class Transformer(nn.Module):
    def __init__(self, vocab_size, context_size, d_model, d_ff, num_heads, n_blocks):
        super().__init__()
        self.encoder = Encoder(vocab_size, context_size, d_model, d_ff, num_heads, n_blocks)
        self.decoder = Decoder(vocab_size, context_size, d_model, d_ff, num_heads, n_blocks)

    def forward(self, input_encoder, input_decoder):
        enc_output = self.encoder(input_encoder)
        output = self.decoder(input_decoder, enc_output)
        return output

# Testing

In [233]:
SOS_token = 0
EOS_token = 1
PAD_token = 2

index2words = {
    SOS_token: 'SOS',
    EOS_token: 'EOS',
    PAD_token: 'PAD'
}

words = "How are you doing ? I am good and you ?"
words_list = set(words.lower().split(' '))
for word in words_list:
    index2words[len(index2words)] = word

words2index = {w: i for i, w in index2words.items()}

def convert2tensors(sentence, max_len):
    words_list = sentence.lower().split(' ')
    padding = ['PAD'] * (max_len - len(words_list))
    words_list.extend(padding)
    indexes = [words2index[word] for word in words_list]
    return torch.tensor(indexes, dtype=torch.long).view(1, -1)


In [239]:
D_MODEL = 10
VOCAB_SIZE = len(words2index)
N_BLOCKS = 10
D_FF = 20
CONTEXT_SIZE = 100
NUM_HEADS = 2

transformer = Transformer(
    vocab_size=VOCAB_SIZE,
    context_size=CONTEXT_SIZE,
    d_model=D_MODEL,
    d_ff=D_FF,
    num_heads=NUM_HEADS,
    n_blocks=N_BLOCKS
)

input_sentence = "How are you doing ?"
output_sentence = "I am good and"

input_encoder = convert2tensors(input_sentence, CONTEXT_SIZE)
input_decoder = convert2tensors(output_sentence, CONTEXT_SIZE)

output = transformer(input_encoder, input_decoder)

In [240]:
output.size()

torch.Size([1, 100, 12])

In [251]:
_, indexes = output.squeeze().topk(1)
# 3 is index of next predicted word
index2words[indexes[3].item()]

'and'