In this file we'll initialize, create, and train the simple RNN model.


In [27]:
from tensorflow.keras.preprocessing.text import one_hot
from tensorflow.keras.utils import pad_sequences
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Embedding
import numpy as np

# Example list of sentences for RNN training
sentences = [
    "The quick brown fox jumps over the lazy dog.",
    "Artificial intelligence is transforming many industries.",
    "She enjoys reading books on rainy afternoons.",
    "The RNN model learns patterns from sequential data.",
    "Python is a popular language for machine learning.",
    "The sun rises in the east and sets in the west.",
    "Training neural networks requires a lot of data.",
    "He drank a cup of coffee before starting work.",
    "The cat slept peacefully on the windowsill.",
    "LangChain and Huggingface are useful AI frameworks."
]

Now we start with the text preprocessing.

In [28]:
sentences = [s.lower() for s in sentences]  # Convert to lowercase
sentences = [s.replace('.', '') for s in sentences]  # Remove punctuation  
sentences = [s.replace(',', '') for s in sentences]  # Remove punctuation

# Consider vocabulary size
vocab_size = 10000  # Size of the vocabulary
for word in sentences:
    one_hot_encoded = one_hot(word, vocab_size)

# Define maximum length of sequences
max_length = 10  # Maximum length of sequences

# Pad sequences to ensure uniform input size
padded_sequences = pad_sequences([one_hot(word, vocab_size) for word in sentences],
                                 maxlen=max_length, padding='pre')

print(padded_sequences)

feature_representation_size = 10 # Size of the feature representation
model = Sequential()
model.add(Embedding(input_dim=vocab_size, output_dim=feature_representation_size, input_length=max_length))
model.compile('rmsprop', 'mse')
print(model.predict(padded_sequences[0:1]))  # Predict for the first padded sequence
model.summary()

[[   0 7767 6120 1209 4368 9705 9962 7767 7983 2071]
 [   0    0    0    0 9614 8059 5210 6930 6241 5238]
 [   0    0    0 6871 1172 2807 5176 5068 9364 2143]
 [   0    0 7767 6908 1065 6134 6260 7857 5225  755]
 [   0    0 4959 5210 4005 8632 4753 1413 3477 2250]
 [6017 3761 7465 7767 5712 1823 3288 7465 7767 2074]
 [   0    0 1871 6438 2966 6985 4005 2203 8561  755]
 [   0 6144 6328 4005  922 8561 8824 4821 3332 5875]
 [   0    0    0 7767 3747 3439 3834 5068 7767 5726]
 [   0    0    0 4852 1823  396 4693 5960 4112 5143]]
[[[ 4.4232044e-02  3.3626232e-02  4.3028865e-02 -4.9760826e-03
   -4.9930599e-02  1.9013677e-02  4.5692850e-02 -2.8543925e-02
   -1.3918057e-03 -3.1095529e-02]
  [ 1.3099853e-02 -1.2134980e-02  2.7034853e-02 -4.8502397e-02
   -4.8449170e-02  4.0927995e-02  1.0860957e-02 -1.6312040e-02
    4.1677546e-02  2.0131204e-02]
  [ 2.3839559e-02  1.1894595e-02 -4.1097678e-02 -4.4358280e-02
    2.0778064e-02  4.6904314e-02  4.4753764e-02 -1.5068509e-02
    7.1938150e-03  9.81