In [36]:
from termcolor import colored
import random
import numpy as np

import trax
from trax import layers as tl
from trax.fastmath import numpy as fastnp
from trax.supervised import training

import w1_unittest

In [9]:
train_stream_fn = trax.data.TFDS('opus/medical',
                                 data_dir='./data/',
                                 keys=('en', 'de'),
                                 eval_holdout_size=0.01,
                                 train=True
                                 )
eval_stream_fn = trax.data.TFDS('opus/medical',
                                 data_dir='./data/',
                                 keys=('en', 'de'),
                                 eval_holdout_size=0.01,
                                 train=False
                                 )



In [13]:
train_stream = train_stream_fn()
print(colored('train data (en, de) tuple:', 'red'), next(train_stream))
print()

eval_stream = eval_stream_fn()
print(colored('eval data (en, de) tuple:', 'red'), next(eval_stream))

[31mtrain data (en, de) tuple:[0m (b'Tel: +421 2 57 103 777\n', b'Tel: +421 2 57 103 777\n')

[31meval data (en, de) tuple:[0m (b'Subcutaneous use and intravenous use.\n', b'Subkutane Anwendung und intraven\xc3\xb6se Anwendung.\n')


# Tokenize

In [26]:
VOCAB_FILE = 'ende_32k.subword'
VOCAB_DIR = 'data/'

tokenized_train_stream = trax.data.Tokenize(vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR)(train_stream)
tokenized_eval_stream = trax.data.Tokenize(vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR)(eval_stream)

In [27]:
# Append EOS at the end of each sentence
EOS = 1

def append_eos(stream):
    for (inputs, targets) in stream:
        inputs_with_eos = list(inputs) + [EOS]
        targets_with_eos = list(targets) + [EOS]
        yield np.array(inputs_with_eos), np.array(targets_with_eos)

tokenized_train_stream = append_eos(tokenized_train_stream)
tokenized_eval_stream = append_eos(tokenized_eval_stream)

## filter long sentences

In [28]:
filtered_train_stream = trax.data.FilterByLength(
    max_length=512, length_keys=[0, 1])(tokenized_train_stream)
filtered_eval_stream = trax.data.FilterByLength(
    max_length=512, length_keys=[0, 1])(tokenized_eval_stream)

train_input, train_target = next(filtered_train_stream)
print(colored(f'Single tokenized example input:', 'red'), train_input)
print(colored(f'Single tokenized example target:', 'red'), train_target)

[31mSingle tokenized example input:[0m [ 5148 10018   487  1646    64  5024 24618  5748    23  3445 23306     5
 30650  4729   992     1]
[31mSingle tokenized example target:[0m [10609 24618  5748    23  3445 23306     5  3550 30650  4729   992     1]


In [29]:
def tokenize(input_str, vocab_file=None, vocab_dir=None):
    EOS = 1
    inputs = next(trax.data.tokenize(iter([input_str]), vocab_file=vocab_file, vocab_dir=vocab_dir))
    inputs = list(inputs) + [EOS]

    batch_inputs = np.reshape(np.array(inputs), [1, -1])
    
    return batch_inputs

def detokenize(integers, vocab_file=None, vocab_dir=None):
    integers = list(np.squeeze(integers))

    EOS = 1

    if EOS in integers:
        integers = integers[:integers.index(EOS)]
    return trax.data.detokenize(integers, vocab_file=vocab_file, vocab_dir=vocab_dir)

In [31]:
print(colored(f'Single detokenized example input:', 'red'), detokenize(train_input, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR))
print(colored(f'Single detokenized example target:', 'red'), detokenize(train_target, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR))

print(colored(f"tokenize('hello'): ", 'green'), tokenize('hello', vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR))
print(colored(f"detokenize([17332, 140, 1]): ", 'green'), detokenize([17332, 140, 1], vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR))

[31mSingle detokenized example input:[0m Each tablet contains: irbesartan 150 mg

[31mSingle detokenized example target:[0m Irbesartan 150 mg.

[32mtokenize('hello'): [0m [[17332   140     1]]
[32mdetokenize([17332, 140, 1]): [0m hello


## Bucketing

In [32]:
boundaries = [8, 16, 32, 64, 128, 256, 512]
batch_sizes = [256, 128, 64, 32, 16, 8, 4, 2]

train_batch_stream = trax.data.BucketByLength(
    boundaries, batch_sizes,
    length_keys=[0, 1]
)(filtered_train_stream)

eval_batch_stream = trax.data.BucketByLength(
    boundaries, batch_sizes,
    length_keys=[0, 1]
)(filtered_eval_stream)

train_batch_stream = trax.data.AddLossWeights(id_to_mask=0)(train_batch_stream)
eval_batch_stream = trax.data.AddLossWeights(id_to_mask=0)(eval_batch_stream)

## Exploring the data

In [34]:
input_batch, target_batch, mask_batch = next(train_batch_stream)

print('input_batch data type: ', type(input_batch))
print('target_batch data type: ', type(target_batch))

print("input_batch shape: ", input_batch.shape)
print("target_batch shape: ", target_batch.shape)

input_batch data type:  <class 'numpy.ndarray'>
target_batch data type:  <class 'numpy.ndarray'>
input_batch shape:  (32, 64)
target_batch shape:  (32, 64)


In [35]:
# pick a random index less than the batch size.
index = random.randrange(len(input_batch))

# use the index to grab an entry from the input and target batch
print(colored('THIS IS THE ENGLISH SENTENCE: \n', 'red'), detokenize(input_batch[index], vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR), '\n')
print(colored('THIS IS THE TOKENIZED VERSION OF THE ENGLISH SENTENCE: \n ', 'red'), input_batch[index], '\n')
print(colored('THIS IS THE GERMAN TRANSLATION: \n', 'red'), detokenize(target_batch[index], vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR), '\n')
print(colored('THIS IS THE TOKENIZED VERSION OF THE GERMAN TRANSLATION: \n', 'red'), target_batch[index], '\n')

[31mTHIS IS THE ENGLISH SENTENCE: 
[0m DATE OF FIRST AUTHORISATION/RENEWAL OF THE AUTHORISATION
 

[31mTHIS IS THE TOKENIZED VERSION OF THE ENGLISH SENTENCE: 
 [0m [14344   493 11544  9868 18207   693  9227  7933 10547  3337  4205 10617
   123  3234  4017 13986  1047 11544  4327  9227  7933 10547  3337  4205
 10617 30650  4729   992     1     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0] 

[31mTHIS IS THE GERMAN TRANSLATION: 
[0m DATUM DER ERTEILUNG DER ERSTZULASSUNG / VERLÄNGERUNG DER ZULASSUNG
 

[31mTHIS IS THE TOKENIZED VERSION OF THE GERMAN TRANSLATION: 
[0m [14344 11459     5  9192  6304  6999  8001 18044  9192  6304 14075  1834
 25057  8066  8462 11553  1962 10468  2474  3449  9797  6304 18044  9192
  1834 25057  8066  8462 11553 30650  4729   992     1     0     0     0
     0     0     0     0     0     0   

# NMT with Attention

## input encoder

In [41]:
def input_encoder_fn(input_vocab_size, d_model, n_encoder_layers):
    input_encoder = tl.Serial(
        # convert tokens to vectors
        tl.Embedding(input_vocab_size, d_model),

        # feed embedding to lstm
        [tl.LSTM(d_model) for _ in range(n_encoder_layers)]
    )
    return input_encoder

In [42]:
# test input_encoder_fn
w1_unittest.test_input_encoder_fn(input_encoder_fn)

[92m All tests passed


## Pre-attention decoder

In [43]:
def pre_attention_decoder_fn(mode, target_vocab_size, d_model):
    pre_attention_decoder = tl.Serial(
        tl.ShiftRight(mode=mode),
        tl.Embedding(target_vocab_size, d_model),
        tl.LSTM(d_model)
    )
    return pre_attention_decoder

In [44]:
w1_unittest.test_pre_attention_decoder_fn(pre_attention_decoder_fn)

[92m All tests passed


## Prepare attention input

In [51]:
def prepare_attention_input(encoder_activations, decoder_activations, inputs):
    keys = encoder_activations
    values = encoder_activations

    queries = decoder_activations

    mask = fastnp.where(inputs>0, 1, 0)

    mask = fastnp.reshape(mask, (mask.shape[0], 1, 1, mask.shape[1]))
    
    # broadcast so mask shape is [batch size, attention heads, decoder-len, encoder-len].
    # note: for this assignment, attention heads is set to 1.
    mask = mask + fastnp.zeros((1, 1, decoder_activations.shape[1], 1))
        
    
    return queries, keys, values, mask

In [52]:
w1_unittest.test_prepare_attention_input(prepare_attention_input)

[92m All tests passed


## NMTAttn

In [59]:
def NMTAttn(input_vocab_size=33300,
            target_vocab_size=33300,
            d_model=1024,
            n_encoder_layers=2,
            n_decoder_layers=2,
            n_attention_heads=4,
            attention_dropout=0.0,
            mode='train'):
    # Step 0
    input_encoder = input_encoder_fn(input_vocab_size, d_model, n_encoder_layers)
    pre_attention_decoder = pre_attention_decoder_fn(mode, target_vocab_size, d_model)

    # Step 1
    model = tl.Serial(
        # Step 2
        tl.Select([0, 1, 0, 1]),

        # Step 3
        tl.Parallel(input_encoder, pre_attention_decoder),

        # Step 4
        tl.Fn('PrepareAttentionInput', prepare_attention_input, n_out=4),

        # Step 5
        tl.Residual(tl.AttentionQKV(d_model, n_heads=n_attention_heads, dropout=attention_dropout, mode=mode)),

        # Step 6
        tl.Select([0, 2]),
        
        # Step 7
        [tl.LSTM(d_model) for _ in range(n_decoder_layers)],

        # Step 8
        tl.Dense(target_vocab_size),

        # Step 9
        tl.LogSoftmax()
    )
    return model

In [61]:
model = NMTAttn()
print(model)

Serial_in2_out2[
  Select[0,1,0,1]_in2_out4
  Parallel_in2_out2[
    Serial[
      Embedding_33300_1024
      LSTM_1024
      LSTM_1024
    ]
    Serial[
      Serial[
        ShiftRight(1)
      ]
      Embedding_33300_1024
      LSTM_1024
    ]
  ]
  PrepareAttentionInput_in3_out4
  Serial_in4_out2[
    Branch_in4_out3[
      None
      Serial_in4_out2[
        _in4_out4
        Serial_in4_out2[
          Parallel_in3_out3[
            Dense_1024
            Dense_1024
            Dense_1024
          ]
          PureAttention_in4_out2
          Dense_1024
        ]
        _in2_out2
      ]
    ]
    Add_in2
  ]
  Select[0,2]_in3_out2
  LSTM_1024
  LSTM_1024
  Dense_33300
  LogSoftmax
]


In [62]:
w1_unittest.test_NMTAttn(NMTAttn)

[92m All tests passed


# Training

In [71]:
def train_task_function(train_batch_stream):
    return training.TrainTask(
        labeled_data=train_batch_stream,
        loss_layer=tl.CrossEntropyLoss(),
        optimizer=trax.optimizers.Adam(0.01),
        lr_schedule=trax.lr.warmup_and_rsqrt_decay(1000, 0.01),
        n_steps_per_checkpoint=10
    )

In [72]:
train_task = train_task_function(train_batch_stream)

In [73]:
w1_unittest.test_train_task(train_task_function)

[92m All tests passed


## Eval

In [75]:
eval_task = training.EvalTask(
    labeled_data=eval_batch_stream,
    metrics=[tl.CrossEntropyLoss(), tl.Accuracy()],
)

## Loop

In [76]:
output_dir = './output_dir/'
!rm -f ~/output_dir/model.pkg.gz

training_loop = training.Loop(NMTAttn(mode='train'),
                              train_task,
                              eval_tasks=[eval_task],
                              output_dir=output_dir)



In [77]:
training_loop.run(10)


Step      1: Total number of trainable weights: 148492820
Step      1: Ran 1 train steps in 87.48 secs
Step      1: train CrossEntropyLoss |  10.40531635
Step      1: eval  CrossEntropyLoss |  10.39260292
Step      1: eval          Accuracy |  0.00000000

Step     10: Ran 9 train steps in 175.00 secs
Step     10: train CrossEntropyLoss |  10.24884701
Step     10: eval  CrossEntropyLoss |  9.95088005
Step     10: eval          Accuracy |  0.02429765


# Testing

In [78]:
model = NMTAttn(mode='eval')

model.init_from_file('./output_dir/model.pkl.gz', weights_only=True)
model = tl.Accelerate(model)