In [1]:
import pickle
# Getting vocab and tokenizers [src->trg]
with open('.data/en_to_de_vocab_token.pkl', 'rb') as f:  
    vocab_en, vocab_de, en_tokenizer, de_tokenizer = pickle.load(f)

In [2]:
import torch
from attention_transformer import Transformer
import torch.nn as nn
SRC_VOCAB_SIZE ,TRG_VOCAB_SIZE = len(vocab_en) , len(vocab_de)
SRC_PAD_IDX, TRG_PAD_IDX = vocab_en['<PAD>'] , vocab_de['<PAD>']
MAX_SENTENCE_LENGTH = 256
EMBED_SIZE , NUM_LAYERS , FORWARD_EXPANSION , HEADS = 256, 3, 2 , 8
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DROPOUT = 0.1

model = Transformer(
    src_vocab_len=SRC_VOCAB_SIZE,
    trg_vocab_len=TRG_VOCAB_SIZE,
    src_pad_idx = SRC_PAD_IDX,
    trg_pad_idx = TRG_PAD_IDX,
    src_max_sentence_len = MAX_SENTENCE_LENGTH,
    trg_max_sentence_len = MAX_SENTENCE_LENGTH,
    hid_dim = EMBED_SIZE,
    n_layers = NUM_LAYERS,
    n_heads = HEADS,
    ff_dim_multiplier = FORWARD_EXPANSION,
    dropout = DROPOUT,
    device = DEVICE
).to(DEVICE)


# Loading params for model
model.load_state_dict(torch.load('model_dicts/en_ger_attention_model.pt'))

<All keys matched successfully>

In [3]:
def translate_sentence(sentence, src_tokenizer, src_vocab, trg_vocab, model, device,max_len = 50):
    model.eval()
    
    src_tokens = ['<BOS>'] + [token.text.lower() for token in src_tokenizer(sentence)] + ['<EOS>']
    src_indexes = [src_vocab[token] for token in src_tokens]
    src_tensor = torch.LongTensor(src_indexes).unsqueeze(0).to(device)

    trg_indexes = []
    trg_indexes.append(trg_vocab['<BOS>'])

    for i in range(max_len):
        trg_tensor = torch.LongTensor(trg_indexes).unsqueeze(0).to(device)
        
        with torch.no_grad():
            output = model(src_tensor,trg_tensor)
        
        pred_token = output.cpu().argmax(2)[:,-1]
        
        trg_indexes.append(pred_token)

        if pred_token == trg_vocab['<EOS>']:
            break
    
    trg_tokens = [trg_vocab.lookup_token(i) for i in trg_indexes]
    print(trg_tokens)
    """Memory Management"""
    output = output.cpu()
    del output
    return trg_tokens


In [4]:
translation = translate_sentence(sentence = "Fuck your mother",
                                 src_tokenizer = en_tokenizer,
                                 src_vocab = vocab_en, 
                                 trg_vocab = vocab_de,
                                 model = model,
                                 device = DEVICE,
                                 max_len = 30
                                )

['<BOS>', 'scheiß', 'ihre', 'mutter', 'an', '.', '\n', '<EOS>']
