In [1]:
import torch
import torch.nn as nn
import numpy as np
import gensim.downloader as api
from gensim.utils import tokenize 

In [None]:
embedding_model = api.load('fasttext-wiki-news-subwords-300')

In [3]:
class SelfAttention(nn.Module):
    def __init__(self, embedding_size, n_heads):
        super(SelfAttention, self).__init__()
        
        self.n_heads = n_heads
        self.head_dim = embedding_size // n_heads
        self.embed_size = embedding_size
        
        assert(embedding_size == self.head_dim * n_heads)
        
        self.query_weight = nn.Linear(embedding_size, embedding_size)
        self.value_weight = nn.Linear(embedding_size, embedding_size)
        self.key_weight = nn.Linear(embedding_size, embedding_size)
        self.fc = nn.Linear(embedding_size, embedding_size)
        
    def forward(self, query, key, value, mask=None):
        
        query = self.query_weight(query)
        key = self.key_weight(key)
        value = self.value_weight(value)
        
        batch_size = query.shape[0]
        
        query_len = query.shape[1]
        key_len = key.shape[1]
        value_len = value.shape[1]
        
        query = query.reshape(batch_size, query_len, self.n_heads, self.head_dim)
        key = key.reshape(batch_size, key_len, self.n_heads, self.head_dim)
        value = value.reshape(batch_size, value_len, self.n_heads, self.head_dim)
        
        energy = torch.matmul(
            query.transpose(1, 2),
            key.permute(0, 2, 3, 1)
            ) / self.head_dim**0.5
        
        if mask is not None:
            energy = energy.masked_fill(mask == 0, -1e20)
        
        
        normalized_energy = torch.softmax(energy, dim=3)
        
        attention = torch.matmul(
            normalized_energy,
            value.transpose(1, 2)
            )
        
        out = attention.transpose(1, 2)\
            .reshape(batch_size, query_len, self.embed_size)
            
        out = self.fc(out)
        
        return out

In [4]:
class TransformerBlock(nn.Module):
    def __init__(self, embedding_size, expantion_factor, dropout, n_heads=8):
        super(TransformerBlock, self).__init__()
        
        self.attention = SelfAttention(embedding_size, n_heads)
        self.norm1 = nn.LayerNorm(embedding_size)
        self.norm2 = nn.LayerNorm(embedding_size)
        self.dropout = nn.Dropout(dropout)
        self.feed_forward = nn.Sequential(
            nn.Linear(embedding_size, embedding_size * expantion_factor),
            nn.ReLU(),
            nn.Linear(expantion_factor * embedding_size, embedding_size)
        )
    def forward(self, query, key, values, mask=None):
        
        attention = self.attention(query, key, values, mask)
        attention = self.norm1(attention + query)
        
        attention = self.dropout(attention)
        
        x = self.feed_forward(attention)
        x = self.norm2(x + attention)
        
        out = self.dropout(x)
        
        return out

In [5]:
class Encoder(nn.Module):
    def __init__(self, num_layers, embedding_size,expantion_factor,dropout=0.2, n_heads=8):
        super(Encoder, self).__init__()
        self.layers = nn.ModuleList(
            [TransformerBlock(embedding_size,expantion_factor,dropout,n_heads)
             for _ in range(num_layers)]
        )

    def forward(self, x, mask=None):
        for layer in self.layers:
            x = layer(x, x, x, mask)
        
        return x

In [6]:
class DecoderBlock(nn.Module):
    def __init__(self,embedding_size,expantion_factor, dropout=0.2,n_heads=8):
        super(DecoderBlock, self).__init__()
        
        self.attention = SelfAttention(embedding_size, n_heads)
        self.transformer_block = TransformerBlock(embedding_size, expantion_factor, dropout, n_heads)
        self.norm1 = nn.LayerNorm(embedding_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value, src_mask=None, trg_mask=None):
        out = self.attention(query, query, query, trg_mask)
        
        out = self.norm1(out + query)
        
        query = self.dropout(out)
        
        enc_dec_attention = self.transformer_block(query, key, value, src_mask)
        
        return enc_dec_attention

In [7]:
class Decoder(nn.Module):
    def __init__(self, num_layers, embedding_size,expantion_factor,target_vocab_size, dropout=0.2,n_heads=8):
        super(Decoder, self).__init__()
        
        self.layers = nn.ModuleList([
            DecoderBlock(embedding_size,expantion_factor, dropout, n_heads)
            for _ in range(num_layers)
        ])

        self.fc = nn.Linear(embedding_size, target_vocab_size)

    def forward(self, query, key, value, src_mask=None, trg_mask=None):
        out = query
        for layer in self.layers:
            out = layer(query, key, value, src_mask, trg_mask)
        
        out = self.fc(out)
        return out

In [None]:
class Transformer(nn.Module):
    
    def __init__(self, num_layers,
                 embedding_size,
                 expantion_factor,
                 target_vocab_size,
                 dropout, n_heads,
                 embedding_model,
                 max_len=100):
        super(Transformer, self).__init__()
        
        self.encoder = Encoder(num_layers, embedding_size, expantion_factor, dropout, n_heads)
        self.decoder = Decoder(num_layers, embedding_size, expantion_factor,target_vocab_size, dropout, n_heads)
        self.embedding_model = embedding_model
        self.max_len = max_len
        
    def make_src_mask(self, src):
        src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
        # (N, 1, 1, src_len)
        return src_mask

    def make_trg_mask(self, trg):
        N, trg_len = trg.shape[:2]
        trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(
            N, 1, trg_len, trg_len
        )
        
        return trg_mask
    
    def tokenizer(self, batch_sentence):
        if len(batch_sentence.shape) != 2:
            raise Exception("ndim must be 2 [batch_size, sentence]")
        tokenized_batch = []
        
        for sentence in batch_sentence:
            tokenized_batch.append(self.embedding_model[sentence.tolist()])
        
        return torch.tensor(np.array(tokenized_batch))
        
    
    def from_text(self, batch_text, batch_target_text):
        if isinstance(batch_text, str) and isinstance(batch_target_text, str):
            batch_text, batch_target_text = [batch_text], [batch_target_text]
        text = torch.tensor([
            [self.embedding_model.key_to_index.get(token, "UNK")
            for token in tokenize(text, lowercase=True)] 
            
            for text in batch_text])
        
        target_text = torch.tensor([
            [self.embedding_model.key_to_index.get(token, "<pad>")
            for token in tokenize(target_text, lowercase=True)]
            
            for target_text in batch_target_text] )
        
        return self.forward(text, target_text)
        
    def forward_sequence(self, src, trg):    
        src_mask = self.make_src_mask(src)
        
        embedded_src = self.tokenizer(src)
        
        enc_out = self.encoder(embedded_src, src_mask)
        for _ in range(self.max_len):
            trg_mask = self.make_trg_mask(trg)
            # print(trg_mask.shape)
            # print(trg_mask.dtype)
            embedded_trg = self.tokenizer(trg)
            
            out = self.decoder(embedded_trg, enc_out, enc_out, src_mask, trg_mask)
            
            token_prob = torch.softmax(out[:, -1, :], dim=-1)
            token = torch.argmax(token_prob, dim=-1)
            print(trg.shape)
            print(token.shape)
            trg = torch.cat([
                trg,
                token.unsqueeze(1)
                ], dim=1)
            # if self.embedding_model.index_to_key[token] == "<eos>":
            #     break
        return out
    
    def forward(self, src, trg):
        src_mask = self.make_src_mask(src)
        trg_mask = self.make_trg_mask(trg)
        
        src = self.tokenizer(src)
        trg = self.tokenizer(trg)
        
        enc_out = self.encoder(src, src_mask)
        
        out = self.decoder(trg, enc_out, enc_out, src_mask, trg_mask)
        
        return out

# train_x = torch.randint(0, 67, (10, 500, 15))
# train_trg = torch.randint(0, 67, (10, 500, 11))
train_x = torch.randint(0, 67, (10, 500, 15))
train_trg = torch.randint(0, 67, (10, 500, 11))
test_x = torch.randint(0, 67, (500, 15))
test_trg = torch.randint(0, 67, (500, 11))

num_layers = 1
target_vocab_size = 67
embedding_size = 50
expantion_factor = 2
dropout = 0.2
n_heads = 5
model = Transformer(num_layers, embedding_size, expantion_factor, target_vocab_size, dropout, n_heads, embedding_model)
loss_fn = torch.nn.CrossEntropyLoss()
optim = torch.optim.Adam(model.parameters(), lr=1e-3)

model.train()
for _ in range(100):
    model.train()
    for i, (x, trg) in enumerate(zip(train_x, train_trg)):
        model.zero_grad()
        out = model(x, trg[:, :-1])
        with torch.no_grad():
            predicted_token = torch.softmax(out[:, -1, :], dim=-1)
        batch_size = x.shape[0]
        labels = torch.zeros([batch_size, 67])
        labels[0, 34] = 1
        labels[1, 24] = 1
        # labels
        loss = loss_fn(out[:, -1, :], labels)
        # loss = loss_fn(out[:, -1, :], labels)
        
        print(loss)
        
        loss.backward()
        optim.step()
        break
        
    with torch.no_grad():
        model.eval()
        out = model(test_x, test_trg[:, :-1])
        predicted_token = torch.softmax(out[:, -1, :], dim=-1)
        batch_size = test_x.shape[0]
        labels = torch.zeros([batch_size, 67])
        labels[0, 34] = 1
        labels[1, 24] = 1
        # labels
        if i % 20 == 0:
            loss = loss_fn(out[:, -1, :], labels)
        # loss = loss_fn(out[:, -1, :], labels)
        
        print(f"test: {loss}")
        break
    break
model.forward_sequence(train_x[0], train_trg[0]).shape