In [6]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Embedding, SimpleRNN, Dense
from datasets import load_dataset


In [7]:
# 1. Load the WikiText-2 dataset
dataset = load_dataset("wikitext", "wikitext-2-raw-v1")
raw_text = dataset['train']['text']


In [8]:
# 2. Clean and preprocess text (remove empty lines, lowercase)
corpus = [line.lower() for line in raw_text if line.strip() != ""][:2000]

In [9]:

# 3. Tokenize the text
tokenizer = Tokenizer()
tokenizer.fit_on_texts(corpus)
total_words = len(tokenizer.word_index) + 1

In [10]:
# 4. Create input sequences with sequence length limit
input_sequences = []
max_len_allowed = 20  # Keep sequences short for memory efficiency

for line in corpus:
    token_list = tokenizer.texts_to_sequences([line])[0]
    for i in range(1, len(token_list)):
        n_gram_seq = token_list[:i+1]
        if len(n_gram_seq) <= max_len_allowed:
            input_sequences.append(n_gram_seq)

In [11]:

# 5. Pad sequences and split into input (X) and label (y)
max_sequence_len = max([len(seq) for seq in input_sequences])
input_sequences = pad_sequences(input_sequences, maxlen=max_sequence_len, padding='pre')

X = input_sequences[:, :-1]
y = input_sequences[:, -1]
y = tf.keras.utils.to_categorical(y, num_classes=total_words)

In [12]:
# 6. Build the RNN model
model = Sequential()
model.add(Embedding(input_dim=total_words, output_dim=64, input_length=max_sequence_len - 1))
model.add(SimpleRNN(100))
model.add(Dense(total_words, activation='softmax'))

model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
model.summary()


Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 embedding (Embedding)       (None, 19, 64)            1027648   
                                                                 
 simple_rnn (SimpleRNN)      (None, 100)               16500     
                                                                 
 dense (Dense)               (None, 16057)             1621757   
                                                                 
Total params: 2665905 (10.17 MB)
Trainable params: 2665905 (10.17 MB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________


In [13]:

# 7. Train the model
model.fit(X, y, epochs=50, batch_size=128, verbose=1)

Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25/50
Epoch 26/50
Epoch 27/50
Epoch 28/50
Epoch 29/50
Epoch 30/50
Epoch 31/50
Epoch 32/50
Epoch 33/50
Epoch 34/50
Epoch 35/50
Epoch 36/50
Epoch 37/50
Epoch 38/50
Epoch 39/50
Epoch 40/50
Epoch 41/50
Epoch 42/50
Epoch 43/50
Epoch 44/50
Epoch 45/50
Epoch 46/50
Epoch 47/50
Epoch 48/50
Epoch 49/50
Epoch 50/50


<keras.src.callbacks.History at 0x7fe6e8494430>

In [None]:
def predict_next_word(seed_text, next_words=1):
    for _ in range(next_words):
        token_list = tokenizer.texts_to_sequences([seed_text])[0]
        token_list = pad_sequences([token_list], maxlen=max_sequence_len - 1, padding='pre')
def predict_next_word(seed_text, next_words=1):
    for _ in range(next_words):
        token_list = tokenizer.texts_to_sequences([seed_text])[0]
        token_list = pad_sequences([token_list], maxlen=max_sequence_len - 1, padding='pre')
        predicted_probs = model.predict(token_list, verbose=0)
        predicted_index = np.argmax(predicted_probs)
        predicted_word = tokenizer.index_word.get(predicted_index, "")
        seed_text += " " + predicted_word
    return seed_text


In [23]:
print(predict_next_word("Devan is", 1))
print(predict_next_word("Growing", 1))


Devan is early
Growing season
