In [1]:
import os
import numpy as np
import pandas as pd
import pickle
import json

import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, LSTM, Embedding, Activation
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, CSVLogger
import tensorflow_hub as hub
import tensorflow_datasets as tfds

print("Version: ", tf.__version__)
print("Eager mode: ", tf.executing_eagerly())
print("Hub version: ", hub.__version__)
print("GPU is", "available" if tf.config.list_physical_devices("GPU") else "NOT AVAILABLE")

Init Plugin
Init Graph Optimizer
Init Kernel
Version:  2.5.0
Eager mode:  True
Hub version:  0.12.0
GPU is available


# Constants

In [2]:
d_name = '20210812_wmt19_cs_en'
folder_name = '20210812_translate_mle'

wv_dim = 32
en_que_pad = 30
de_que_pad = 32

# Load data

In [3]:
encoder_train = pickle.load(open(f'{d_name}/encoder_train.pkl', 'rb'))
decoder_train = pickle.load(open(f'{d_name}/decoder_train.pkl', 'rb'))
teacher_train = pickle.load(open(f'{d_name}/teacher_train.pkl', 'rb'))
encoder_vali  = pickle.load(open(f'{d_name}/encoder_vali.pkl', 'rb'))
decoder_vali  = pickle.load(open(f'{d_name}/decoder_vali.pkl', 'rb'))
teacher_vali  = pickle.load(open(f'{d_name}/teacher_vali.pkl', 'rb'))

decoder_idx2word   = pickle.load(open(f'{d_name}/en_idx2word.pkl','rb'))
decoder_word2idx   = pickle.load(open(f'{d_name}/en_word2idx.pkl','rb'))
encoder_idx2word   = pickle.load(open(f'{d_name}/cs_idx2word.pkl','rb'))
encoder_word2idx   = pickle.load(open(f'{d_name}/cs_word2idx.pkl','rb'))

In [4]:
num_decoder_words = np.max([np.max(decoder_train), np.max(decoder_vali)])+1
num_encoder_words = np.max([np.max(encoder_train), np.max(encoder_vali)])+1

print(num_decoder_words)
print(num_encoder_words)

52575
52573


In [5]:
def seq2word(seq_tensor, idx2word):
    return np.array([[idx2word[str(i)] for i in seq] for seq in seq_tensor])

In [6]:
seq2word(decoder_vali[:10], decoder_idx2word)

array([['bos', '10', '000', 'gold', 'eos', 'pad', 'pad', 'pad', 'pad',
        'pad', 'pad', 'pad', 'pad', 'pad', 'pad', 'pad', 'pad', 'pad',
        'pad', 'pad', 'pad', 'pad', 'pad', 'pad', 'pad', 'pad', 'pad',
        'pad', 'pad', 'pad', 'pad', 'pad'],
       ['bos', 'san', 'francisco', '–', 'it', 'has', 'never', 'been',
        'easy', 'to', 'have', 'a', 'rational', 'conversation', 'about',
        'the', 'value', 'of', 'gold', 'eos', 'pad', 'pad', 'pad', 'pad',
        'pad', 'pad', 'pad', 'pad', 'pad', 'pad', 'pad', 'pad'],
       ['bos', 'lately', 'with', 'gold', 'prices', 'up', 'more', 'than',
        '300', 'over', 'the', 'last', 'decade', 'it', 'is', 'harder',
        'than', 'ever', 'eos', 'pad', 'pad', 'pad', 'pad', 'pad', 'pad',
        'pad', 'pad', 'pad', 'pad', 'pad', 'pad', 'pad'],
       ['bos', 'just', 'last', 'december', 'fellow', 'economists',
        'martin', 'feldstein', 'and', 'nouriel', 'roubini', 'each',
        'penned', 'op', 'eds', 'bravely', 'questioning

# Generator

[OneHot relaxation](https://www.tensorflow.org/probability/api_docs/python/tfp/distributions/RelaxedOneHotCategorical)
based on [Jang+16](https://arxiv.org/abs/1611.01144) and [Maddison+16](https://arxiv.org/abs/1611.00712) 
<br>
[Multi-Head Attention](https://www.tensorflow.org/api_docs/python/tf/keras/layers/MultiHeadAttention) 
and [Positional Encoding](https://www.tensorflow.org/text/tutorials/transformer)
based on [Vaswani+17, Attention is All You Need](https://arxiv.org/abs/1706.03762)

In [7]:
def getG():
    
    # Encoder
    en_input = Input((en_que_pad,))
    en_emb = Embedding(
        num_encoder_words, 
        wv_dim, 
        mask_zero = False,
        input_length = en_que_pad,
        trainable = True,
    )
    vector_encoder = en_emb(en_input)
    
    # Deocder
    de_input = Input((de_que_pad,))
    de_emb = Embedding(
        num_decoder_words, 
        wv_dim, 
        mask_zero = False,
        input_length = de_que_pad,
        trainable = True,
    )
    vector_decoder = de_emb(de_input)
    mem, state_h, state_c = LSTM(wv_dim, return_state=True)(vector_encoder)
    output=LSTM(wv_dim, return_sequences=True)(vector_decoder, initial_state=[state_h, state_c] )
    # Concatenation and output
    output = Dense(num_decoder_words)(output)
    output = Activation('softmax')(output)
    model = Model(
        [en_input, de_input], 
        output,
    ) 
    return model


In [8]:
mleG=getG()
mleG.compile(
    loss='sparse_categorical_crossentropy',
    optimizer=Adam(),
    metrics=['accuracy'],
)
mleG.summary()

Metal device set to: Apple M1
Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 30)]         0                                            
__________________________________________________________________________________________________
input_2 (InputLayer)            [(None, 32)]         0                                            
__________________________________________________________________________________________________
embedding (Embedding)           (None, 30, 32)       1682336     input_1[0][0]                    
__________________________________________________________________________________________________
embedding_1 (Embedding)         (None, 32, 32)       1682400     input_2[0][0]                    
________________________________________________________________

2021-08-14 16:02:47.834596: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2021-08-14 16:02:47.834678: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


In [9]:
mleG.fit(
    [encoder_train, decoder_train], 
    teacher_train, 
    batch_size=64, 
    epochs=300, 
    shuffle=True, 
    validation_data = (
        [encoder_vali, decoder_vali], 
        teacher_vali
    ), 
    callbacks=[
        ModelCheckpoint(
            f'./{folder_name}/mleG.h5', 
            save_best_only=True, 
            monitor = "val_loss"
        ),
        EarlyStopping(monitor='val_loss', patience=5),
        CSVLogger(f'{folder_name}/mleG.csv'),
    ]
)

Epoch 1/300


2021-08-14 16:02:48.057736: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:176] None of the MLIR Optimization Passes are enabled (registered 2)
2021-08-14 16:02:48.057916: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz
2021-08-14 16:02:48.826842: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.
2021-08-14 16:02:49.007855: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.
2021-08-14 16:02:49.159475: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.
2021-08-14 16:02:49.628379: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.


   1/2716 [..............................] - ETA: 1:30:40 - loss: 10.8699 - accuracy: 0.0000e+00

2021-08-14 16:02:49.892891: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.




2021-08-14 16:21:11.742186: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.
2021-08-14 16:21:11.830678: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.
2021-08-14 16:21:11.995634: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.


Epoch 2/300
Epoch 3/300
  11/2716 [..............................] - ETA: 17:52 - loss: 3.9195 - accuracy: 0.4812

KeyboardInterrupt: 

## Inference train D (False) 

In [None]:
def inference(
    model,
    enData = None,
    inpData = None,
    start_on = 0,
    end_on = 10,
    num_data = 10, 
    batch_size = 10,
    que_pad = 10,
):
    # Initialize
    num_batch = (num_data-1)//batch_size +1
    resp_pred_list = None
    in_batch_list = None
    the_first = True
    #idx = np.arange(num_words)
    for b in range(num_batch):
        in_batch = np.zeros((batch_size, que_pad), dtype = int)
        if start_on == 0:
            in_batch[:,0] = word2idx['BOS']
        else:
            in_batch = inpData[b*batch_size:(b+1)*batch_size]
            en_batch = enData[b*batch_size:(b+1)*batch_size]
        resp_pred = np.zeros((batch_size, que_pad), dtype = int)
        # Generate the sequence recurrsively.
        for i in range(start_on, end_on):
            # Run
            resp_pred_wv = model([en_batch, in_batch])
            the_last = resp_pred_wv[:,i]
            #We the_last = tf.keras.backend.argmax(the_last).numpy()
            the_last = tf.reshape(
                tf.random.categorical(tf.math.log(the_last), 1), 
                [batch_size,]
            )
            try:
                resp_pred[:,i] = the_last
                in_batch[:,i+1] = the_last
            except:
                resp_pred[:,i] = the_last
        for i in range(len(resp_pred)):
            try:
                index = list(resp_pred[i]).index(word2idx['EOS'])
            except:
                continue
            resp_pred[i,index+1:] = 0
            in_batch[i,index+1:] = 0
        if the_first:
            resp_pred_list = resp_pred
            in_batch_list = in_batch
            the_first = False
        else:
            resp_pred_list = np.vstack((resp_pred_list, resp_pred))
            in_batch_list = np.vstack((in_batch_list, in_batch))
    resp_pred_list = resp_pred_list[:num_data]
    in_batch_list = in_batch_list[:num_data]
    if start_on != 0:
        resp_pred_list[:,:start_on] = inpData[:,1:start_on+1]
        in_batch_list[:, :start_on+1] = inpData[:,:start_on+1]
    return resp_pred_list, in_batch_list