In [1]:
from utils.ch09util import create_model

In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import pickle

In [3]:
with open('files/dict.p', 'rb') as f:
    en_word_dict,en_idx_dict, fr_word_dict,fr_idx_dict = pickle.load(f)

In [4]:
src_vocab = len(en_word_dict)
tgt_vocab = len(fr_word_dict)
print(f"there are {src_vocab} distinct English tokens")
print(f"there are {tgt_vocab} distinct French tokens")

there are 11055 distinct English tokens
there are 11239 distinct French tokens


In [5]:
model = create_model(src_vocab, tgt_vocab, N=6, d_model=256, d_ff=1024, h=8, dropout=0.1)

In [6]:
state_dict = torch.load('files/my_en2fr.pth', map_location='mps')

In [7]:
model.load_state_dict(state_dict)

<All keys matched successfully>

In [8]:
from transformers import XLMTokenizer
tokenizer = XLMTokenizer.from_pretrained('xlm-clm-enfr-1024')

In [9]:
from utils.ch09util import subsequent_mask
DEVICE='mps'
PAD=0
UNK=1
def translate(eng):
    # tokenize the English sentence
    tokenized_en=tokenizer.tokenize(eng)
    # add beginning and end tokens
    tokenized_en=["BOS"]+tokenized_en+["EOS"]
    # convert tokens to indexes
    enidx=[en_word_dict.get(i,UNK) for i in tokenized_en]  
    src=torch.tensor(enidx).long().to(DEVICE).unsqueeze(0)
    # create mask to hide padding
    src_mask=(src!=0).unsqueeze(-2)
    # encode the English sentence
    memory=model.encode(src,src_mask)
    # start translation in an autogressive fashion
    start_symbol=fr_word_dict["BOS"]
    ys = torch.ones(1, 1).fill_(start_symbol).type_as(src.data)
    translation=[]
    for i in range(100):
        out = model.decode(memory,src_mask,ys,
        subsequent_mask(ys.size(1)).type_as(src.data))
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.data[0]
        ys = torch.cat([ys, torch.ones(1, 1).type_as(
            src.data).fill_(next_word)], dim=1)
        sym = fr_idx_dict[ys[0, -1].item()]
        if sym != 'EOS':
            translation.append(sym)
        else:
            break
    # convert tokens to sentences
    trans="".join(translation)
    trans=trans.replace("</w>"," ") 
    for x in '''?:;.,'("-!&)%''':
        trans=trans.replace(f" {x}",f"{x}")    
    print(trans)
    return trans

In [10]:
model.eval()

Transformer(
  (encoder): Encoder(
    (layers): ModuleList(
      (0-5): 6 x EncoderLayer(
        (self_attn): MultiHeadedAttention(
          (linears): ModuleList(
            (0-3): 4 x Linear(in_features=256, out_features=256, bias=True)
          )
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (feed_forward): PositionwiseFeedForward(
          (w_1): Linear(in_features=256, out_features=1024, bias=True)
          (w_2): Linear(in_features=1024, out_features=256, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (sublayer): ModuleList(
          (0-1): 2 x SublayerConnection(
            (norm): LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
    )
    (norm): LayerNorm()
  )
  (decoder): Decoder(
    (layers): ModuleList(
      (0-5): 6 x DecoderLayer(
        (self_attn): MultiHeadedAttention(
          (linears): ModuleList(
            (0-3): 4 x Linear(in_features=256, out_fea

In [11]:
eng = "Today is a beautiful day!"
translated_fr = translate(eng)

troUNKUNKplantambour troUNKUNKcontre. contre. contre. contre...... kie UNKkie UNKfabrique warkie UNKfabrique warkie UNKfabrique warkie UNKfabrique warkie UNKfabrique warkie UNKfabrique warkie UNKfabrique warkie UNKfabrique warkie UNKkie UNKfabrique warkie UNKfabrique warkie UNKfabrique warkie UNKfabrique warkie UNKfabrique warkie UNKcontre.... contre. contre.... contre. contre. kie UNKkie UNKkie 
