# Fun with language

> <p><small>Copyright 2021 DeepMind Technologies Limited.</small></p>
> <p><small>Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at </small></p>
> <p><small><a href="https://www.apache.org/licenses/LICENSE-2.0">https://www.apache.org/licenses/LICENSE-2.0</a></small></p>
> <p><small>Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.</small></p>


**Aim**
To introduce you to how Artificial Intelligence can help us process natural language, for instance to automatically generate text and decrypt secret messages.

**Disclaimer**

This code is intended for educational purposes, and in the name of readability for a non-technical audience does not always follow best practices for software engineering.

**Links to resources**
- [What is Colab?](https://colab.research.google.com/notebooks/intro.ipynb) If you have never used Colab before, get started here!
- [The MCMC revolution](https://math.uchicago.edu/~shmuel/Network-course-readings/MCMCRev.pdf) A technical report on the mathematics underlying the second part of this notebook.
- [Text generation with an RNN](https://www.tensorflow.org/tutorials/text/text_generation) A text generation tutorial.

# Natural (human) language and computers

Language, specifically human language, is endlessly fascinating. We use it to communicate our needs, desires, opinions, ideas, feelings ... essentially anything we can formulate in words, which seems to be ... everything.

<center>
<img src="https://storage.googleapis.com/dm-educational/assets/fun-with-language/romeo.png" alt="drawing" height="250"/>
</center>


We can express anything in the world around us, but at the same time, we can create new, imaginary worlds and fascinating stories that never happened. We write books about mythical creatures, and poems about feelings of characters who never existed. We arrange words like legos on a daily basis, almost effortlessly and often without too much thinking --- it is that easy for us to use language.

Computers, on the other hand, don't have an easy time with language.

When we see the word `pancake`, we almost immediately form a mental image of a pancake, thinking what might we put on top of it, whether it is a small or a large one, or we might even feel a bit peckish just thinking about it. On the other hand, the word `pancake` has no prior meaning to a computer. Even worse, given that computers code symbols as numbers, a computer sees 'pancake' just as a sequence of numbers like these:

<center>
<img src="https://storage.googleapis.com/dm-educational/assets/fun-with-language/numbers.png" alt="drawing" height="220"/>
</center>

If by any chance we made an error, and added 1 to each of these numbers, we'd get the word `qbodblf` which is of absolutely no meaning to us. Interestingly, one of these words means something to us, the other doesn't, but for a computer, they're just a particular sequence of numbers, and that's it.


So, since computers crunch numbers, and we can represent language with numbers, can we go about making computers 'crunch' language as a way of demonstrating that they somehow 'understand' language?

In this colab we will have fun with simple models called **Language Models**. These models do not 'deeply understand' language, but can be used in solving tasks that require 'shallow' language understanding.

---


We will use them on two tasks:

1. **Message decoding!** Can we use a language model to decode a cryptic message?


<center>
<img src="https://storage.googleapis.com/dm-educational/assets/fun-with-language/decrypt-ohno.png" alt="drawing" height="220"/>
</center>



2. **Text generation!** Let's use a language model as an artificial writer/poet!


<center>
<img src="https://storage.googleapis.com/dm-educational/assets/fun-with-language/robeo.png" alt="drawing" height="280"/>
</center>



# Installation and Setup

We start by installing and importing all of the necessary Python libraries and defining helper functions used in the colab.




Some of the code we will run to build a language model is computationally intensive, so please use a colab kernel with a GPU (Graphics Processing Unit---yes, <a href="https://en.wikipedia.org/wiki/Graphics_processing_unit#Computational_functions">we're using specialised gaming hardware to speed things up</a> ^_^) by doing the following:

 > `Edit` -> `Notebook settings` -> select GPU under `Hardware accelerator` -> `Save`

like so:

<center>
<img src="https://storage.googleapis.com/dm-educational/assets/fun-with-language/gpu_instructions.png" alt="drawing" height="400"/>
</center>

 

In [None]:
#@title Setting up
#@markdown Installing and importing dependencies, as well as defining helper functions (used for model building, visualisation and text completion). The output of this cell will warn you if you're not using a GPU, which will lead to slow model training.


# Installing dependencies

print("Installing dependencies...", end='')


from IPython.utils import io

with io.capture_output() as captured:
  # Add all the pip installs of modules necessary to run the colab
  %reset -f
  !apt-get update

  # Install current tensorflow version
  !pip install columnize
  !pip install tabulate
  !pip install tensorflow --upgrade
  !pip install transformers

  !mkdir -p data

print("DONE!")

print("Importing dependencies...", end='')

# Importing dependencies

import base64
import collections
import columnize
import copy
import functools
import io
import IPython
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
import re
import tensorflow as tf
import transformers
import warnings

from tqdm.autonotebook import tqdm
from google.colab import html
from matplotlib import ticker
from tabulate import tabulate
from tensorflow import keras

print("DONE!")

%matplotlib inline

np.random.seed(42)

warnings.filterwarnings('ignore')

print("Defining helper functions...", end='')

# Defining helper functions

def split_input_target(data):
  """Splits the data into input and target."""
  return data[:-1], data[1:]


def generate_text(text, length=100, sampling='random'):
  """A function used to generate a text from starting text.

  Args:
    text: starting text
    length: the length of the generated text (default: 100)
    sampling: 'greedy' or 'random' (default: random)
  Returns:
    A string containing the generated text
  """
  assert sampling in ['greedy', 'random']
  for _ in range(length):
    data = tf.expand_dims([character_to_id[char] for char in text], 0)
    predicted = model(data)
    last_predicted = predicted[0, -1]
    if sampling == 'greedy':
      next_character = id_to_character[np.argmax(last_predicted)]
    elif sampling == 'random':
      rand_char = tf.random.categorical(np.expand_dims(last_predicted, axis=0),
                                        num_samples=1).numpy()[0, 0]
      next_character = id_to_character[rand_char]
    text = text + next_character
  return text


def print_generated_text(start, text):
  """Prints generated text so that it stands out.

  Denotes the starting text by enveloping it in brackets, and prints an array
  of dashes below and after the text, for example:

  -------------------------------
  [starting text:] generated text
  -------------------------------

  Args:
    start: starting text
    text: text to print
  """

  print('\n\n' + '-' * len(text))
  print('[{}]{}'.format(start, text[len(start):]))
  print('-' * len(text) + '\n')

print_callback = tf.keras.callbacks.LambdaCallback(
    on_epoch_end=lambda epoch, logs: print_generated_text('ROMEO:', generate_text('ROMEO:', length=100, sampling='random'))
    )


def get_probabilities(text):
  """Get probabilities of next character, given a list of starting ones.

  Args:
    text: a list of characters
  Returns:
    A dictionary of character : probability values
  """
  assert len(text) > 0
  data = tf.expand_dims([character_to_id[char] for char in text], 0)
  predicted = model(data)
  last_predicted = predicted[0, -1]
  probabilities = tf.nn.softmax(last_predicted).numpy()
  character_probabilities = {id_to_character[i]: probabilities[i]
                             for i in range(probabilities.shape[0])}
  return character_probabilities


def nn_seq_log_likelihood(seq):
  """Calculates the log-likelihood of the sequence for the neural model.

  Args:
    seq: string containing the input sequence / message
  Returns:
    Log-likelihood of the sequence
  """
  seq_of_ids = [character_to_id[character] for character in seq]
  text_as_numbers = np.array(seq_of_ids).astype(np.int32)
  ground_truth = text_as_numbers[1:]
  inp = text_as_numbers[np.newaxis, ...]
  predicted = model(inp)
  predicted = tf.math.log(tf.nn.softmax(predicted[0]))
  indices = tf.transpose(tf.stack([tf.range(predicted.shape[0] - 1),
                                   ground_truth]))
  nll = tf.reduce_sum(tf.gather_nd(predicted, indices))
  return nll.numpy()


def seq_log_likelihood(seq, initial, transition, alphabet, first_order=False):
  """Calculates the log-likelihood of the sequence for the simple (quick) model.

  Args:
    seq: string containing the input sequence / message
    initial: initial probabilities for each character
    transition: probabilities for character-to-character transition
    alphabet: list of characters
    first_order: first-order estimate (default: False)
  Returns:
    Log-likelihood of the sequence
  """
  # log(p(s_0))
  ll = np.log(initial[alphabet.index(seq[0])])

  if not first_order:
    bigrams = zip(seq, seq[1:])
    for b in bigrams:
      idx1 = alphabet.index(b[0])
      idx2 = alphabet.index(b[1])
      # log(p(s_{t+1}|s_t))
      ll += np.log(transition[idx1, idx2])
  else:
    # This is assuming independence,
    # i.e. p(s_0, ..., s_T) = \prod_{t=1}^T p(s_t)
    idx = list(map(alphabet.index, seq[1:]))
    ll += np.sum(np.log(initial[idx]))

  return ll


def dirichlet_map(prior_as, n_successes):
  """Get the dirichlet maximum a posteriori (MAP) estimate

  Args:
    prior_as: list of dirichlet alphas for all elements
    n_successes: occurrences for each element
  Returns:
    A maximum a posteriori (MAP) estimate for all elements
  """
  # Dimensionality of the simplex
  N = np.sum(n_successes, axis=1)
  K = len(prior_as)
  Z = N + np.sum(prior_as) - K
  map_ = n_successes + prior_as - 1
  return map_ / Z[:, np.newaxis]


def decode(seq, mapping):
  """Decodes a sequence with a given mapping.

  Args:
    seq: a list of characters
    mapping: a character : character mapping
  Returns:
    A string containing characters from seq mapped with the mapping dict
  """
  decoded_seq = (mapping[ch] for ch in seq)
  return ''.join(decoded_seq)


def state_transition(mapping, rng):
  """Randomly swaps two mappings from the map.

  Args:
    mapping: dictionary containing character : character mapping
    rng: RandomState object
  Returns:
    A dictionary containing the new mapping, with two keys containing swapped
    values
  """
  idx = rng.choice(range(0, len(mapping)), size=2, replace=False)

  keys = list(mapping.keys())
  k1 = keys[idx[0]]
  k2 = keys[idx[1]]

  new_mapping = copy.deepcopy(mapping)
  new_mapping[k1] = mapping[k2]
  new_mapping[k2] = mapping[k1]

  return new_mapping


def random_permutation(alphabet_1, alphabet_2=None, rng=None):
  """Randomly maps one alphabet (list of characters) to another one.

  Args:
    alphabet_1: list of characters
    alphabet_2: list of characters. If None, it equals alphabet_1
                (default: None)
    rng: numpy RandomState (default: None)
  Returns:
    A dictionary of the random character : character mapping denoting which
    character from one alphabet is mapped to which character in the other one
  """
  # Assume equal alphabets
  if alphabet_2 is None:
    alphabet_2 = alphabet_1
  else:
    assert len(alphabet_1) ==  len(alphabet_2)

  permutation_fn = rng.permutation if rng else np.random.permutation
  permutation = permutation_fn([a for a in alphabet_2])
  mapping = {k: v for k, v in zip(alphabet_1, permutation)}
  return mapping


def preprocess_whitespace_character(x):
  """Preprocesses whitespace characters for prettier printing."""
  if x == '\t': return '\\t'
  if x == '\n': return '\\n'
  if x == ' ': return "' '"
  return x

def pretty_print_characters(characters):
  """Pretty prints characters into nicely aligned rows of characters."""
  sorted_characters = [preprocess_whitespace_character(char)
                       for char in sorted(characters)]
  columnized = columnize.columnize(sorted_characters, displaywidth=80, ljust=False)
  print(columnized)


def pretty_print_mappings(character_to_id):
  """Pretty-prints character to id mappings.

  Prints the character to id mappings into four column-categories:
  'numbers', 'lower case letters', 'upper case letters' and 'other character'

  Args:
    character_to_id: a dictionary of character : id values.
  """
  pretty_character_to_id = {preprocess_whitespace_character(k): v for k, v in character_to_id.items()}

  numbers = ['{:>2} -> {:>3}'.format(k, v) for k, v in sorted(pretty_character_to_id.items()) if k.isnumeric()]
  alpha_small = ['{:>2} -> {:>3}'.format(k, v) for k, v in sorted(pretty_character_to_id.items()) if k.isalpha() and k.islower()]
  alpha_big = ['{:>2} -> {:>3}'.format(k, v) for k, v in sorted(pretty_character_to_id.items()) if k.isalpha() and k.isupper()]
  other = ['{:>2} -> {:>3}'.format(k, v) for k, v in sorted(pretty_character_to_id.items()) if not k.isnumeric() and not k.isalpha()]

  max_len = max(len(numbers), len(alpha_small), len(alpha_big), len(other))
  headers = ['Numbers', 'Lower case\nletters',
             'Upper case\nletters', 'Other\ncharacters']
  table = []
  for i in range(max_len):
    row = []
    for lst in [numbers, alpha_small, alpha_big, other]:
      if len(lst) > i:
        row.append(lst[i])
      else:
        row.append('')
    table.append(row)
  print(tabulate(table, headers=headers, stralign='right'))


class Img(html.Element):
  """Class for constructing the 'Character occurrences' graph image."""
  def __init__(self):
    super(Img, self).__init__('img')

  def update_img(self, character, data):
    plt.figure(figsize=(12, 4))
    plt.bar(*zip(*sorted(data.items(), key=lambda x: -x[1])))
    plt.title('Occurrences of character following {}'.format(character))
    plt.ylabel('Probability of occurence')
    plt.xlabel('Character')
    plt.savefig('next_char.png')
    plt.close()

    img = plt.imread('next_char.png')
    img = (img * 255).astype('uint8')
    img = PIL.Image.fromarray(img).convert('RGB')
    buf = io.BytesIO()
    img.save(buf, format='JPEG',)
    content = buf.getvalue()
    url = 'data:image/jpeg;base64,'+base64.b64encode(content).decode('utf-8')
    self.set_property('src', url)


class Div(html.Element):
  """Class for constructing an input field where the user will press a key."""
  def __init__(self):
    super(Div, self).__init__('input')

  @property
  def text_content(self):
    return self.get_property('textContent')

  @text_content.setter
  def text_content(self, value):
    return self.set_property('textContent', value)


class ProcessKeyboardInput():
  """Class for processing user-made inputs."""
  def __init__(self, model_type='bigram'):
    self._input = []
    assert model_type in ['bigram', 'rnn']
    self._model_type = model_type

  def process_key(self, key):
    """Function that processes a pressed key.

    Args:
      key: key from the 'keydown' event listener
    """
    if self._model_type == 'bigram':
      if key['key'] not in alphabet:
        print('Pressed key "{}" not in the vocabulary!'.format(key['key']))
      else:
        char_index = alphabet.index(key['key'])
        transition_counter = {elem: transition[char_index][alphabet.index(elem)]
                              for elem in alphabet}
        graph_img.update_img(key['key'], transition_counter)
    elif self._model_type == 'rnn':
      if key['key'] not in character_counter.keys():
        print('Pressed key "{}" not in the vocabulary!'.format(key['key']))
      else:
        self._input.append(key['key'])
        character_probabilities = get_probabilities(self._input)
        graph_img.update_img(key['key'], character_probabilities)


print("DONE!")

print('\nChecking for GPU...', end='')

# from https://colab.research.google.com/notebooks/gpu.ipynb#scrollTo=Y04m-jvKRDsJ
device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0':
  print(' NOT using a GPU. Running this code will be quite SLOW.')
else:
  print(" Using a GPU. You're good to go.")


print("\nSetup DONE!")

# Step 1: Learning about human language with Artificial Intelligence: Language Models

**Language Models** are Artificial Intelligence (AI) algorithms that learn simple statistics of a language. Meaning, these models learn which characters (or words) are likely to follow some previous text. For instance, you probably know that the next character following `pancak` is most likely `e` and not `s`, or that `How are` will probably be followed by `you?` or `they` and not `lovely`. This is exactly what we will teach our language models to do.

Although this might sound rather simple, given a good algorithm, enough text and some time on a fast computer, these language models can learn a surprising amount of knowledge about language.

In order to learn which characters go together, we take a look at a large amount of text, such as a book, a collection of books, or even unimaginable amounts of text on the Internet.

If you want to see how your results compare had you chosen a different book, feel free to re-run the code from this section again with a different choice.

In [None]:
#@title Choosing a suitable data source

#@markdown As our first step to building a language model, we will pick a book. We picked a selection of freely available books on Project Gutenberg, that you can choose from:

book = 'The Wonderful Wizard of Oz (L. Frank Baum)'  #@param ["The Complete Works of William Shakespeare (William Shakespeare)", "A Christmas Carol (Charles Dickens)", "Alice’s Adventures in Wonderland (Lewis Carroll)", "Gulliver's Travels (Jonathan Swift)", "The Wonderful Wizard of Oz (L. Frank Baum)", "Treasure Island (Robert Louis Stevenson)"] {allow-input: false}

book_data = {
    'The Complete Works of William Shakespeare (William Shakespeare)':
        ['http://www.gutenberg.org/files/100/100-0.txt',
         'http://www.gutenberg.org/files/100/100-h/images/cover.jpg'],
    'A Christmas Carol (Charles Dickens)':
        ['http://www.gutenberg.org/files/46/46-0.txt',
         'https://www.gutenberg.org/cache/epub/46/pg46.cover.medium.jpg'],
    'Alice’s Adventures in Wonderland (Lewis Carroll)':
        ['http://www.gutenberg.org/files/11/11-0.txt',
         'https://www.gutenberg.org/cache/epub/11/pg11.cover.medium.jpg'],
    'Gulliver\'s Travels (Jonathan Swift)':
        ['http://www.gutenberg.org/files/829/829-0.txt',
         'https://www.gutenberg.org/cache/epub/829/pg829.cover.medium.jpg'],
    'The Wonderful Wizard of Oz (L. Frank Baum)':
        ['https://www.gutenberg.org/files/55/55-0.txt',
         'http://www.gutenberg.org/files/43936/43936-h/images/i001_edit.jpg'],
    'Treasure Island (Robert Louis Stevenson)':
        ['http://www.gutenberg.org/files/120/120-0.txt',
         'https://www.gutenberg.org/files/120/120-h/images/0010m.jpg'],
}

#@markdown > Upon execution, this code will download the book and display its cover.

def download_book(book):
  text_url, cover_url = book_data[book]
  text_path = tf.keras.utils.get_file(os.path.basename(text_url), text_url)
  cover_path = tf.keras.utils.get_file(os.path.basename(cover_url), cover_url)
  return text_path, cover_path 

with IPython.utils.io.capture_output() as captured:
  text_path, cover_path = download_book(book)

plt.figure(figsize=(10, 10))
plt.imshow(plt.imread(cover_path))
plt.axis('off')
plt.title('BOOK COVER')

txt = []
with open(text_path, 'r', encoding='utf-8-sig') as f:
  txt = f.readlines()

In [None]:
#@title Let's scroll through the text we're  working with { run: "auto" }

text_scroller = 370  #@param {type:"slider", min:300, max:500, step:10}

print(''.join(txt[text_scroller:text_scroller+50]))

Language models learn to predict the following segment of a text from the previous one. We will build a character language model: a language model that predicts the character following previous characters. It is also possible (with some changes we will not go through in this colab) to predict the word following previous ones, as we will see later.

But for now, before starting to model, let's first take a look at the statistics of the text we want to model.

Let's see which unique characters (letters, numbers, punctuation, etc.) are there in the book and how many of each are there.

In [None]:
#@title Defining the vocabulary

#@markdown These unique characters (or words in the case of a word language model) is what we call **the vocabulary**. Once we define a vocabulary, we will train a model to predict which element of the vocabulary (letter or a word)  comes after the previous one(s).

#@markdown > Running this code will build and display the character vocabulary

text = ''.join(txt)

character_counter = collections.Counter()
character_counter.update(text)
vocabulary_size = len(character_counter.keys())

print('Number of unique characters in the text: {}\n'.format(vocabulary_size))
print('All the unique characters in the text (our vocabulary):\n')

pretty_print_characters(character_counter.keys())

What do you think are the most frequent characters / letters in our book? Let's find that out:

In [None]:
plot_which_characters = 'letters only'  #@param ['all characters', 'letters only']

if plot_which_characters == 'all characters':
  title = 'Number of occurrences of each character in the book'
  xlabel = 'Characters'
  characters_to_plot = {k: v for k, v in character_counter.items()}
elif plot_which_characters == 'letters only':
  title = 'Number of occurrences of each letter in the book'
  xlabel = 'Letters'
  characters_to_plot = collections.defaultdict(lambda: 0)
  for k, v in character_counter.items():
    if k.isalpha():
      characters_to_plot[k.lower()] += v


fig = plt.figure(figsize=(20, 8))
plt.gca().ticklabel_format(axis='y', style='plain')
plt.gca().get_yaxis().set_major_formatter(
    ticker.FuncFormatter(lambda x, p: format(int(x), ',')))
plt.bar(*zip(*sorted(characters_to_plot.items(), key=lambda x: -x[1])))
plt.title(title)
plt.ylabel('Number of occurences')
plt.xlabel(xlabel)
plt.show()

As you can see this is highly skewed (there are letters that occur a lot, and ones that occur rarely), which is typical for language data such as characters or words (you probably use some words much more frequently than others). Such statistics are highly indicative of a language used to write our book of choice. Check [the most frequent letters in English](http://letterfrequency.org/) to see whether our text statistics is indicative of a text written in English.

# Step 2: Building a Language Model



We previously said that in order for a computer to 'digest' language, we need to **convert language to a numerical representation**. Here we give every character its unique number, and map the entire text of characters to numbers.

In [None]:
#@title Mapping characters to numbers

#@markdown > We build maps from characters to numbers and back, so we can both translate text into numbers and numbers into text.

character_to_id = {c: i for i, c in enumerate(sorted(character_counter.keys()))}
id_to_character = {v: k for k, v in character_to_id.items()}

pretty_print_mappings(character_to_id)

Let's see how a fragment of the text looks like presented with numbers:

In [None]:
#@title Translating text to numbers

example_text = 'Fun with language!'  #@param {type:"string"}

text_as_numbers = [character_to_id[character] for character in example_text]

print('"{}"\n'.format(example_text))
print('is translated to:\n')
print(text_as_numbers)

At this point we're already at a stage where we can try a very simple first language model: We will simply look at each character and keep track of which other characters are likely to follow. 


*Easy, right?*


To keep things simple we'll only use the most common characters for this model that we've observed in the plots earlier.


In [None]:
#@title Run the code below to set up this simple model

# Dirichlet prior strength
alpha = 5

# Reduced alphabet
alphabet = [' ', '-', ',', ';', ':', '!', '?', '/', '.', "'", '"', '(', ')',
            '[', ']', '*', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
            'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm',
            'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
alphabet.append('#')  # Used to denote OOV tokens

# Remove unnecessary whitespace and new lines, lowercase
txt_bigram = re.sub('\s+', ' ', ''.join(text)).lower()
txt_bigram = list(filter(lambda x: x in alphabet, txt_bigram))

alphabet_size = len(alphabet)
initial = np.random.dirichlet(alphabet_size*[1], size=1).flatten()
transition = np.random.dirichlet(alphabet_size*[1], size=alphabet_size)

unigram_counts = np.zeros([alphabet_size], dtype=np.float64)
bigram_counts = np.zeros(2*[alphabet_size], dtype=np.float64)

unigram_counter = collections.Counter(txt_bigram)
for char, count in unigram_counter.items():
  idx = alphabet.index(char)
  # Unnormalised
  unigram_counts[idx] += float(count)

# Find maximum a posteriori probability (MAP)
initial = dirichlet_map(alphabet_size*[alpha],
                        unigram_counts[np.newaxis, :]).flatten()
# Sanity check that things add up correctly
assert np.isclose(1.0, np.sum(initial))

bigrams = list(zip(txt_bigram[:-1], txt_bigram[1:]))
bigram_counter = collections.Counter(bigrams)

for chars, count in bigram_counter.items():
  idx = (alphabet.index(chars[0]), alphabet.index(chars[1]))
  # Unnormalised counts
  bigram_counts[idx] += float(count)

# Find maximum a posteriori probability (MAP)
transition = dirichlet_map(alphabet_size*[alpha], bigram_counts)
assert np.isclose(1.0, np.sum(transition, axis=1)).all()

Let's see what the model has learned!

First let's look at what the model thinks should come next in a sequence.

Run the cell below and start by typing a single letter to visualise what characters most frequently come next.

In [None]:
#@markdown We can play around with this to string together longer words that our model thinks are likely! Start with 'a', which letter is likely to come next? And what is likely to come after that? Can you find some common 3 letter words?

#@markdown > NOTE: Do not use any other keys, such as backspace or arrows. If you make a mistake, please re-run this cell.

print('Start typing character by character:')
process = ProcessKeyboardInput(model_type='bigram')
graph_img = Img()
input_field = Div()
input_field.add_event_listener('keydown', process.process_key)
display(input_field)
graph_img

Now, while this was fast and easy, only looking at what character immediately follows has its limitations.

What do you think is the next character in `badminto`? If we asked this language model to complete the word for us, all we could do is look at the last character (`o`) and check what is most likely to follow, in this case `u`. `badmintou` is obviously not the correct solution, but the only way we could have known this is by looking at all the other characters before.

Hence, we will now work our way towards a much more complex algorithm for language that can in theory process an **infinite** amount of previous characters. Pretty cool!



## Building a prediction dataset

We've got the text, and we know which character to uniquely map to which number. Before we can proceed with a more complex model, we need to build a prediction **dataset**. A dataset is just a collection of things we want to model. In our case, we will use a dataset consisting of sequences of characters and the characters following it. We call these **input** characters and **target** characters. The model we will create will take the input and predict the target.

In [None]:
SEQUENCE_LENGTH = 100  #@param {type:"integer"}
BATCH_SIZE = 32  #@param {type:"integer"}

text_as_numbers = [character_to_id[character] for character in text]

dataset = tf.data.Dataset.from_tensor_slices(np.array(text_as_numbers))
dataset = dataset.batch(SEQUENCE_LENGTH + 1, drop_remainder=True)
dataset = dataset.map(split_input_target)
dataset = dataset.shuffle(10000)
dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)

## Defining the model

We will create a model called the **Recurrent Neural Network**, which you can think of as a rather complex function with a lot of small elements called **parameters** you can tweak. By tweaking (or what we call **learning**) these parameters in a particular way, on a particular dataset, we can make this function predict elements of a sequence, such as voice, or language. This type of a model is particularly well fit for predicting next characters (targets) from previous ones (input).

In [None]:
#@title Model definition

#@markdown We added some basic parameters below that you can modify to change the training behaviour. If this is the first time you run training, we suggest that you use the preset parameter values for the first time you run training. You can modify the values later and see their effect on the output.

WORD_ENCODING_SIZE = 64  #@param  [16, 32, 64, 128, 256] {allow-input: false}
NUMBER_OF_UNITS = 256  #@param [32, 64, 128, 256, 512] {allow-input: false}
LEARNING_RATE = 0.001  #@param [0.00001, 0.0001, 0.001, 0.01, 0.1] {allow-input: false}

model = keras.models.Sequential([
    keras.layers.Embedding(input_dim=vocabulary_size,
                           output_dim=WORD_ENCODING_SIZE),
    keras.layers.LSTM(NUMBER_OF_UNITS, return_sequences=True),
    keras.layers.Dense(vocabulary_size)
])

optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE)

def get_loss(labels, logits):
  return keras.losses.sparse_categorical_crossentropy(labels, logits,
                                                      from_logits=True)

model.compile(loss=get_loss, optimizer=optimizer, metrics=['accuracy'])

The model we defined above is not trained, i.e. its parameters are set to random values, meaning its output should be random. Let's check that:

In [None]:
INITIAL_TEXT = 'ROMEO:'
generated_text = generate_text(INITIAL_TEXT, length=100, sampling='random')
print_generated_text(INITIAL_TEXT, generated_text)

Ok, so our model doesn't do anything useful. Let's train it now. Throughout training we will show statistics called `loss` and `accuracy`. The `loss` is loosely speaking telling us how well we are doing on the current set of data, while the `accuracy` tells us how many of the next characters we predicted correctly.

In [None]:
#@title Model Training

EPOCHS = 5  #@param {type:"slider", min:5, max:30, step:1}

#@markdown > After training a model for an epoch (one pass through the data), we print out samples of generated language. Do they visibly improve?

history = model.fit(dataset, epochs=EPOCHS,
                    callbacks=[print_callback], batch_size=BATCH_SIZE)

Let's now take a look at how this model predicts next characters:

In [None]:
#@markdown We can play around with this to string together longer words that our model thinks are likely! Start with 'a', which letter is likely to come next? And what is likely to come after that? Can you find some common 3 letter words?

#@markdown > NOTE: Do not use any other keys, such as backspace or arrows. If you make a mistake, please re-run this cell.

print('Start typing character by character:')
process = ProcessKeyboardInput(model_type='rnn')
graph_img = Img()
input_field = Div()
input_field.add_event_listener('keydown', process.process_key)
display(input_field)
graph_img

In the final section, you will explore some cool applications of language models. 

You can choose either between decoding secret messages or generating even more sophisticated text!

# Application: Decoding secret messages!

<center>
<img src="https://storage.googleapis.com/dm-educational/assets/fun-with-language/decrypt-ohno.png" alt="drawing" height="220"/>
</center>


On a sunny afternoon, a distressed police officer storms into the mathematics department at the local university. In bitter frustration he explains he spent all night trying to decode cryptic messages he found on his usual patrol yesterday evening - to no avail. He shows you an example message:

<img src="https://i.imgur.com/nZYOhnT.png" alt="An encrypted message" width=700>

*Pretty cryptic, right?*

Unsure of what to, do the mathematicians ask you for help, having heard that you have recently been interested in AI and language. 

So far they've only found out that the messages are encrypted by replacing each instance of a specific character with another at random. For example, if we wanted to encrypt `pancake`, we start by replacing `a` with `s`, ending up with `psncske`. Then, replacing `p` with `[` with get `[sncske` and continue until the message is no longer legible. You can try this below:



In [None]:
#@title Encrypt your own message

your_message = 'pancake'  #@param {type:"string"}
your_message = [ch if ch in alphabet else '' for ch in your_message.lower()]
mapping = random_permutation(alphabet, alphabet)
inv_mapping = {v: k for k, v in mapping.items()}

encoded_message = decode(your_message, mapping)
print('Your message encrypted: {}'.format(encoded_message))

The key question is how to reverse this procedure.

**Can you come up with a solution?**

## Reversing the encryption process



If we knew with which symbol each character had been replaced with this would be quite easy: We would simply look at the encryption message and replace all symbols with their original counterpart: `[sncske` becomes `pancake` again. Could we simply guess a mapping and try our luck? Unfortunately, the mathematicians tell you that they calculated that this means possibly having to try **403291461126605635584000000** possible mappings before finding the right one, so we will have to be a bit smarter.

*You have an idea*: Say we have correctly guessed that `a` has been replaced with `s`, but incorrectly think `[` is the encrypted symbol for `?`. We try to reverse the process and end up with `?ancake`. Not the correct solution of course, but this certainly looks a lot more like an English word. 

This seems to suggest that we got part of the mapping right. So instead of randomly guessing anew, we could just make a tiny change (such as changing the reverse mapping for `[`)and see if the solution has improved.


Here's the key idea: **We can use a language model to judge how similar a piece of text looks to commonly found English.**

<center>
<img src="https://storage.googleapis.com/dm-educational/assets/fun-with-language/decrypt-success.png" alt="drawing" height="220"/>
</center>



We've seen that we can use language models to generate novel texts. However, we can also use them to differentiate between language and gibberish. For each character of the message, the language model will tell us what the probability of the character (given previous characters) is. We use these probabilities to calculate the probability of the whole sequence of characters. Given that our language model has been trained on English, this implies: Higher probability → better English. 

A sentence written in English will have a higher probability than a gibberish sentence, so let's use the probability of a sentence to guide us in the decoding process! 

The code you will run now uses a language model and a clever technique called `Metropolis Hastings` to decide when we should move from a current mapping to another one. Give it a try!

## Running the decryption algorithm



Next, we will decode the decryption procedure. You can choose between the two language models we have trained so far. Because neural networks are slower than simpler models, the code below might take a bit longer to run with a neural network. You can also try both and see what works better!

Below you can also set the number of decryption steps. This corresponds to how many mappings we will try out before we give up. Hence, a higher number is likely to lead to better results, but means we will have to wait longer. We will print the currently decoded message after every couple of steps so you can see how well you are doing along the process. If the quality isn't quite good enough, try to improve the number of decryption steps or the model you use.

In [None]:
#@title Choose a message to decrypt.

message_choice = '2: A famous novel'  #@param ["1: The police message", "2: A famous novel", "3: An important speech"] {allow-input: false}

if int(message_choice[0]) == 1:
  message = 'g??#zg?zn(94ojz?w?3938z!jzk?3#(otz!o3*pz[?o(938zi3tjzoz!tok*z#(ok*ba9#pz8tiw?bzo34z!otoktowosz9zuow?zo((o38?4z?w?(j#u938zni(zabz#iz?g)#jz#u?zwoat#sz#u?zkt?(*bzo(?zkig)t9k9#zo34z[9ttz3i#zbia34z#u?zoto(gsz[?z?5)?k#zgo3jz)ia34bz#iz!?z93z#u?zbon?z#uo#z398u#szgo*?zba(?z#iz!?zi3z#9g?pz#u?z4(9w?(z[9ttzi3tjz[o9#zni(zuotnzo3zuia(za3#9tz[?z3??4z3??4z#iz8?#zia#z2a9k*sz4i3#zni(8?#z#izab?z#u?z?3k(j)#9i3z4?w9k?z9zuow?zjiazni(zo3jzna(#u?(zkigga39ko#9i3zo34z*??)zjia(z3?w?(bzkotgsz[?z[9ttzuow?zozn93otz!(9?n938zo3zuia(z!?ni(?z?w?(j#u938z9bzbku?4at?4z#iz#o*?z)tok?szi3k?z[?(?zia#zjiaz8?#zjia(zbuo(?zo34z9gg?49o#?tjzu?o4z#iz#u?zo9()i(#zni(zjia(znt98u#zbiz3izi3?zko3zko#kuzjiaszia(z)to33938z9bz)?(n?k#pz#u9bz898zko3#z)ibb9!tjz8iz[(i38s'
elif int(message_choice[0]) == 2:
  message = '[:p1lpl]x:r?!pw:np1]!?pmx;:?!w.;?pl?w!gp1lp/ws-?!prwm?p1?pg]1?pwnm[2?ps-wsp[ m?p.??:psx!:[:rp]m?!p[:p1lp1[:np?m?!pg[:2?5p6z-?:?m?!pl]xp/??;p;[7?p2![s[2[0[:rpw:lp]:?,6p-?ps];np1?,p6vxgsp!?1?1.?!ps-wspw;;ps-?pu?]u;?p[:ps-[gpz]!;np-wm?: sp-wnps-?pwnmw:swr?gps-wspl]x m?p-wn56p-?pn[n: spgwlpw:lp1]!?p.xspz? m?pw;zwlgp.??:px:xgxw;;lp2]11x:[2ws[m?p[:pwp!?g?!m?npzwl,pw:np[px:n?!gs]]nps-wsp-?p1?w:spwpr!?wspn?w;p1]!?ps-w:ps-ws5p[:p2]:g?9x?:2?p[ 1p[:2;[:?nps]p!?g?!m?pw;;pvxnr1?:sg,pwp-w.[sps-wsp-wgp]u?:?npxup1w:lp2x![]xgp:wsx!?gps]p1?pw:npw;g]p1wn?p1?ps-?pm[2s[1p]/p:]spwp/?zpm?s?!w:p.]!?g5ps-?pw.:]!1w;p1[:np[gp9x[27ps]pn?s?2spw:npwssw2-p[sg?;/ps]ps-[gp9xw;[slpz-?:p[spwuu?w!gp[:pwp:]!1w;pu?!g]:,pw:npg]p[sp2w1?pw.]xsps-wsp[:p2];;?r?p[pzwgpx:vxgs;lpw22xg?np]/p.?[:rpwpu];[s[2[w:,p.?2wxg?p[pzwgpu![mlps]ps-?pg?2!?spr![?/gp]/pz[;n,px:7:]z:p1?:5p1]gsp]/ps-?p2]:/[n?:2?gpz?!?px:g]xr-scc/!?9x?:s;lp[p-wm?p/?[r:?npg;??u,pu!?]22xuws[]:,p]!pwp-]gs[;?p;?m[slpz-?:p[p!?w;[0?np.lpg]1?px:1[gsw7w.;?pg[r:ps-wspw:p[:s[1ws?p!?m?;ws[]:pzwgp9x[m?![:rp]:ps-?p-]![0]:cc/]!ps-?p[:s[1ws?p!?m?;ws[]:gp]/pl]x:rp1?:p]!pwsp;?wgsps-?ps?!1gp[:pz-[2-ps-?lp?4u!?ggps-?1pw!?pxgxw;;lpu;wr[w![gs[2pw:np1w!!?np.lp].m[]xgpgxuu!?gg[]:g5p!?g?!m[:rpvxnr1?:sgp[gpwp1wss?!p]/p[:/[:[s?p-]u?5p[pw1pgs[;;pwp;[ss;?pw/!w[np]/p1[gg[:rpg]1?s-[:rp[/p[p/]!r?sps-ws,pwgp1lp/ws-?!pg:]..[g-;lpgxrr?gs?n,pw:np[pg:]..[g-;lp!?u?wspwpg?:g?p]/ps-?p/x:nw1?:sw;pn?2?:2[?gp[gpuw!2?;;?np]xspx:?9xw;;lpwsp.[!s-5'
elif int(message_choice[0]) == 3:
  message = 'bc(v(3cvx4wvi]vbc4kwmw2,v(3cvx4wvi]vbc((ckc5vbcmnzkc,v:cl4kcvi]v(34(vozcckv4w5vh3c49v(cx9(4(miwv(iv9incv(iv3mxncb]v4w5v(ivi(3cknv4nv(3cvh#wmh,v4nv(3cvx4wvl3iv34nviz(2kilwvcxi(miwnv4w5v:cbmc]n,v(3cvx4wv(ivl3ixv2ii5v4w5vc;mbv4kcv4nviwc)v(3cv9iikcn(vl4#v(iv]4hcvbm]cvmnv(iv]4hcvm(vlm(3v4vnwcck)v(3ckcv4kcvx4w#vxcwvl3iv]ccbv4v?mw5vi]v(lmn(c5v9km5cvmwvh#wmhmnx,v(3ckcv4kcvx4w#vl3ivhiw]mwcv(3cxncb;cnv(ivhkm(mhmnxvi]v(3cvl4#vi(3cknv5ivl34(v(3c#v(3cxncb;cnv54kcvwi(vc;cwv4((cx9()v(3ckcvmnvwivxikcvzw3c4b(3#v:cmw2,vwivx4wvbcnnvlik(3#vi]vkcn9ch(,v(34wv3cvl3ivcm(3ckvkc4bb#v3ib5n,vikv]cm2wnv(iv3ib5,v4wv4((m(z5cvi]vnwcckmw2v5mn:cbmc]v(il4k5v4bbv(34(vmnv2kc4(v4w5vbi](#,vl3c(3ckvmwv4h3mc;cxcw(vikvmwv(34(vwi:bcvc]]ik(vl3mh3,vc;cwvm]vm(v]4mbn,vhixcnvnchiw5v(iv4h3mc;cxcw()v4vh#wmh4bv34:m(vi]v(3iz23(v4w5vn9cch3,v4vkc45mwcnnv(ivhkm(mhmncvlik?vl3mh3v(3cvhkm(mhv3mxncb]vwc;ckv(kmcnv(iv9ck]ikx,v4wvmw(cbbch(z4bv4bii]wcnnvl3mh3vlmbbvwi(v4hhc9(vhiw(4h(vlm(3vbm]cnvkc4bm(mcn,v4bbv(3cncv4kcvx4k?n,vwi(,v4nv(3cv9inncnnikvlizb5v]4mwv(3mw?,vi]vnz9ckmikm(#,v:z(vi]vlc4?wcnn)v(3c#vx4k?v(3cvxcwvzw]m(v(iv:c4kv(3cmkv94k(vx4w]zbb#vmwv(3cvn(ckwvn(km]cvi]vbm;mw2,vl3ivncc?,vmwv(3cv4]]ch(4(miwvi]vhiw(cx9(v]ikv(3cv4h3mc;cxcw(nvi]vi(3ckn,v(iv3m5cv]kixvi(3cknv4w5v]kixv(3cxncb;cnv(3cmkvilwvlc4?wcnn)v(3cvkibcvmnvc4n#,v(3ckcvmnvwiwcvc4nmck,vn4;cviwb#v(3cvkibcvi]v(3cvx4wvl3ivnwccknv4bm?cv4(v:i(3vhkm(mhmnxv4w5v9ck]ikx4whc)vm(vmnvwi(v(3cvhkm(mhvl3ivhizw(n,vwi(v(3cvx4wvl3iv9imw(nviz(v3ilv(3cvn(kiw2vx4wvn(zx:bcn,vikvl3ckcv(3cv5ickvi]v5cc5nvhizb5v34;cv5iwcv(3cxv:c((ck)v(3cvhkc5m(v:cbiw2nv(iv(3cvx4wvl3ivmnv4h(z4bb#vmwv(3cv4kcw4,vl3incv]4hcvmnvx4kkc5v:#v5zn(v4w5vnlc4(v4w5v:bii5,vl3ivn(km;cnv;4bm4w(b#,vl3ivckkn,v4w5vhixcnvn3ik(v424mwv4w5v424mw,v:ch4zncv(3ckcvmnvwivc]]ik(vlm(3iz(vckkikv4w5vn3ik(hixmw2,v:z(vl3iv5icnv4h(z4bb#vn(km;cv(iv5iv(3cv5cc5n,vl3iv?wilnv(3cv2kc4(vcw(3znm4nxn,v(3cv2kc4(v5c;i(miwn,vl3ivn9cw5nv3mxncb]vmwv4vlik(3#vh4znc,vl3iv4(v(3cv:cn(v?wilnvmwv(3cvcw5v(3cv(kmzx93vi]v3m23v4h3mc;cxcw(,v4w5vl3iv4(v(3cvlikn(,vm]v3cv]4mbn,v4(vbc4n(v]4mbnvl3mbcv54kmw2v2kc4(b#,vniv(34(v3mnv9b4hcvn34bbvwc;ckv:cvlm(3v(3incvhib5v4w5v(mxm5vnizbnvl3iv?wilvwcm(3ckv;mh(ik#vwikv5c]c4()'

message_alphabet = alphabet

In [None]:
#@title Decode message

rng = np.random.RandomState(42)
first_order = False
init_type = 'max_freq'
decryption_steps = 15000  #@param {type:"slider", min:1000, max:50000, step:1}
use_which_model = 'simple (quick)'  #@param ['simple (quick)', 'neural (slow)']

#@markdown Once the decoding has finished, a figure will be shown that allows you to judge by how much the language quality (y-axis) improves as you increase the number of decoding steps (x-axis). You might find this useful when choosing an appropriate value.

if use_which_model == 'simple (quick)':
  seq_ll = functools.partial(seq_log_likelihood,
                             initial=initial,
                             transition=transition,
                             alphabet=alphabet,
                             first_order=first_order)
else:
  seq_ll = nn_seq_log_likelihood


lls = []

if init_type == 'max_freq':
  # Map symbols to each other by first ordering them in decreasing frequency
  # and then reading of the pairs.
  message_counter = collections.Counter(message)
  initial_freq = np.argsort(initial)[::-1]
  mapping = {}

  for ((e, _), d_id) in zip(message_counter.most_common(), initial_freq):
    mapping[e] = alphabet[d_id]

  # Take care of any remaining symbols not present in message
  remaining_keys = list(set(alphabet) - set(mapping.keys()))
  remaining_values = list(set(alphabet) - set(mapping.values()))
  mapping.update(random_permutation(remaining_keys, remaining_values, rng))

  initial_ll = seq_ll(decode(message, mapping))
elif init_type == 'random':
  message_alphabet = alphabet
  mapping = random_permutation(message_alphabet, alphabet, rng)
  initial_ll = seq_ll(decode(message, mapping))
else:
  assert False, 'Not available'

lls = [initial_ll]

print('Initial encoded message:\t{}'.format(decode(message, mapping)))

for i in tqdm(range(decryption_steps), ncols=1100):
  if 0 == (i+1) % 1000:
    print('Step {}:\t{}'.format(i+1, decode(message, mapping)))
  new_mapping = state_transition(mapping, rng)
  proposal_ll = seq_ll(decode(message, new_mapping))

  # Metropolis Hastings
  A = min(1, np.exp(proposal_ll - lls[-1]))

  if rng.uniform() < A:
    lls.append(proposal_ll)
    mapping = new_mapping
  else:
    lls.append(lls[-1])


plt.figure(figsize=(15, 7.5))
plt.plot(lls, lw=3)
plt.grid(True)
plt.xlabel('Number of steps', fontsize=15)
plt.ylabel('Text quality', fontsize=15)
plt.tick_params(labelsize=15)
plt.locator_params(nbins=10)

ax = plt.gca()
ax.axes.yaxis.set_ticklabels([])

plt.show()

In [None]:
#@title Final decoded message

#@markdown As you enthusiastically hand over the decoded message, the mathematicians and the police officer are in awe. Once again, AI has saved the day:

print(decode(message, mapping))

# Application: Generating text!

<center>
<img src="https://storage.googleapis.com/dm-educational/assets/fun-with-language/robeo.png" alt="drawing" height="280"/>
</center>




Now we've got a trained language model: it's a model that predicts the next character, following a piece of text. We can predict the next character, add the predicted character to the text, predict the next character, etc. thus making the model generate texts!


Let's generate some texts now:

In [None]:
#@title Let's apply the model on your piece of text

#@markdown Different training settings can result in a text that is more or less interesting. One simple thing that can change the resulting text is how we use character probabilities to generate a character. By choosing **random** sampling, we make the model pick at random the following character based on the probability of each character. By choosing **greedy** we pick the character with the highest probability. Greedy choice will always result with the same, and less interesting texts, and the random one will produce a new text every time. Try it out!

TEXT_START = 'Education means '  #@param {type:"string"}
SAMPLING = 'random'  #@param ["random", "greedy"]
LENGTH = 450  #@param {type:"slider", min:50, max:500, step:50}

print(generate_text(TEXT_START, length=LENGTH, sampling=SAMPLING))

Does the generated language 'look like English'?

Are the generated words meaningful?

Is the text as a whole meaningful? If not, how many 'words in a sequence' are meaningful?

## Language Models on words

We've seen that Language Models trained on characters can (surprisingly?) learn to generate somewhat meaningful text. Now, let's take a look at how a close to state-of-the-art Language Model called GPT-2 can generate text. This model comes in a few versions (corresponding to different model sizes), so why don't you try which one you like best?

In [None]:
model_version = 'gpt2-medium'  #@param ['gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl']

#@markdown You can choose the size of the pretrained GPT-2 model. The larger, the longer it will take to download the model and generate the text.

gpt2_tokenizer = transformers.GPT2Tokenizer.from_pretrained(model_version)
gpt2_model = transformers.TFGPT2LMHeadModel.from_pretrained(
    model_version, pad_token_id=gpt2_tokenizer.eos_token_id)

In [None]:
#@title Let's apply the model on your piece of text (try choosing your own!)

TEXT_START = 'ROMEO:'  #@param {type:"string"}
LENGTH = 100  #@param {type:"slider", min:50, max:300, step:50}

input_as_numbers = gpt2_tokenizer.encode(TEXT_START, return_tensors='tf')
generated_output = gpt2_model.generate(input_as_numbers, max_length=LENGTH,
                                       do_sample=True)

generated_text = gpt2_tokenizer.decode(generated_output[0],
                                       skip_special_tokens=True)

print_generated_text(TEXT_START, generated_text)


## Discussion:

*   How does the quality of the generated text compare to our character-based model?
*   Does this model produce meaningless words?
*   Is the whole text meaningful? If not, how many 'words in a sequence' are meaningful?
*   Does the quality of generated text improve with the size of the model?