In [None]:
import tensorflow as tf
from transformer_custom import Transformer, Encoder, Decoder
from model_configs import create_config
from tf_text_preprocess import EnglishTexts, RussianTexts
import pymorphy2

english_texts = EnglishTexts('english', text_files_paths='')
russian_texts = RussianTexts('russian', text_files_paths='')

for text_processor in [english_texts, russian_texts]:
    text_processor.load_datasets()
    text_processor.preprocess_datasets()

datasets_en = english_texts.get_datasets()
datasets_ru = russian_texts.get_datasets()

config = create_config.transformer_config()

encoder = Encoder(config["num_layers"], config["d_model"], config["num_heads"], config["dff"],
                 config["input_vocab_size"], pe_input=config["maximum_position_encoding"], 
                 rate=config["dropout_rate"])

decoder = Decoder(config["num_layers"], config["d_model"], config["num_heads"], config["dff"],
                 config["target_vocab_size"], pe_target=config["maximum_position_encoding"], 
                 rate=config["dropout_rate"])

transformer = Transformer(encoder, decoder)

loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

morph = pymorphy2.MorphAnalyzer()

@tf.function
def train_step(inp, tar):
    tar_inp = tar[:, :-1]
    tar_real = tar[:, 1:]

    with tf.GradientTape() as tape:
        predictions, _ = transformer(inp, tar_inp, 
                                     training=True, 
                                     enc_padding_mask=None, 
                                     look_ahead_mask=None,
                                     dec_padding_mask=None)
        loss = loss_object(tar_real, predictions)

    gradients = tape.gradient(loss, transformer.trainable_variables)    
    optimizer.apply_gradients(zip(gradients, transformer.trainable_variables))

    return loss

for epoch in range(10):
    for (batch_en, batch_ru) in zip(datasets_en, datasets_ru):
        inp_en, tar_en = batch_en
        inp_ru, tar_ru = batch_ru
        inp_ru = [morph.parse(word)[0].normal_form for word in inp_ru]
        loss_en = train_step(inp_en, tar_en)
        loss_ru = train_step(inp_ru, tar_ru)
        print(f'Epoch {epoch + 1} Loss EN {loss_en:.4f} Loss RU {loss_ru:.4f}')

    if (epoch + 1) % 5 == 0:
        transformer.save_weights(f'transformer_weights_epoch_{epoch+1}.h5')
