In [1]:
!cp "/content/drive/MyDrive/models/Next Word Prediction/next-word.h5" "./"

In [6]:
import numpy as np
import tensorflow as tf

### Building the inference model

In [18]:
model_weights = "/content/next-word.h5"

In [12]:
def build_model(vocab_size, embedding_dim, rnn_units, batch_size):
  model = tf.keras.Sequential([
                               tf.keras.layers.Embedding(vocab_size, embedding_dim,
                                                         batch_input_shape=[batch_size, None]),
                               tf.keras.layers.GRU(rnn_units,
                                                   return_sequences=True,
                                                   stateful=True),
                               tf.keras.layers.Dense(vocab_size)
])
  return model

In [4]:
chars = sorted(set("abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:'''/\|_@#$%ˆ&*˜'+-=()[]{}' ABCDEFGHIJKLMNOPQRSTUVWXYZ"))
chars = list(chars)
EOS = '<EOS>'
UNK = "<UNK>"
PAD = "<PAD>"
chars.append(UNK)
chars.append(EOS)
chars.insert(0, PAD)

In [7]:
char2idx = {u:i for i, u in enumerate(chars)}
idx2char = np.array(chars)

In [8]:
def char_idx(c):
    if c in chars:
        return char2idx[c]
    return char2idx[UNK]

In [10]:
# Hyperparameters
vocab_size = len(chars)
embedding_dim = 256
rnn_units = 1024
BATCH_SIZE=1

In [14]:
model = build_model(vocab_size, embedding_dim, rnn_units, BATCH_SIZE)

In [19]:
model.load_weights(model_weights)

In [125]:
def generate_text(model, start_string, temperature=0.4, num_generate=30):
  input_eval = [char2idx[s] for s in start_string]
  input_eval = tf.expand_dims(input_eval, 0)
  text_generated = []
  for i in range(num_generate):
    predictions = model.predict(input_eval)
    predictions = tf.squeeze(predictions, 0)
    predictions = predictions / temperature
    predicted_id = tf.random.categorical(predictions,num_samples=1)[-1,0].numpy()
    if predicted_id == char2idx["<EOS>"]:
      break
    input_eval = tf.expand_dims([predicted_id], 0)
    text_generated.append(idx2char[predicted_id])
  return (''.join(text_generated))

In [126]:
def get_text(text):
  gen_txt = generate_text(model, test_text)
  gen_txt = gen_txt.lower()
  gen_txt = gen_txt.strip()
  gen_txt = gen_txt.split()
  return gen_txt[:3]

In [139]:
text1 = "I want to"
get_text(text1)

['save', 'lives']

In [140]:
text2= "I like"
get_text(text2)

['experience', 'the', 'foreclosure']

In [141]:
text3= "He likes to "
get_text(text3)

['be', 'healthy']