In [1]:
from transformer_v2 import *

  from .autonotebook import tqdm as notebook_tqdm


### Greedy Decoding

In [2]:
def greedy_decode(model, src, src_mask, max_len, start_symbol):
    memory = model.encode(src, src_mask)
    ys = torch.zeros(1, 1).fill_(start_symbol).type_as(src.data)
    for i in range(max_len - 1):
        out = model.decode(
            memory, src_mask, ys, subsequent_mask(ys.size(1)).type_as(src.data)
        )
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.data[0]
        ys = torch.cat(
            [ys, torch.zeros(1, 1).type_as(src.data).fill_(next_word)], dim=1
        )
    return ys

### Load data and model for output checks

In [4]:
def check_outputs(
    valid_dataloader,
    model,
    vocab_src,
    vocab_tgt,
    n_examples=15,
    pad_idx=2,
    eos_string="</s>",
):
    results = [()] * n_examples
    for idx in range(n_examples):
        print("\nExample %d ========\n" % idx)
        b = next(iter(valid_dataloader))
        rb = Batch(b[0], b[1], pad_idx)
        greedy_decode(model, rb.src, rb.src_mask, 64, 0)[0]

        src_tokens = [
            vocab_src.get_itos()[x] for x in rb.src[0] if x != pad_idx
        ]
        tgt_tokens = [
            vocab_tgt.get_itos()[x] for x in rb.tgt[0] if x != pad_idx
        ]

        print(
            "Source Text (Input)        : "
            + " ".join(src_tokens).replace("\n", "")
        )
        print(
            "Target Text (Ground Truth) : "
            + " ".join(tgt_tokens).replace("\n", "")
        )
        model_out = greedy_decode(model, rb.src, rb.src_mask, 72, 0)[0]
        model_txt = (
            " ".join(
                [vocab_tgt.get_itos()[x] for x in model_out if x != pad_idx]
            ).split(eos_string, 1)[0]
            + eos_string
        )
        print("Model Output               : " + model_txt.replace("\n", ""))
        results[idx] = (rb, src_tokens, tgt_tokens, model_out, model_txt)
    return results


def run_model_example(n_examples=5):
    global vocab_src, vocab_tgt, spacy_de, spacy_en

    print("Preparing Data ...")
    _, valid_dataloader = create_dataloaders(
        torch.device("cpu"),
        vocab_src,
        vocab_tgt,
        spacy_de,
        spacy_en,
        batch_size=1,
        is_distributed=False,
    )

    print("Loading Trained Model ...")

    model = make_model(len(vocab_src), len(vocab_tgt), N=6)
    model.load_state_dict(
        torch.load("multi30k_model_final.pt", map_location=torch.device("cpu"))
    )

    print("Checking Model Outputs:")
    example_data = check_outputs(
        valid_dataloader, model, vocab_src, vocab_tgt, n_examples=n_examples
    )
    return model, example_data


### pre-processing dataset

In [5]:
spacy_de, spacy_en = load_tokenizers()
vocab_src, vocab_tgt = load_vocab(spacy_de, spacy_en)



Finished.
Vocabulary sizes:
8317
6384


## Results
epoch가 8회 수행되었을 때는 번역 품질이 좋지 못함.

-----------

``` text
Example 0 ========

Source Text (Input)        : <s> Mann beobachtet eine Frau , die auf dem Gehweg raucht . </s>
Target Text (Ground Truth) : <s> Man looking at a woman that is smoking on the sidewalk . </s>
Model Output               : <s> Man watching woman smoking on the sidewalk . </s>

Example 1 ========

Source Text (Input)        : <s> Ein Mann mit einer weißen Schürze und Hut verkauft Fleisch an einer belebten Straße . </s>
Target Text (Ground Truth) : <s> A man in a white apron and hat is selling meat on a busy street . </s>
Model Output               : <s> A man in a white apron and hat is selling meat on a busy street . </s>

Example 2 ========

Source Text (Input)        : <s> Drei weiße Männer in T-Shirts springen in die Luft . </s>
Target Text (Ground Truth) : <s> Three white men in t - shirt jump into the air . </s>
Model Output               : <s> Three white men in t - shirts jumping in the air . </s>

Example 3 ========

Source Text (Input)        : <s> Zwei Kinder springen auf einem abgeschirmten blau-schwarzen Trampolin , das von Bäumen umgeben ist . </s>
Target Text (Ground Truth) : <s> Two children jumping on a screened in blue and black trampoline while outside surrounded by trees . </s>
Model Output               : <s> Two children are jumping on a black and blue trampoline surrounded by trees . </s>

Example 4 ========

Source Text (Input)        : <s> Ein kleiner Junge trägt eine <unk> Flagge und geht neben einer Frau . </s>
Target Text (Ground Truth) : <s> A young boy carries a green , white , and red flag and walks next to a woman . </s>
```

---

In [6]:
run_model_example()

Preparing Data ...




Loading Trained Model ...
Checking Model Outputs:


Source Text (Input)        : <s> Mann beobachtet eine Frau , die auf dem Gehweg raucht . </s>
Target Text (Ground Truth) : <s> Man looking at a woman that is smoking on the sidewalk . </s>
Model Output               : <s> Man watching woman smoking on the sidewalk . </s>


Source Text (Input)        : <s> Ein Mann mit einer weißen Schürze und Hut verkauft Fleisch an einer belebten Straße . </s>
Target Text (Ground Truth) : <s> A man in a white apron and hat is selling meat on a busy street . </s>
Model Output               : <s> A man in a white apron and hat is selling meat on a busy street . </s>


Source Text (Input)        : <s> Drei weiße Männer in T-Shirts springen in die Luft . </s>
Target Text (Ground Truth) : <s> Three white men in t - shirt jump into the air . </s>
Model Output               : <s> Three white men in t - shirts jumping in the air . </s>


Source Text (Input)        : <s> Zwei Kinder springen auf einem abgesch

(EncoderDecoder(
   (encoder): Encoder(
     (layers): ModuleList(
       (0): EncoderLayer(
         (self_attn): MultiHeadedAttention(
           (linears): ModuleList(
             (0): Linear(in_features=512, out_features=512, bias=True)
             (1): Linear(in_features=512, out_features=512, bias=True)
             (2): Linear(in_features=512, out_features=512, bias=True)
             (3): Linear(in_features=512, out_features=512, bias=True)
           )
           (dropout): Dropout(p=0.1, inplace=False)
         )
         (feed_forward): PositionwiseFeedForward(
           (w_1): Linear(in_features=512, out_features=2048, bias=True)
           (w_2): Linear(in_features=2048, out_features=512, bias=True)
           (dropout): Dropout(p=0.1, inplace=False)
         )
         (sublayer): ModuleList(
           (0): SublayerConnection(
             (norm): LayerNorm()
             (dropout): Dropout(p=0.1, inplace=False)
           )
           (1): SublayerConnection(
       