<a href="https://colab.research.google.com/github/dabingooo/Trax-examples/blob/master/trax_transformer_%E5%B0%86%E5%90%84%E7%A7%8D%E6%97%A5%E6%9C%9F%E6%A0%BC%E5%BC%8F%E8%BD%AC%E4%B8%BA%E6%A0%87%E5%87%86%E6%A0%BC%E5%BC%8F.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 将各种日期格式转为标准格式 by Trax Transformer模型
本代码是作为使用google的trax库的例子. 本来是打算做机器翻译(中文->英文), 但因为资源不足,训练时间长,且很可能没有好的结果.所以将问题简化为将各种日期格式->标准日期的训练.

**转化示例**:
> 各种日期格式&nbsp;&nbsp;&nbsp;->&nbsp;&nbsp;标准日期格式<br/>
'10.11.19'&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;->&nbsp;&nbsp;'2019-11-10'<br/>
'1970/9/10'&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;->&nbsp;&nbsp;'1970-09-10'<br/>
'1990年4月28日星期六'&nbsp;&nbsp;->&nbsp;&nbsp;'1990-04-28'<br/>


In [75]:
import os
import numpy as np
!pip install -q -U trax
import trax
from trax import layers as tl
!pip install -q faker
!pip install -q tqdm
!pip install -q babel

下面函数用于生成数据.此代码源自吴恩达Deep-Learning-Specialization-Coursera课程,此处代码拷贝自[Here](https://github.com/AdalbertoCq/Deep-Learning-Specialization-Coursera/blob/master/Sequence%20Models/week3/Neural%20machine%20translation%20with%20attention/nmt_utils.py)
,有较大改动..

In [76]:
import numpy as np
from faker import Faker
import random
from tqdm import tqdm
from babel.dates import format_date

fake = Faker()
Faker.seed(12345)
random.seed(12345)

# Define format of the data we would like to generate
FORMATS = ['short',
           'medium',
           'long',
           'full',
           'full',
           'full',
           'full',
           'full',
           'full',
           'full',
           'full',
           'full',
           'full',
           'd MMM YYY', 
           'd MMMM YYY',
           'dd MMM YYY',
           'd MMM, YYY',
           'd MMMM, YYY',
           'dd, MMM YYY',
           'd MM YY',
           'd MMMM YYY',
           'MMMM d YYY',
           'MMMM d, YYY',
           'dd.MM.YY']

# change this if you want it to work with another language
LOCALES = ['en_US']

def load_date():
    """
        Loads some fake dates 
        :returns: tuple containing human readable string, machine readable string, and date object
    """
    dt = fake.date_object()

    try:
        human_readable = format_date(dt, format=random.choice(FORMATS),  locale='zh_CN') # locale=random.choice(LOCALES))
        human_readable = human_readable.lower()
        human_readable = human_readable.replace(',','')
        machine_readable = dt.isoformat()
        
    except AttributeError as e:
        return None, None, None

    return human_readable, machine_readable, dt

def load_dataset_yield():
  """
    Loads a dataset with m examples and vocabularies
    :m: the number of examples to generate
  """       
  while True:
    h, m, _ = load_date()
    if h is not None:
      h += '#'
      m += '#'
      yield (h, m)

print('示例数据(井号键(#)表示结束符):')
print(next(load_dataset_yield()))
print(next(load_dataset_yield()))
print(next(load_dataset_yield()))

#========可读日期vocab=========    
human_vocab = {'@': 0, '#': 1, ' ': 2, '.': 3, '/': 4, '0': 5, '1': 6, '2': 7, '3': 8, '4': 9, '5': 10, '6': 11, '7': 12, '8': 13, '9': 14, '一': 15, '七': 16, '三': 17, '九': 18, '二': 19, '五': 20, '八': 21, '六': 22, '十': 23, '四': 24, '年': 25, '日': 26, '星': 27, '月': 28, '期': 29}
#========机器日期vocab=========
machine_vocab = {'@': 0, '#': 1, '-': 2, '0': 3, '1': 4, '2': 5, '3': 6, '4': 7, '5': 8, '6': 9, '7': 10, '8': 11, '9': 12}
#========机器日期 反vocab=========
inv_machine_vocab = {0: '@', 1: '#', 2: '-', 3: '0', 4: '1', 5: '2', 6: '3', 7: '4', 8: '5', 9: '6', 10: '7', 11: '8', 12: '9'}    

max_len = 32
batch_size = 64

示例数据(井号键(#)表示结束符):
('9 5月 1998#', '1998-05-09#')
('10.11.19#', '2019-11-10#')
('1970/9/10#', '1970-09-10#')


因为trax中的tokenize()和detokenize()函数,如果词汇表类型(vocab_type)使用'char'方式是不能指定自己词汇表(指定了内部也不用).为了使用自己的词汇表所以自己简单实现tokenize()和detokenize()功能:

In [77]:
def tok(data, dic):
  '''
  将字符串转为整数token
  data: String类型
  dic: 词汇表字典
  return: np.array类型
  '''
  s = []
  for c in data:
    s.append(dic[c])
  return np.array(s) 

def detok(data, dic):
  '''
  将 token(np.array类型) 转为 字符串/列表
  data: token, 如 np.array.shape = (batch_size, seq_length)
  dic: 词汇表字典
  return: 字符串/列表, 如'abc' 或 ['abc', 'def']
  ''' 
  if len(data.shape)>2:
    raise ValueError(f'The dim of input can NOT > 2. the dim of input is {len(data.shape)}')
  data = data if len(data.shape) > 1 else data[None, :]
  l = ['' for _ in range(data.shape[0])]
  for i in range(data.shape[0]):
    for j in range(data.shape[1]):
      l[i] += dic[data[i][j]]
  return np.squeeze(l)

数据预处理:

In [78]:
def tok_tuple_yield(data, dic, axis=0):
  for da in data:
    l = tok(da[axis], dic)
    yield (l, da[1]) if axis==0 else (da[0], l)

input_pip = trax.data.Serial(
  lambda _: load_dataset_yield(),
  lambda x: tok_tuple_yield(x, human_vocab, 0),
  lambda x: tok_tuple_yield(x, machine_vocab, 1),
  trax.data.FilterByLength(max_length=max_len, length_keys=[0, 1]),
  trax.data.BucketByLength(boundaries=[max_len],
              batch_sizes=[batch_size, 1],
              length_keys=[0, 1],
              strict_pad_on_len=True),  
  trax.data.AddLossWeights(id_to_mask=0),
)

train_batches_stream = input_pip()
eval_batches_stream = input_pip()

o = next(train_batches_stream)
print(f'第一批数据:{o}')
print(f'第一批数据tuple中每个数据shape: {[x.shape for x in o]}')

第一批数据:(array([[ 6, 14, 14, ...,  0,  0,  0],
       [ 6, 14, 14, ...,  0,  0,  0],
       [ 6, 14, 13, ...,  0,  0,  0],
       ...,
       [24, 28,  2, ...,  0,  0,  0],
       [ 6, 14, 14, ...,  0,  0,  0],
       [ 6, 14, 12, ...,  0,  0,  0]]), array([[ 4, 12, 12, ...,  0,  0,  0],
       [ 4, 12, 12, ...,  0,  0,  0],
       [ 4, 12, 11, ...,  0,  0,  0],
       ...,
       [ 5,  3,  4, ...,  0,  0,  0],
       [ 4, 12, 12, ...,  0,  0,  0],
       [ 4, 12, 10, ...,  0,  0,  0]]), array([[1., 1., 1., ..., 0., 0., 0.],
       [1., 1., 1., ..., 0., 0., 0.],
       [1., 1., 1., ..., 0., 0., 0.],
       ...,
       [1., 1., 1., ..., 0., 0., 0.],
       [1., 1., 1., ..., 0., 0., 0.],
       [1., 1., 1., ..., 0., 0., 0.]], dtype=float32))
第一批数据tuple中每个数据shape: [(64, 32), (64, 32), (64, 32)]


创建Transformer训练模型:

In [79]:
# MODEL
def create_model(mode = 'train'):
  return trax.models.Transformer(
      input_vocab_size=len(human_vocab),
      output_vocab_size=len(machine_vocab),
      d_model=32, d_ff=128,
      n_heads=8, n_encoder_layers=2, n_decoder_layers=2,
      max_len=max_len, mode=mode)

# UNUSED
def create_model_reformer(mode = 'train'):
  return trax.models.Reformer(input_vocab_size=8269,
              output_vocab_size=8185,
              d_model=256,
              d_ff=1024,
              n_encoder_layers=2,
              n_decoder_layers=2,
              n_heads=8,
              dropout=0.1,
              max_len=258,
              ff_activation=tl.Relu,
              ff_dropout=None,
              mode=mode)
          

训练(注:使用GPU可以加快训练速度):

In [80]:
# TRAIN
from trax.supervised import training

# Training task.
train_task = training.TrainTask(
    labeled_data=train_batches_stream,
    loss_layer=tl.CrossEntropyLoss(),
    optimizer=trax.optimizers.Adafactor(0.02),
    #optimizer=trax.optimizers.Adam(learning_rate=0.1, weight_decay_rate=1e-05, b1=0.9, b2=0.98, eps=1e-06, clip_grad_norm=None),
    n_steps_per_checkpoint=300,
)

# Evaluaton task.
eval_task = training.EvalTask(
    labeled_data=eval_batches_stream,
    metrics=[tl.CrossEntropyLoss(), tl.Accuracy()],
    n_eval_batches=20  # For less variance in eval numbers.
)

# Training loop saves checkpoints to output_dir.
output_dir = os.path.expanduser('~/output_dir/')
print(output_dir)
!rm -rf {output_dir}
training_loop = training.Loop(create_model(),
                train_task,
                eval_tasks=[eval_task],
                output_dir=output_dir)


# Run
training_loop.run(1200)

/root/output_dir/

Step      1: Ran 1 train steps in 18.79 secs
Step      1: train CrossEntropyLoss |  2.84239006
Step      1: eval  CrossEntropyLoss |  2.52507530
Step      1: eval          Accuracy |  0.18465910

Step    300: Ran 299 train steps in 35.78 secs
Step    300: train CrossEntropyLoss |  0.83705121
Step    300: eval  CrossEntropyLoss |  0.51364623
Step    300: eval          Accuracy |  0.79950286

Step    600: Ran 300 train steps in 17.25 secs
Step    600: train CrossEntropyLoss |  0.23055305
Step    600: eval  CrossEntropyLoss |  0.07081541
Step    600: eval          Accuracy |  0.97890628

Step    900: Ran 300 train steps in 17.21 secs
Step    900: train CrossEntropyLoss |  0.04028004
Step    900: eval  CrossEntropyLoss |  0.02603501
Step    900: eval          Accuracy |  0.99140628

Step   1200: Ran 300 train steps in 17.20 secs
Step   1200: train CrossEntropyLoss |  0.01672226
Step   1200: eval  CrossEntropyLoss |  0.01874784
Step   1200: eval          Accuracy |  0.995

测试模型效果:

In [81]:
for _ in range(10):
  dat = next(load_dataset_yield())
  # 模型必须每次重新加载, 因为调用autoregressive_sample()时,会改变model的状态值
  model = create_model('predict') 
  model.init_from_file(output_dir+ '/model.pkl.gz', weights_only=True, 
            input_signature=[trax.shapes.ShapeDtype((1, 1), np.int32), 
                      trax.shapes.ShapeDtype((1, 1), np.int32),
                      trax.shapes.ShapeDtype((1, 1), np.float32)])
  # Tokenize a sentence.
  test_source = dat[0]
  test_target = dat[1]
  test_tok = tok(test_source, human_vocab)
  test_tok = test_tok[None, :]

  res_tok = trax.supervised.decoding.autoregressive_sample(
      model, inputs=test_tok, batch_size=1, temperature=0.5, 
      start_id=human_vocab['@'], eos_id=machine_vocab['#'], max_length=max_len, accelerate=False)

  res = detok(res_tok, inv_machine_vocab)
  print('======================================================')
  print(f'输入: {test_source}')
  print(f'token(输入): {test_tok}')
  print(f'token(输出): {res_tok}')
  print(f'输出: {res}')
  print(f'真值: {test_target}')

输入: 2020年7月10日星期五#
token(输入): [[ 7  5  7  5 25 12 28  6  5 26 27 29 20  1]]
token(输出): [[ 5  3  5  3  2  3 10  2  4  3  1]]
输出: 2020-07-10#
真值: 2020-07-10#
输入: 27 10月 2017#
token(输入): [[ 7 12  2  6  5 28  2  7  5  6 12  1]]
token(输出): [[ 5  3  4 10  2  4  3  2  5 10  1]]
输出: 2017-10-27#
真值: 2017-10-27#
输入: 16 06 83#
token(输入): [[ 6 11  2  5 11  2 13  8  1]]
token(输出): [[ 4 12 11  6  2  3  9  2  4  9  1]]
输出: 1983-06-16#
真值: 1983-06-16#
输入: 1978年5月14日#
token(输入): [[ 6 14 12 13 25 10 28  6  9 26  1]]
token(输出): [[ 4 12 10 11  2  3  8  2  4  7  1]]
输出: 1978-05-14#
真值: 1978-05-14#
输入: 1995年5月30日星期二#
token(输入): [[ 6 14 14 10 25 10 28  8  5 26 27 29 19  1]]
token(输出): [[ 4 12 12  8  2  3  8  2  6  3  1]]
输出: 1995-05-30#
真值: 1995-05-30#
输入: 七月 9 1981#
token(输入): [[16 28  2 14  2  6 14 13  6  1]]
token(输出): [[ 4 12 11  4  2  3 10  2  3 12  1]]
输出: 1981-07-09#
真值: 1981-07-09#
输入: 1991年4月3日星期三#
token(输入): [[ 6 14 14  6 25  9 28  8 26 27 29 17  1]]
token(输出): [[ 4 12 12  4  2  3  7  2  3  6  1]]
