<a href="https://colab.research.google.com/github/mohammadreza-mohammadi94/Deep-Learning-Projects/blob/main/Next%20Word%20Prediction%20-%20Wikipedia%20Dataset%20(LSTM)/Next_Word_Prediction_RNNs.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Import Frameworks and Setup Enviorment

In [1]:
# Install libs
!pip install -q datasets

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/491.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m491.5/491.5 kB[0m [31m20.0 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.5/491.5 kB[0m [31m8.1 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/116.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m5.1 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/193.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m193.6/193.6 kB[0m [31m6.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m143.5/143.5 kB[0m [31m8.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [9]:
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, LSTM, Dense, Embedding
from tensorflow.keras.optimizers import Adam
import datasets
from datasets import load_dataset
import re
import numpy as np
import matplotlib.pyplot as plt

# Setup warnings
import warnings
warnings.filterwarnings('ignore')

# Setup logging
import logging
logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s - %(levelname)s - %(message)s',
                    handlers=[
                        logging.FileHandler('training.log'),
                        logging.StreamHandler()
                    ])

# Load Dataset

In [3]:
def load_wiki_dataset():
    dataset = load_dataset("wikipedia", "20220301.simple")
    texts = dataset['train']['text']
    cleaned_texts = []
    for text in texts[:1000]:
        text = re.sub(r'\[.*?\]|\n|\t', ' ', text)
        text = re.sub(r'[^a-zA-Z\s]', ' ', text)
        sentences = text.lower().split('.')
        cleaned_texts.extend([s.strip() for s in sentences if len(s.strip().split()) > 2])
    return cleaned_texts

# Preparing Data

In [4]:
def prepare_data(texts):
    # tokenization
    tokenizer = Tokenizer()
    tokenizer.fit_on_texts(texts)
    sequences = tokenizer.texts_to_sequences(texts)

    # X, y
    X, y = [], []
    max_len = 5     # Maximum len of sequence
    vocab_size = len(tokenizer.word_index) + 1

    for seq in sequences:
        for i in range(1, len(seq)):
            n_grams = seq[max(0, i - max_len): i]
            X.append(pad_sequences([n_grams], maxlen=max_len, padding='pre')[0])
            y.append(seq[i])

    X = np.array(X)
    y = np.array(y)
    y = np.expand_dims(y, -1)

    print(f"Count of Sequences: {len(X)}")
    print(f"Vocab Size: {vocab_size}")
    print(f"Max Length: {max_len}")

    return X, y, vocab_size, tokenizer, max_len

# Build Model

In [5]:
def build_model(vocab_size, max_len, embedding_dim, lstm_units=128):
    inputs = Input(shape=(max_len,))
    x = Embedding(vocab_size, embedding_dim)(inputs)
    x = LSTM(lstm_units, return_sequences=False)(x)
    x = Dense(vocab_size, activation='softmax')(x)

    model = Model(inputs, x)
    model.compile(
        optimizer=Adam(learning_rate=0.001),
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )
    return model

In [6]:
# Prediction function
def predict_next_word(model, tokenizer, text, max_len):
    sequence = tokenizer.texts_to_sequences([text.lower().split()[-max_len:]])
    padded_sequence = pad_sequences(sequence, maxlen=max_len, padding='pre')
    prediction = model.predict(padded_sequence, verbose=0)
    predicted_word_index = np.argmax(prediction[0])
    predicted_word = [word for word, index in tokenizer.word_index.items() if index == predicted_word_index][0]
    return predicted_word

# Run Model

In [7]:
# Load and preparing data
texts = load_wiki_dataset()
X, y, vocab_size, tokenizer, max_len = prepare_data(texts)

# Creating and fitting the model
model = build_model(vocab_size, max_len, embedding_dim=50)
model.summary()

# Fit
history = model.fit(X, y, batch_size=32, epochs=50, validation_split=0.2, verbose=1)
model.save('next_word_predictor.h5')

# Prediction
test_sentences = ["i go to", "the cat is", "she likes to"]
for sentence in test_sentences:
    next_word = predict_next_word(model, tokenizer, sentence, max_len)
    print(f"Input: {sentence} -> Prediction: {next_word}")

README.md:   0%|          | 0.00/16.0k [00:00<?, ?B/s]

wikipedia.py:   0%|          | 0.00/36.7k [00:00<?, ?B/s]

The repository for wikipedia contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/wikipedia.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N] y


train-00000-of-00001.parquet:   0%|          | 0.00/134M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/205328 [00:00<?, ? examples/s]

Count of Sequences: 610062
Vocab Size: 31719
Max Length: 5


Epoch 1/50
[1m15252/15252[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m167s[0m 11ms/step - accuracy: 0.0888 - loss: 7.2192 - val_accuracy: 0.1302 - val_loss: 6.9187
Epoch 2/50
[1m15252/15252[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m200s[0m 11ms/step - accuracy: 0.1570 - loss: 6.0873 - val_accuracy: 0.1441 - val_loss: 6.8331
Epoch 3/50
[1m15252/15252[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m201s[0m 11ms/step - accuracy: 0.1843 - loss: 5.6476 - val_accuracy: 0.1490 - val_loss: 6.8796
Epoch 4/50
[1m15252/15252[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m201s[0m 11ms/step - accuracy: 0.2043 - loss: 5.3102 - val_accuracy: 0.1509 - val_loss: 6.9263
Epoch 5/50
[1m15252/15252[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m202s[0m 11ms/step - accuracy: 0.2227 - loss: 5.0317 - val_accuracy: 0.1490 - val_loss: 7.0142
Epoch 6/50
[1m15252/15252[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m201s[0m 11ms/step - accuracy: 0.2398 - loss: 4.7750 - val_accuracy: 0.1473 - val



Input: i go to -> Prediction: the
Input: the cat is -> Prediction: a
Input: she likes to -> Prediction: do
