<a href="https://colab.research.google.com/github/kapoor-a/nlp/blob/main/next_char_precdiction_deep_gru.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install trax
!pip install nltk

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting trax
  Downloading trax-1.4.1-py2.py3-none-any.whl (637 kB)
[K     |████████████████████████████████| 637 kB 28.6 MB/s 
Collecting funcsigs
  Downloading funcsigs-1.0.2-py2.py3-none-any.whl (17 kB)
Collecting tensorflow-text
  Downloading tensorflow_text-2.10.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (5.9 MB)
[K     |████████████████████████████████| 5.9 MB 53.5 MB/s 
Collecting tensorflow<2.11,>=2.10.0
  Downloading tensorflow-2.10.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (578.0 MB)
[K     |████████████████████████████████| 578.0 MB 15 kB/s 
[?25hCollecting keras<2.11,>=2.10.0
  Downloading keras-2.10.0-py2.py3-none-any.whl (1.7 MB)
[K     |████████████████████████████████| 1.7 MB 49.5 MB/s 
Collecting tensorboard<2.11,>=2.10
  Downloading tensorboard-2.10.1-py3-none-any.whl (5.9 MB)
[K     |████████████████████████████████| 5.9 MB 46

In [3]:
from nltk.corpus import gutenberg
import nltk
import random as rnd
import numpy as np
import trax.layers as tl
import trax

In [4]:
nltk.download('gutenberg')

[nltk_data] Downloading package gutenberg to /root/nltk_data...
[nltk_data]   Unzipping corpora/gutenberg.zip.


True

In [5]:
def get_play(name):
  play = gutenberg.raw(name).split("\n")
  lines = []
  for line in play:
    line = line.strip().lower()
    if line != "":
      lines.append(line)
  print(name, lines[0], len(lines))
  return lines

In [6]:
lines = []
for f in gutenberg.fileids():
  lines += get_play(f)
  lines.append("/n")

austen-emma.txt [emma by jane austen 1816] 14283
austen-persuasion.txt [persuasion by jane austen 1818] 7356
austen-sense.txt [sense and sensibility by jane austen 1811] 12773
bible-kjv.txt [the king james bible] 74645
blake-poems.txt [poems by william blake 1789] 1094
bryant-stories.txt [stories to tell to children by sara cone bryant 1918] 4123
burgess-busterbrown.txt [the adventures of buster bear by thornton w. burgess 1920] 1315
carroll-alice.txt [alice's adventures in wonderland by lewis carroll 1865] 2479
chesterton-ball.txt [the ball and the cross by g.k. chesterton 1909] 7890
chesterton-brown.txt [the wisdom of father brown by g. k. chesterton 1914] 6443
chesterton-thursday.txt [the man who was thursday by g. k. chesterton 1908] 5475
edgeworth-parents.txt [the parent's assistant, by maria edgeworth] 14492
melville-moby_dick.txt [moby dick by herman melville 1851] 19651
milton-paradise.txt [paradise lost by john milton 1667] 10572
shakespeare-caesar.txt [the tragedie of julius 

In [7]:
split = int(0.95*len(lines))
train = lines[:split]
eval = lines[split:]

In [8]:
def line_to_tensor(line, EOS_int=1):
    tensor = [ord(c) for c in line] 
    tensor.append(EOS_int)
    return tensor

In [9]:
def data_generator(batch_size, max_length, data_lines, shuffle=True):
    index = 0
    cur_batch = []
    num_lines = len(data_lines)
    lines_index = [*range(num_lines)]
    if shuffle:
        rnd.shuffle(lines_index)
    while True:
        if index >= num_lines:
            index = 0
            if shuffle:
                rnd.shuffle(lines_index) 
        line = data_lines[lines_index[index]]
        if len(line) < max_length:
            cur_batch.append(line)
        index += 1
        if len(cur_batch) == batch_size:
            batch = []
            mask = []
            for li in cur_batch:
                tensor = line_to_tensor(li)
                pad = [0] * (max_length - len(tensor))
                tensor_pad = tensor + pad
                batch.append(tensor_pad)
                example_mask = [1]*len(tensor) + [0]*(max_length - len(tensor))
                mask.append(example_mask)
            batch_np_arr = np.array(batch)
            mask_np_arr = np.array(mask)
            yield batch_np_arr, batch_np_arr, mask_np_arr
            cur_batch = []

In [39]:
def char_predict_model(vocab_size=256, d_model=512, gru_layers=2, mode="train"):
  return tl.Serial(
      tl.ShiftRight(mode=mode),
      tl.Embedding(vocab_size, d_model),
      [tl.GRU(d_model) for _ in range(gru_layers)],
      tl.Dense(d_model),
      tl.LogSoftmax()
  )

In [24]:
from trax.supervised import training

def create_training_loop(model, train_stream, eval_stream, output_dir="/content/model/"):
  train_task = training.TrainTask(
      labeled_data=train_stream,
      loss_layer=tl.CrossEntropyLoss(),
      optimizer=trax.optimizers.Adam(0.0005),
      n_steps_per_checkpoint=10,
  )
  eval_task = training.EvalTask(
      labeled_data=eval_stream, 
      metrics=[tl.CrossEntropyLoss(), tl.Accuracy()])
  
  loop = training.Loop(
      model,
      train_task, 
      eval_tasks=[eval_task], 
      output_dir=output_dir
      )
  return loop

In [25]:
batch_size = 32
max_length = 64
train_stream = data_generator(batch_size, max_length, train)
eval_stream = data_generator(batch_size, max_length, eval)
model = char_predict_model()

In [None]:
!rm -rf /content/model/
loop = create_training_loop(model, train_stream, eval_stream)
loop.run(n_steps=300)

In [46]:
def gumbel_sample(log_probs, temperature=0.5):
    """Gumbel sampling from a categorical distribution."""
    u = np.random.uniform(low=1e-6, high=1.0 - 1e-6, size=log_probs.shape)
    g = -np.log(-np.log(u))
    return np.argmax(log_probs + g * temperature, axis=-1)

def predict(num_chars, prefix):
    inp = [ord(c) for c in prefix]
    result = [c for c in prefix]
    max_len = len(prefix) + num_chars
    for _ in range(num_chars):
        cur_inp = np.array(inp + [0] * (max_len - len(inp)))
        outp = model(cur_inp[None, :])  # Add batch dim.
        next_char = gumbel_sample(outp[0, len(inp)])
        inp += [int(next_char)]
       
        if inp[-1] == 1:
            break  # EOS
        result.append(chr(int(next_char)))
    
    return "".join(result)

In [48]:
predict(5, "My n")

'My nictee'