In [1]:
import os
import time
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from pprint import pprint
from IPython.display import clear_output

In [2]:
import tensorflow as tf
import tensorflow_datasets as tfds
tf.__version__

'2.4.1'

In [3]:
import logging
logging.basicConfig(level="ERROR")
np.set_printoptions(suppress=True)

In [4]:
output_dir = "nmt"
en_vocab_file = os.path.join(output_dir, "en_vocab")
zh_vocab_file = os.path.join(output_dir, "zh_vocab")
checkpoint_path = os.path.join(output_dir, "checkpoints")
log_dir = os.path.join(output_dir, 'logs')
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

* 新聞評論：newscommentary_v14
* 維基百科標題：wikititles_v1
* 聯合國數據：uncorpus_v1

In [5]:
tmp_builder = tfds.builder("wmt19_translate/zh-en")
tmp_builder.subsets

{Split('train'): ['newscommentary_v14',
  'wikititles_v1',
  'uncorpus_v1',
  'casia2015',
  'casict2011',
  'casict2015',
  'datum2015',
  'datum2017',
  'neu2017'],
 Split('validation'): ['newstest2018']}

In [6]:
config = tfds.translate.wmt.WmtConfig(
  version = tfds.core.Version('0.0.3', experiments={tfds.core.ReadInstruction: False}),
  language_pair = ("zh", "en"),
  subsets = {
    tfds.Split.TRAIN: ["newscommentary_v14"],
    tfds.Split.VALIDATION: ["newstest2018"],
  }
)
builder = tfds.builder("wmt_translate", config = config)
builder.download_and_prepare()
clear_output()

In [7]:
examples = builder.as_dataset(as_supervised=True)
examples

{'train': <PrefetchDataset shapes: ((), ()), types: (tf.string, tf.string)>,
 'validation': <PrefetchDataset shapes: ((), ()), types: (tf.string, tf.string)>}

In [8]:
train_data = examples["train"].take(300000)
train_data

<TakeDataset shapes: ((), ()), types: (tf.string, tf.string)>

In [9]:
test_data = examples["train"].skip(300000).take(11556)
test_data

<TakeDataset shapes: ((), ()), types: (tf.string, tf.string)>

In [10]:
valid_data = examples["validation"]
valid_data

<PrefetchDataset shapes: ((), ()), types: (tf.string, tf.string)>

In [11]:
len(train_data)

300000

In [12]:
len(test_data)

11556

In [13]:
len(valid_data)

3981

In [14]:
for en, zh in train_data.take(1):
    print(en)
    print(zh)
    print('-' * 10)

tf.Tensor(b'The fear is real and visceral, and politicians ignore it at their peril.', shape=(), dtype=string)
tf.Tensor(b'\xe8\xbf\x99\xe7\xa7\x8d\xe6\x81\x90\xe6\x83\xa7\xe6\x98\xaf\xe7\x9c\x9f\xe5\xae\x9e\xe8\x80\x8c\xe5\x86\x85\xe5\x9c\xa8\xe7\x9a\x84\xe3\x80\x82 \xe5\xbf\xbd\xe8\xa7\x86\xe5\xae\x83\xe7\x9a\x84\xe6\x94\xbf\xe6\xb2\xbb\xe5\xae\xb6\xe4\xbb\xac\xe5\x89\x8d\xe9\x80\x94\xe5\xa0\xaa\xe5\xbf\xa7\xe3\x80\x82', shape=(), dtype=string)
----------


In [15]:
for en_t, zh_t in train_data.take(1):
    en = en_t.numpy().decode("utf-8")
    zh = zh_t.numpy().decode("utf-8")
    print(en)
    print(zh)
    print('-' * 10)

The fear is real and visceral, and politicians ignore it at their peril.
这种恐惧是真实而内在的。 忽视它的政治家们前途堪忧。
----------


## 建立中文與英文字典

In [65]:
def build_dict(data,path,lang):
    try:
        subword_encoder = tfds.deprecated.text.SubwordTextEncoder.load_from_file(path)
        print(f"載入已建立的字典： {path}")
    except:
        print("沒有已建立的字典，從頭建立。")
        if lang == "en":
            print("Build English Dict")
            subword_encoder = tfds.deprecated.text.SubwordTextEncoder.build_from_corpus(
          (w.numpy().decode("utf-8") for w, _ in data), 
          target_vocab_size=2**13) # 有需要可以調整字典大小
        if lang == "zh":
            print("Build Chinese Dict")
            subword_encoder = tfds.deprecated.text.SubwordTextEncoder.build_from_corpus(
          (zh.numpy().decode("utf-8") for _, zh in data), 
          target_vocab_size=2**13,
          max_subword_length=1 # 可以讓每個漢字都會被視為字典裡頭的一個單位
            )

    # 將字典檔案存下以方便下次 warmstart
    subword_encoder.save_to_file(path)
    print(f"字典大小：{subword_encoder.vocab_size}")
    print(f"前 10 個 subwords：{subword_encoder.subwords[:10]}")
    return subword_encoder

In [66]:
en_encoder = build_dict(train_data,en_vocab_file,'en')

載入已建立的字典： nmt\en_vocab
字典大小：8179
前 10 個 subwords：[', ', 'the_', 'to_', 'of_', 'and_', 's_', 'in_', 'a_', 'is_', 'that_']


### 所有在中文字典裡的字雖然有 4 萬個，但你只需其中 1 萬個就幾乎可以表達所有意思，其中常用字也才約 2000 個

In [79]:
t_zh = train_data.take(30000).cache()

In [80]:
zh_encoder = build_dict(t_zh,zh_vocab_file,'zh')

沒有已建立的字典，從頭建立。
Build Chinese Dict
字典大小：3879
前 10 個 subwords：['的', '，', '。', '国', '在', '是', '一', '和', '不', '这']


### 試試看轉成索引

In [81]:
sample_string = 'Taiwan is beautiful.'
indices = en_encoder.encode(sample_string)
indices

[3096, 7955, 9, 3023, 3544, 1281, 7969]

In [82]:
index_back_word(indices,en_encoder)

['Taiwan', ' ', 'is ', 'bea', 'uti', 'ful', '.']

In [88]:
sample_string = '欧元区的瓦解'
indices = zh_encoder.encode(sample_string)
indices

[44, 199, 174, 1, 841, 206]

In [89]:
def index_back_word(indices, encoder):
    res = []
    for idx in indices: 
        res.append(encoder.decode([idx]))
    return res

In [90]:
index_back_word(indices,zh_encoder)

['欧', '元', '区', '的', '瓦', '解']