In [1]:
import time, re
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras import layers
from sklearn.model_selection import train_test_split

from tqdm import tqdm, tqdm_notebook

In [2]:
def preprocess_sentence(w):
    w = re.sub(r"[^a-zA-Z.!,?ążźśęćńół']+", " ", w)
    w = re.sub(r"([?.!,¿])", r" \1 ", w)
    w = re.sub(r'[" "]+', " ", w)
    
    words = []
    for word in w.split():
        if word[0].isupper():
            words.append('<up>')
        words.append(word.lower())
    
    words = ['<start>'] + words + ['<end>']
    return ' '.join(words) 

with open("pol.txt", 'rb') as f:
    texts = f.read().decode('utf-8').rstrip('\n')
    
pairs = [[preprocess_sentence(x) for x in pair.split('\t')[:2]] for pair in texts.split('\n')]
en_str, pl_str = zip(*pairs)

def tokenize(sequences):
    lang_tokenizer = tf.keras.preprocessing.text.Tokenizer(filters='')
    lang_tokenizer.fit_on_texts(sequences)
    sequences = lang_tokenizer.texts_to_sequences(sequences)
    return sequences, lang_tokenizer

en_seq, en_tokenizer = tokenize(en_str)
pl_seq, pl_tokenizer = tokenize(pl_str)
idx_to_take = [i for i, x in enumerate(en_seq) if len(x) < 20]

en_seq = [en_seq[i] for i in idx_to_take]
pl_seq = [pl_seq[i] for i in idx_to_take]
en_seq = tf.keras.preprocessing.sequence.pad_sequences(en_seq, padding='post')
pl_seq = tf.keras.preprocessing.sequence.pad_sequences(pl_seq, padding='post')
x_train, x_valid, y_train, y_valid = train_test_split(en_seq, pl_seq, test_size=0.1)

In [3]:
buffer_size = len(x_train)//5
batch_size = 64

steps_per_epoch = len(x_train) // batch_size
embedding_dim = 256
units = 1024

vocab_x_size = len(en_tokenizer.word_index) + 1
vocab_y_size = len(pl_tokenizer.word_index) + 1
x_valid = tf.keras.preprocessing.sequence.pad_sequences(x_valid, padding='post')
y_valid = tf.keras.preprocessing.sequence.pad_sequences(y_valid, padding='post')

dataset_tr = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(buffer_size)
dataset_tr = dataset_tr.batch(batch_size, drop_remainder=True)
dataset_tr = dataset_tr.prefetch(tf.data.experimental.AUTOTUNE)
dataset_tr_len = len([x for x, y in dataset_tr])

dataset_vd = tf.data.Dataset.from_tensor_slices((x_valid, y_valid)).batch(batch_size)
example_x, example_y = next(iter(dataset_vd))

start_token = pl_tokenizer.texts_to_sequences(['<start>'])
end_token = pl_tokenizer.texts_to_sequences(['<end>'])

# Training utils

In [4]:
helper_loss = tf.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')

def masked_loss(true, pred):
    losses = helper_loss(true, pred)
    float_mask = tf.cast(true!=0, dtype=tf.float32)
    return tf.reduce_mean(losses * float_mask)


@tf.function
def training_step(encoder_sequence, decoder_sequence):
    with tf.GradientTape() as tape:
        decoder_outputs = model(encoder_sequence, decoder_sequence)
        loss = model.loss(decoder_sequence, decoder_outputs)

    grads = tape.gradient(loss, model.trainable_variables)
    model.optimizer.apply_gradients(zip(grads, model.trainable_variables))
    return loss, decoder_outputs

@tf.function
def validation_step(encoder_sequence, decoder_sequence):
    decoder_outputs = model(encoder_sequence, decoder_sequence)
    loss = model.loss(decoder_sequence, decoder_outputs)
    return loss, decoder_outputs

loss_met = tf.metrics.Mean()
accu_met = tf.metrics.SparseCategoricalAccuracy()

def resres(metric):
    result = metric.result()
    metric.reset_states()
    return result

def train(model, epochs, writer):    
    for epoch in range(epochs):
        time_start = time.time()
        
        for x, y in tqdm(dataset_tr, total=dataset_tr_len):
            loss, outputs = training_step(x, y)
            accu_met(y, outputs)
            loss_met(loss)
            
        with writer.as_default():
            tf.summary.scalar('training/loss', loss_met.result(), step=epoch)
            tf.summary.scalar('training/accu', accu_met.result(), step=epoch)
            
        print(f"epoch {epoch:^5} | tr loss {resres(loss_met):^8.5f} | tr accu {resres(accu_met):^8.5f} | epoch time {time.time() - time_start:^8.2f}")
        
        for x, y in dataset_vd:
            loss, outputs = validation_step(x, y)
            accu_met(y, outputs)
            loss_met(loss)
            
        with writer.as_default():
            tf.summary.scalar('validation/loss', loss_met.result(), step=epoch)
            tf.summary.scalar('validation/accu', accu_met.result(), step=epoch)
            
            outputs = tf.argmax(model(example_x), -1)
            sample_x = '  \n'.join(en_tokenizer.sequences_to_texts(example_x.numpy()))
            sample_y = '  \n'.join(pl_tokenizer.sequences_to_texts(outputs.numpy()))
            tf.summary.text('en', sample_x, step=epoch)
            tf.summary.text('pl', sample_y, step=epoch)
            
        print(f"epoch {epoch:^5} | vd loss {resres(loss_met):^8.5f} | vd accu {resres(accu_met):^8.5f} | epoch time {time.time() - time_start:^8.2f}")

# Encoder-Decoder with 2 recurrent NN

In [5]:
%run EncDec.ipynb

model = ED(vocab_x_size, vocab_y_size, embedding_dim, units=829, start_token=start_token, end_token=end_token, max_output_length=pl_seq.shape[1])
output = model(example_x, example_y)

model.compile(tf.optimizers.Adam(), masked_loss)
model.summary()

Model: "ed"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
encoder (Encoder)            multiple                  4799753   
_________________________________________________________________
decoder (Decoder)            multiple                  25233525  
Total params: 30,033,278
Trainable params: 30,033,278
Non-trainable params: 0
_________________________________________________________________


In [None]:
writer = tf.summary.create_file_writer('logs/ed_2nets_1')
train(model, epochs=20, writer=writer)

# Encoder-Decoder with 1 recurrent NN

In [6]:
%run EncDec_in1.ipynb

model = EDinOne(vocab_x_size, vocab_y_size, embedding_dim, units=925, start_token=start_token, end_token=end_token, max_output_length=pl_seq.shape[1])
output = model(example_x, example_y)

model.compile(tf.optimizers.Adam(), masked_loss)
model.summary()

Model: "e_din_one"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
embedding (Embedding)        multiple                  2096384   
_________________________________________________________________
embedding_1 (Embedding)      multiple                  5310976   
_________________________________________________________________
gru (GRU)                    multiple                  3282825   
_________________________________________________________________
dense (Dense)                multiple                  19210796  
Total params: 29,900,981
Trainable params: 29,900,981
Non-trainable params: 0
_________________________________________________________________


In [8]:
writer = tf.summary.create_file_writer('logs/ed_in1_1')
train(model, epochs=20, writer=writer)

100%|████████████████████████████████████████████████████████████████████████████████| 538/538 [01:36<00:00,  5.60it/s]


epoch   0   | tr loss 2.04564  | tr accu 0.17470  | epoch time  96.15  
epoch   0   | vd loss 1.87348  | vd accu 0.19782  | epoch time  101.49 


100%|████████████████████████████████████████████████████████████████████████████████| 538/538 [01:32<00:00,  5.79it/s]


epoch   1   | tr loss 1.73157  | tr accu 0.20796  | epoch time  92.86  
epoch   1   | vd loss 1.79945  | vd accu 0.21111  | epoch time  96.01  


100%|████████████████████████████████████████████████████████████████████████████████| 538/538 [01:27<00:00,  6.12it/s]


epoch   2   | tr loss 1.63014  | tr accu 0.21754  | epoch time  87.92  
epoch   2   | vd loss 1.75810  | vd accu 0.21597  | epoch time  91.11  


100%|████████████████████████████████████████████████████████████████████████████████| 538/538 [01:27<00:00,  6.12it/s]


epoch   3   | tr loss 1.54572  | tr accu 0.22237  | epoch time  87.98  
epoch   3   | vd loss 1.72934  | vd accu 0.21724  | epoch time  91.26  


100%|████████████████████████████████████████████████████████████████████████████████| 538/538 [01:28<00:00,  6.11it/s]


epoch   4   | tr loss 1.47449  | tr accu 0.22695  | epoch time  88.13  
epoch   4   | vd loss 1.69472  | vd accu 0.22193  | epoch time  91.38  


100%|████████████████████████████████████████████████████████████████████████████████| 538/538 [01:28<00:00,  6.09it/s]


epoch   5   | tr loss 1.35519  | tr accu 0.23133  | epoch time  88.37  
epoch   5   | vd loss 1.47107  | vd accu 0.22769  | epoch time  91.65  


100%|████████████████████████████████████████████████████████████████████████████████| 538/538 [01:30<00:00,  5.96it/s]


epoch   6   | tr loss 1.10714  | tr accu 0.24786  | epoch time  90.29  
epoch   6   | vd loss 1.28999  | vd accu 0.25027  | epoch time  93.64  


100%|████████████████████████████████████████████████████████████████████████████████| 538/538 [01:25<00:00,  6.32it/s]


epoch   7   | tr loss 0.86597  | tr accu 0.27372  | epoch time  85.20  
epoch   7   | vd loss 1.14808  | vd accu 0.26812  | epoch time  88.77  


100%|████████████████████████████████████████████████████████████████████████████████| 538/538 [01:24<00:00,  6.38it/s]


epoch   8   | tr loss 0.60952  | tr accu 0.30274  | epoch time  84.29  
epoch   8   | vd loss 1.07905  | vd accu 0.27846  | epoch time  87.49  


100%|████████████████████████████████████████████████████████████████████████████████| 538/538 [01:26<00:00,  6.22it/s]


epoch   9   | tr loss 0.39800  | tr accu 0.33600  | epoch time  86.51  
epoch   9   | vd loss 1.05574  | vd accu 0.28363  | epoch time  90.05  


100%|████████████████████████████████████████████████████████████████████████████████| 538/538 [01:25<00:00,  6.30it/s]


epoch  10   | tr loss 0.26131  | tr accu 0.36354  | epoch time  85.40  
epoch  10   | vd loss 1.05784  | vd accu 0.28677  | epoch time  88.55  


100%|████████████████████████████████████████████████████████████████████████████████| 538/538 [01:31<00:00,  5.89it/s]


epoch  11   | tr loss 0.17944  | tr accu 0.38193  | epoch time  91.31  
epoch  11   | vd loss 1.06844  | vd accu 0.28781  | epoch time  94.49  


100%|████████████████████████████████████████████████████████████████████████████████| 538/538 [01:28<00:00,  6.06it/s]


epoch  12   | tr loss 0.12796  | tr accu 0.39429  | epoch time  88.83  
epoch  12   | vd loss 1.08992  | vd accu 0.28835  | epoch time  92.19  


100%|████████████████████████████████████████████████████████████████████████████████| 538/538 [01:28<00:00,  6.09it/s]


epoch  13   | tr loss 0.09463  | tr accu 0.40260  | epoch time  88.30  
epoch  13   | vd loss 1.11685  | vd accu 0.28813  | epoch time  91.55  


100%|████████████████████████████████████████████████████████████████████████████████| 538/538 [01:28<00:00,  6.09it/s]


epoch  14   | tr loss 0.07439  | tr accu 0.40749  | epoch time  88.28  
epoch  14   | vd loss 1.14401  | vd accu 0.28832  | epoch time  91.56  


100%|████████████████████████████████████████████████████████████████████████████████| 538/538 [01:27<00:00,  6.18it/s]


epoch  15   | tr loss 0.06168  | tr accu 0.41070  | epoch time  87.15  
epoch  15   | vd loss 1.16926  | vd accu 0.28812  | epoch time  90.43  


100%|████████████████████████████████████████████████████████████████████████████████| 538/538 [01:23<00:00,  6.43it/s]


epoch  16   | tr loss 0.05489  | tr accu 0.41233  | epoch time  83.71  
epoch  16   | vd loss 1.19119  | vd accu 0.28798  | epoch time  86.95  


100%|████████████████████████████████████████████████████████████████████████████████| 538/538 [01:23<00:00,  6.47it/s]


epoch  17   | tr loss 0.04978  | tr accu 0.41340  | epoch time  83.21  
epoch  17   | vd loss 1.21111  | vd accu 0.28824  | epoch time  86.45  


100%|████████████████████████████████████████████████████████████████████████████████| 538/538 [01:23<00:00,  6.45it/s]


epoch  18   | tr loss 0.04742  | tr accu 0.41380  | epoch time  83.41  
epoch  18   | vd loss 1.23902  | vd accu 0.28799  | epoch time  86.69  


100%|████████████████████████████████████████████████████████████████████████████████| 538/538 [01:25<00:00,  6.28it/s]


epoch  19   | tr loss 0.04661  | tr accu 0.41371  | epoch time  85.73  
epoch  19   | vd loss 1.26553  | vd accu 0.28655  | epoch time  89.35  


# Encoder-Decoder with attention

In [5]:
%run EncDec_attention.ipynb

model = ED_attention(vocab_x_size, vocab_y_size, embedding_dim, units=742, start_token=start_token, end_token=end_token, max_output_length=pl_seq.shape[1])
output = model(example_x, example_y)

model.compile(tf.optimizers.Adam(), masked_loss)
model.summary()

Model: "ed_attention"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
encoder (Encoder)            multiple                  4322384   
_________________________________________________________________
decoder (Decoder)            multiple                  25706301  
Total params: 30,028,685
Trainable params: 30,028,685
Non-trainable params: 0
_________________________________________________________________


In [6]:
writer = tf.summary.create_file_writer('logs/ed_attention_1')
train(model, epochs=20, writer=writer)

100%|████████████████████████████████████████████████████████████████████████████████| 538/538 [05:30<00:00,  1.63it/s]


epoch   0   | tr loss 2.00605  | tr accu 0.17915  | epoch time  330.63 
epoch   0   | vd loss 1.86187  | vd accu 0.19627  | epoch time  380.86 


100%|████████████████████████████████████████████████████████████████████████████████| 538/538 [04:44<00:00,  1.89it/s]


epoch   1   | tr loss 1.77607  | tr accu 0.20154  | epoch time  284.62 
epoch   1   | vd loss 1.79043  | vd accu 0.21105  | epoch time  307.49 


100%|████████████████████████████████████████████████████████████████████████████████| 538/538 [04:53<00:00,  1.83it/s]


epoch   2   | tr loss 1.62613  | tr accu 0.21838  | epoch time  293.46 
epoch   2   | vd loss 1.58607  | vd accu 0.22241  | epoch time  316.29 


100%|████████████████████████████████████████████████████████████████████████████████| 538/538 [04:45<00:00,  1.88it/s]


epoch   3   | tr loss 1.36625  | tr accu 0.23445  | epoch time  285.82 
epoch   3   | vd loss 1.33839  | vd accu 0.24252  | epoch time  308.60 


100%|████████████████████████████████████████████████████████████████████████████████| 538/538 [04:45<00:00,  1.89it/s]


epoch   4   | tr loss 1.06139  | tr accu 0.26009  | epoch time  285.33 
epoch   4   | vd loss 1.12573  | vd accu 0.26359  | epoch time  308.23 


100%|████████████████████████████████████████████████████████████████████████████████| 538/538 [04:31<00:00,  1.98it/s]


epoch   5   | tr loss 0.75171  | tr accu 0.28614  | epoch time  271.49 
epoch   5   | vd loss 1.02695  | vd accu 0.27690  | epoch time  295.24 


100%|████████████████████████████████████████████████████████████████████████████████| 538/538 [04:45<00:00,  1.88it/s]


epoch   6   | tr loss 0.49720  | tr accu 0.31782  | epoch time  285.43 
epoch   6   | vd loss 0.99030  | vd accu 0.28283  | epoch time  308.30 


100%|████████████████████████████████████████████████████████████████████████████████| 538/538 [04:48<00:00,  1.87it/s]


epoch   7   | tr loss 0.32664  | tr accu 0.34875  | epoch time  288.21 
epoch   7   | vd loss 0.98802  | vd accu 0.28814  | epoch time  309.44 


100%|████████████████████████████████████████████████████████████████████████████████| 538/538 [04:29<00:00,  2.00it/s]


epoch   8   | tr loss 0.23423  | tr accu 0.36837  | epoch time  269.62 
epoch   8   | vd loss 0.99067  | vd accu 0.28977  | epoch time  292.40 


100%|████████████████████████████████████████████████████████████████████████████████| 538/538 [04:45<00:00,  1.88it/s]


epoch   9   | tr loss 0.17576  | tr accu 0.38125  | epoch time  285.45 
epoch   9   | vd loss 1.00962  | vd accu 0.29072  | epoch time  308.24 


100%|████████████████████████████████████████████████████████████████████████████████| 538/538 [04:44<00:00,  1.89it/s]


epoch  10   | tr loss 0.13503  | tr accu 0.39119  | epoch time  284.61 
epoch  10   | vd loss 1.01984  | vd accu 0.29027  | epoch time  307.38 


100%|████████████████████████████████████████████████████████████████████████████████| 538/538 [04:44<00:00,  1.89it/s]


epoch  11   | tr loss 0.10395  | tr accu 0.39892  | epoch time  284.40 
epoch  11   | vd loss 1.04203  | vd accu 0.29070  | epoch time  307.22 


100%|████████████████████████████████████████████████████████████████████████████████| 538/538 [04:52<00:00,  1.84it/s]


epoch  12   | tr loss 0.08421  | tr accu 0.40384  | epoch time  292.46 
epoch  12   | vd loss 1.06249  | vd accu 0.29109  | epoch time  317.15 


100%|████████████████████████████████████████████████████████████████████████████████| 538/538 [04:35<00:00,  1.96it/s]


epoch  13   | tr loss 0.07152  | tr accu 0.40706  | epoch time  275.09 
epoch  13   | vd loss 1.08758  | vd accu 0.29153  | epoch time  296.68 


100%|████████████████████████████████████████████████████████████████████████████████| 538/538 [04:47<00:00,  1.87it/s]


epoch  14   | tr loss 0.06289  | tr accu 0.40911  | epoch time  287.92 
epoch  14   | vd loss 1.11297  | vd accu 0.29155  | epoch time  312.41 


100%|████████████████████████████████████████████████████████████████████████████████| 538/538 [04:41<00:00,  1.91it/s]


epoch  15   | tr loss 0.05795  | tr accu 0.41057  | epoch time  281.56 
epoch  15   | vd loss 1.13407  | vd accu 0.29116  | epoch time  303.60 


  2%|█▎                                                                                | 9/538 [00:04<04:48,  1.83it/s]


KeyboardInterrupt: 