In [1]:
import random
import string
import re
import pickle
import numpy as np
from tensorflow.keras import layers
import tensorflow as tf

from data_processing import text_standardization


2024-01-18 17:20:14.488106: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


# Data processing

In [2]:
# Goal: translate language in input_text to language in output_text
input_text = pickle.load(open('data/Train_input', 'rb'))
output_text = pickle.load(open('data/Train_output', 'rb'))

In [3]:
# adding starting and end tokens
count = 0
text_pairs = []
for i in range(len(output_text)):
    out = "[start] " + output_text[i] + "[end]"
    if len(output_text[i]) != len(input_text[i]):
        count += 1
    text_pairs.append((input_text[i], out))
print("number of input and output texts with different lengths: %d" % count)

number of input and output texts with different lengths: 112000


In [4]:
# data shuffle
print(f"randomly selected input and output text pairs:\n{random.choice(text_pairs)}")
random.shuffle(text_pairs)

randomly selected input and output text pairs:
('a f b d a g a g c g a f a f a h c d c f a f c f c g b d b d b d ', '[start] b d c g c d c f c f c g a f h i a h f g j b d a f k l b d a f m ed a g e ee b d a g ef eg a f d eh [end]')


In [5]:
from sklearn.model_selection import train_test_split
# train: 70%, val: 15%, and test: 15% of data
trainval_pairs, test_pairs = train_test_split(text_pairs, test_size=0.15, random_state=3)
train_pairs, val_pairs = train_test_split(trainval_pairs, test_size=0.15/.85, random_state=3)

pickle.dump(train_pairs, open("data/train_pairs.pkl", "wb"))
pickle.dump(val_pairs, open("data/val_pairs.pkl", "wb"))
pickle.dump(test_pairs, open("data/test_pairs.pkl", "wb"))

In [7]:
# dimensions based on the dataset analysis
input_vocab_size = 8 + 2  # +2 for "" and Unkown

# max text length in both input and output text is 47
# but I increased to 55 to account for longer texts that may exist in unseen data
output_seq_len = 55
input_seq_len = output_seq_len
output_vocab_size = 18 + 2 + 2  # +2 for "" and Unkown and +2 for [start] and [end]

In [8]:
# initialize instances of vectorization layers both for source(input language) and target(output language)
source_vectorization = layers.TextVectorization(
    max_tokens = input_vocab_size,
    output_mode = "int",
    output_sequence_length = input_seq_len,
    standardize = text_standardization,
)
target_vectorization = layers.TextVectorization(
    max_tokens = output_vocab_size,
    output_mode = "int",
    output_sequence_length = output_seq_len + 1,
    standardize = text_standardization,
)

2024-01-18 17:20:28.551098: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [9]:
train_input_texts = [pair[0] for pair in train_pairs]  # list of input texts in train data
train_output_texts = [pair[1] for pair in train_pairs]  # list of output texts in train data
# learn the vocabulary of each language to vectorize tokens with shape(seq_len,)
# and each vocab will be assigned a number from 1:vocab_size and 0 for masking
source_vectorization.adapt(train_input_texts)
target_vectorization.adapt(train_output_texts)

In [10]:
# save the text vectorizations to load them when reloading the model
pickle.dump({'config': target_vectorization.get_config(),
             'weights': target_vectorization.get_weights()},
           open("model/target_vectorization8.pkl", "wb"))
pickle.dump({'config': source_vectorization.get_config(),
             'weights': source_vectorization.get_weights()},
           open("model/source_vectorization8.pkl", "wb"))