In [1]:
from model import Transformer
import pickle
import torch
from decoding import GreedyDecoder
import spacy

In [2]:
def read_vocab(path):
    with open(path, 'rb') as f:
        return pickle.load(f)

en_vocab = read_vocab('weights/Yandex/en_vocab.pkl')
ru_vocab = read_vocab('weights/Yandex/ru_vocab.pkl')

In [3]:
SEQ_LEN = 64
SRC_VOCAB_SIZE = len(ru_vocab)
TRG_VOCAB_SIZE = len(en_vocab)
DIM = 512
N_HEADS = 8
PW_NET_DIM = 2048
N_EN_BLOCKS = 6
N_DE_BLOCKS = 6
DROPOUT_P = 0.1

model = Transformer(
    seq_len=SEQ_LEN,
    src_vocab_size=SRC_VOCAB_SIZE,
    trg_vocab_size=TRG_VOCAB_SIZE,
    dim=DIM,
    n_heads=N_HEADS,
    pw_net_dim=PW_NET_DIM,
    n_de_blocks=N_DE_BLOCKS,
    n_en_blocks=N_EN_BLOCKS,
    dropout_p=DROPOUT_P
)

In [4]:
_ = model.cuda()

In [5]:
model.load_state_dict(torch.load('weights/Yandex/yandex_weights.pth', map_location='cuda'))

<All keys matched successfully>

In [6]:
ru_tokenizer = spacy.load('ru_core_news_sm')
# Disabling those pipes increases processing speed by an order of magnitude
ru_tokenizer.disable_pipes(['tok2vec', 'morphologizer', 'parser', 'attribute_ruler', 'lemmatizer', 'ner'])
decoder = GreedyDecoder(
    model=model,
    tokenizer=lambda sent: [token.text for token in ru_tokenizer(sent)],
    src_vocab=ru_vocab,
    trg_vocab=en_vocab,
    sos_token_id=2,
    eos_token_id=3,
    pad_token_id=1,
    max_seq_length=SEQ_LEN
)

In [11]:
decoder.decode('Русский язык невероятно сложен.')

'this is a complex and complex set of words that are hard to translate .'