link : https://github.com/AladdinPerzon/Machine-Learning-Collection/blob/master/ML/Pytorch/more_advanced/transformer_from_scratch/transformer_from_scratch.py

In [2]:
import torch
import torch.nn as nn

In [6]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embedSize, heads):
        super(MultiHeadAttention, self).__init__()
        
        self.embedSize = embedSize
        self.heads = heads
        self.headDim =  embedSize // heads
        
        assert (self.headDim*self.heads == self.embedSize), "Embedding size needs to be fixed"

        self.values = nn.Linear(self.headDim, self.headDim, bias=False)
        self.keys = nn.Linear(self.headDim, self.headDim, bias=False)
        self.queries = nn.Linear(self.headDim, self.headDim, bias=False)
        self.fc = nn.Linear(self.heads*self.headDim, embedSize)
        
    def forward(self, value, key, query):
        # N : number of training example
        N = query.shape[0]
        
        valueLen, keyLen, queryLen = value.shape[1], key.shape[1], query.shape[1]
        
        # Split the embedding into heads
        values = value.reshape(N, valueLen, self.heads, self.headDim)
        keys = key.reshape(N, keyLen, self.heads, self.headDim)
        queries = query.reshape(N, queryLen, self.heads, self.headDim)
        
        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)
        
        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        # (N, heads, queryLen, keyLen)
        
        if mask is not None:
            energy = energy.masked_fill(mask==0, float("-1e20"))
            
        attention = torch.softmax(energy / (self.embedSize ** (1 / 2)), dim=3)
        # attention shape: (N, heads, queryLen, keyLen)
        
        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, queryLen, self.heads * self.headDim)
        # out after matrix multiply: (N, queryLen, heads, head_dim), then
        # we reshape and flatten the last two dimensions.

        out = self.fc(out)
        # final shape : (N, query_len, embedSize)

        return out

In [7]:
class TransformerBlock(nn.Module):
    def __init__(self, embedSize, heads, dropout, forwardExpansion):
        super(TransformerBlock, self).__init__()
        
        self.attention = MultiHeadAttention(embedSize, heads)
        self.norm1 = nn.LayerNorm(embedSize)
        self.norm2 = nn.LayerNorm(embedSize)

        self.feedForward = nn.Sequential(
            nn.Linear(embedSize, forwardExpansion * embedSize),
            nn.ReLU(),
            nn.Linear(forwardExpansion * embedSize, embedSize),
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query, mask):
        attention = self.attention(value, key, query, mask)
        x = self.dropout(self.norm1(attention + query))
        forward = self.feedForward(x)
        out = self.dropout(self.norm2(forward + x))
        return out

In [8]:
class Encoder(nn.Module):
    def __init__(
        self,
        srcVocabSize,
        embedSize,
        numLayers,
        heads,
        device,
        forwardExpansion,
        dropout,
        maxLength,
    ):

        super(Encoder, self).__init__()
        self.embedSize = embedSize
        self.device = device
        self.wordEmbedding = nn.Embedding(srcVocabSize, embedSize)
        self.positionEmbedding = nn.Embedding(maxLength, embedSize)

        self.layers = nn.ModuleList(
            [
                TransformerBlock(
                    embedSize,
                    heads,
                    dropout=dropout,
                    forwardExpansion=forwardExpansion,
                )
                for _ in range(numLayers)
            ]
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        N, seqLength = x.shape
        positions = torch.arange(0, seqLength).expand(N, seqLength).to(self.device)
        out = self.dropout(self.wordEmbedding(x) + self.positionEmbedding(positions))

        # In the Encoder the query, key, value are all the same 
        # it's in the decoder this will change. 
        for layer in self.layers:
            out = layer(out, out, out, mask)

        return out

In [9]:
class DecoderBlock(nn.Module):
    def __init__(self, embedSize, heads, forwardExpansion, dropout, device):
        super(DecoderBlock, self).__init__()
        
        self.norm = nn.LayerNorm(embedSize)
        self.attention = MultiHeadAttention(embedSize, heads=heads)
        self.transformerBlock = TransformerBlock(
            embedSize, heads, dropout, forwardExpansion
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, value, key, srcMask, trgMask):
        attention = self.attention(x, x, x, trgMask)
        query = self.dropout(self.norm(attention + x))
        out = self.transformerBlock(value, key, query, srcMask)
        return out

In [10]:
class Decoder(nn.Module):
    def __init__(
        self,
        trgVocabSize,
        embedSize,
        numLayers,
        heads,
        forwardExpansion,
        dropout,
        device,
        maxLength,
    ):
        super(Decoder, self).__init__()
        
        self.device = device
        self.wordEmbedding = nn.Embedding(trgVocabSize, embedSize)
        self.positionEmbedding = nn.Embedding(maxLength, embedSize)

        self.layers = nn.ModuleList(
            [
                DecoderBlock(embedSize, heads, forwardExpansion, dropout, device)
                for _ in range(numLayers)
            ]
        )
        self.fc = nn.Linear(embedSize, trgVocabSize)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, encOut, srcMask, trgMask):
        N, seqLength = x.shape
        positions = torch.arange(0, seqLength).expand(N, seqLength).to(self.device)
        x = self.dropout((self.wordEmbedding(x) + self.positionEmbedding(positions)))

        for layer in self.layers:
            x = layer(x, encOut, encOut, srcMask, trgMask)

        out = self.fc(x)

        return out


In [11]:
class Transformer(nn.Module):
    def __init__(
        self,
        srcVocabSize,
        trgVocabSize,
        srcPadIdx,
        trgPadIdx,
        embedSize=512,
        numLayers=6,
        forwardExpansion=4,
        heads=8,
        dropout=0,
        device="cuda",
        maxLength=100,
    ):

        super(Transformer, self).__init__()

        self.encoder = Encoder(
            srcVocabSize,
            embedSize,
            numLayers,
            heads,
            device,
            forwardExpansion,
            dropout,
            maxLength,
        )

        self.decoder = Decoder(
            trgVocabSize,
            embedSize,
            numLayers,
            heads,
            forwardExpansion,
            dropout,
            device,
            maxLength,
        )

        self.srcPadIdx = srcPadIdx
        self.trgPadIdx = trgPadIdx
        self.device = device

    def makeSrcMask(self, src):
        srcMask = (src != self.srcPadIdx).unsqueeze(1).unsqueeze(2)
        # (N, 1, 1, srcLen)
        return srcMask.to(self.device)

    def makeTrgMask(self, trg):
        N, trgLen = trg.shape
        trgMask = torch.tril(torch.ones((trgLen, trgLen))).expand(
            N, 1, trgLen, trgLen)

        return trgMask.to(self.device)

    def forward(self, src, trg):
        srcMask = self.makeSrcMask(src)
        trgMask = self.makeTrgMask(trg)
        encSrc = self.encoder(src, srcMask)
        out = self.decoder(trg, encSrc, srcMask, trgMask)
        return out

# Simple Example

In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

x = torch.tensor([[1, 5, 6, 4, 3, 9, 5, 2, 0], [1, 8, 7, 3, 4, 5, 6, 7, 2]]).to(
    device
)
trg = torch.tensor([[1, 7, 4, 3, 5, 9, 2, 0], [1, 5, 6, 2, 4, 7, 6, 2]]).to(device)

src_pad_idx = 0
trg_pad_idx = 0
src_vocab_size = 10
trg_vocab_size = 10
model = Transformer(src_vocab_size, trg_vocab_size, src_pad_idx, trg_pad_idx).to(
    device
)
out = model(x, trg[:, :-1])
print(out.shape)