In [11]:
import os
import re
import jieba

In [12]:
PROJECT_ROOT_PATH = os.path.abspath('.') + os.path.sep + os.path.join('..', '..')

In [13]:
def read_data_nmt():
    """载入“英语－法语”数据集"""
    # data_dir = d2l.download_extract('fra-eng')
    file_path = os.path.join(PROJECT_ROOT_PATH, 'data', 'cmn-eng', 'cmn.txt')
    with open(file_path, 'r', encoding='utf-8') as f:
        return f.read()

In [14]:
def preprocess_nmt(text):
    """预处理“英语－法语”数据集"""

    def no_space(char, prev_char):
        return char in set(',.!?') and prev_char != ' '

    # 删除【CC-BY 2.0 ...】
    text = re.sub('	CC-BY 2\.0.+&.+\)', '', text)
    # 使用空格替换不间断空格
    # 使用小写字母替换大写字母
    text = text.replace('\u202f', ' ').replace('\xa0', ' ').lower()
    # 在单词和标点符号之间插入空格
    out = [' ' + char if i > 0 and no_space(char, text[i - 1]) else char
           for i, char in enumerate(text)]
    return ''.join(out)


In [15]:
def tokenize_nmt(text, num_examples=None):
    """词元化“英语－法语”数据数据集"""
    source, target = [], []
    for i, line in enumerate(text.split('\n')):
        if num_examples and i > num_examples:
            break
        parts = line.split('\t')
        if len(parts) == 2:
            source.append(parts[0].split(' '))
            target.append(jieba.lcut(parts[1]))
    return source, target

In [16]:
raw_text = read_data_nmt()
text = preprocess_nmt(raw_text)
source, target = tokenize_nmt(text)
source[:6], target[:6]

([['hi', '.'],
  ['hi', '.'],
  ['run', '.'],
  ['stop', '!'],
  ['wait', '!'],
  ['wait', '!']],
 [['嗨', '。'],
  ['你好', '。'],
  ['你', '用', '跑', '的', '。'],
  ['住手', '！'],
  ['等等', '！'],
  ['等', '一下', '！']])

In [17]:
# token 数量直方图
import plotly.graph_objects as go


def show_list_len_pair_hist(xlist, ylist):
    x_list_len = [len(x) for x in xlist]
    y_list_len = [len(y) for y in ylist]
    x_cnt_dict = {key: 0 for key in set(x_list_len)}
    y_cnt_dict = {key: 0 for key in set(y_list_len)}
    for x, y in zip(x_list_len, y_list_len):
        x_cnt_dict[x] += 1
        y_cnt_dict[y] += 1

    fig = go.Figure()
    fig.add_trace(go.Bar(x=list(x_cnt_dict.keys()), y=list(x_cnt_dict.values()), name='EN'))
    fig.add_trace(go.Bar(x=list(y_cnt_dict.keys()), y=list(y_cnt_dict.values()), name='CN'))
    fig.show()
    # return fig

show_list_len_pair_hist(source,target)

In [18]:
from torchtext import vocab
from torchtext.legacy.data import TabularDataset
# vocab.Vocab()
print(source)

