Bharath Gunasekaran

This colab is to user Perciever IO model by Deep Minds to perform an interesting ML task. Implemented Perciever IO to predict a masked word in a sentence is. 

References:

https://colab.research.google.com/github/2796gaurav/code_examples/blob/main/Perceiver/Perceiver_masked_language_modelling.ipynb#scrollTo=ipZs6p0Xk3lb

https://medium.com/analytics-vidhya/perceiver-io-a-general-architecture-for-structured-inputs-outputs-4ad669315e7f

In [1]:
# Install dependencies for Google Colab.
# If you want to run this notebook on your own machine, you can skip this cell
!pip install dm-haiku
!pip install einops

!mkdir /content/perceiver
!touch /content/perceiver/__init__.py
!wget -O /content/perceiver/bytes_tokenizer.py https://raw.githubusercontent.com/deepmind/deepmind-research/master/perceiver/bytes_tokenizer.py
!wget -O /content/perceiver/io_processors.py https://raw.githubusercontent.com/deepmind/deepmind-research/master/perceiver/io_processors.py
!wget -O /content/perceiver/perceiver.py https://raw.githubusercontent.com/deepmind/deepmind-research/master/perceiver/perceiver.py
!wget -O /content/perceiver/position_encoding.py https://raw.githubusercontent.com/deepmind/deepmind-research/master/perceiver/position_encoding.py

Collecting dm-haiku
  Downloading dm_haiku-0.0.4-py3-none-any.whl (284 kB)
[?25l[K     |█▏                              | 10 kB 27.7 MB/s eta 0:00:01[K     |██▎                             | 20 kB 27.3 MB/s eta 0:00:01[K     |███▌                            | 30 kB 19.5 MB/s eta 0:00:01[K     |████▋                           | 40 kB 16.8 MB/s eta 0:00:01[K     |█████▊                          | 51 kB 8.0 MB/s eta 0:00:01[K     |███████                         | 61 kB 9.3 MB/s eta 0:00:01[K     |████████                        | 71 kB 8.6 MB/s eta 0:00:01[K     |█████████▏                      | 81 kB 9.6 MB/s eta 0:00:01[K     |██████████▍                     | 92 kB 10.3 MB/s eta 0:00:01[K     |███████████▌                    | 102 kB 8.0 MB/s eta 0:00:01[K     |████████████▊                   | 112 kB 8.0 MB/s eta 0:00:01[K     |█████████████▉                  | 122 kB 8.0 MB/s eta 0:00:01[K     |███████████████                 | 133 kB 8.0 MB/s eta 0:00:01

In [2]:
#@title Import
from typing import Union

import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
import pickle

from perceiver import perceiver, position_encoding, io_processors, bytes_tokenizer

In [3]:

#@title Load parameters from checkpoint
!wget -O language_perceiver_io_bytes.pickle https://storage.googleapis.com/perceiver_io/language_perceiver_io_bytes.pickle

with open("language_perceiver_io_bytes.pickle", "rb") as f:
  params = pickle.loads(f.read())

--2021-10-09 19:37:25--  https://storage.googleapis.com/perceiver_io/language_perceiver_io_bytes.pickle
Resolving storage.googleapis.com (storage.googleapis.com)... 74.125.135.128, 74.125.142.128, 74.125.195.128, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|74.125.135.128|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 804479532 (767M) [application/octet-stream]
Saving to: ‘language_perceiver_io_bytes.pickle’


2021-10-09 19:37:30 (166 MB/s) - ‘language_perceiver_io_bytes.pickle’ saved [804479532/804479532]



In [4]:

#@title Model config
D_MODEL = 768
D_LATENTS = 1280
MAX_SEQ_LEN = 2048

encoder_config = dict(
    num_self_attends_per_block=26,
    num_blocks=1,
    z_index_dim=256,
    num_z_channels=D_LATENTS,
    num_self_attend_heads=8,
    num_cross_attend_heads=8,
    qk_channels=8 * 32,
    v_channels=D_LATENTS,
    use_query_residual=True,
    cross_attend_widening_factor=1,
    self_attend_widening_factor=1)

decoder_config = dict(
    output_num_channels=D_LATENTS,
    position_encoding_type='trainable',
    output_index_dims=MAX_SEQ_LEN,
    num_z_channels=D_LATENTS,
    qk_channels=8 * 32,
    v_channels=D_MODEL,
    num_heads=8,
    final_project=False,
    use_query_residual=False,
    trainable_position_encoding_kwargs=dict(num_channels=D_MODEL))



In [5]:
#@title Decoding Perceiver Model
def apply_perceiver(
    inputs: jnp.ndarray, input_mask: jnp.ndarray) -> jnp.ndarray:
  """Runs a forward pass on the Perceiver.

  Args:
    inputs: input bytes, an int array of shape [B, T]
    input_mask: Array of shape indicating which entries are valid and which are
      masked. A truthy value indicates that the entry is valid.

  Returns:
    The output logits, an array of shape [B, T, vocab_size].
  """
  assert inputs.shape[1] == MAX_SEQ_LEN

  embedding_layer = hk.Embed(
      vocab_size=tokenizer.vocab_size,
      embed_dim=D_MODEL)
  embedded_inputs = embedding_layer(inputs)

  batch_size = embedded_inputs.shape[0]

  input_pos_encoding = perceiver.position_encoding.TrainablePositionEncoding(
      index_dim=MAX_SEQ_LEN, num_channels=D_MODEL)
  embedded_inputs = embedded_inputs + input_pos_encoding(batch_size)
  perceiver_mod = perceiver.Perceiver(
      encoder=perceiver.PerceiverEncoder(**encoder_config),
      decoder=perceiver.BasicDecoder(**decoder_config))
  output_embeddings = perceiver_mod(
      embedded_inputs, is_training=False, input_mask=input_mask, query_mask=input_mask)

  logits = io_processors.EmbeddingDecoder(
      embedding_matrix=embedding_layer.embeddings)(output_embeddings)
  return logits

apply_perceiver = hk.transform(apply_perceiver).apply

In [7]:

#@title Pad and reshape inputs
inputs = input_tokens[None]
input_mask = np.ones_like(inputs)

def pad(max_sequence_length: int, inputs, input_mask):
  input_len = inputs.shape[1]
  assert input_len <= max_sequence_length
  pad_len = max_sequence_length - input_len
  padded_inputs = np.pad(
      inputs,
      pad_width=((0, 0), (0, pad_len)),
      constant_values=tokenizer.pad_token)
  padded_mask = np.pad(
      input_mask,
      pad_width=((0, 0), (0, pad_len)),
      constant_values=0)
  return padded_inputs, padded_mask

inputs, input_mask = pad(MAX_SEQ_LEN, inputs, input_mask)

In [100]:
sentences = [
'This is the missing word in this sentence',
'Situps are a terrible way to end your day',
'As time wore on, simple dog commands turned into full paragraphs explaining why the dog couldn’t do something',
'Hang on, my kittens are scratching at the bathtub and they are upset by the lack of biscuits',
'On a scale from one to ten, what is your favorite flavor of random grammar',
'He had a wall full of masks so she could wear a different face every day'
'She could not decide of the glass was half empty or half full so she drank it'
'The knives were out and she was sharpening hers',
'She could not understand why nobody else could see that the sky is full of cotton candy',
'The blinking lights of the antenna tower came into focus just as I heard a loud snap',
'He wondered if it could be called a beach if there was no sand',
'The boy ran up the hill',
'What you stay focused on will grow',
'Onward and Upward! To Narnia and the North!',
'Write while the heat is in you',
'The writer who postpones the recording of his thoughts uses an iron which has cooled to burn a hole with. He cannot inflame the minds of his audience',
'If you dare nothing, then when the day is over, nothing is all you will have gained',
]

sentences

['This is the missing word in this sentence',
 'Situps are a terrible way to end your day.',
 'As time wore on, simple dog commands turned into full paragraphs explaining why the dog couldn’t do something.',
 'Hang on, my kittens are scratching at the bathtub and they are upset by the lack of biscuits.',
 'On a scale from one to ten, what is your favorite flavor of random grammar?',
 'He had a wall full of masks so she could wear a different face every day.She could not decide of the glass was half empty or half full so she drank it.The knives were out and she was sharpening hers.',
 'She could not understand why nobody else could see that the sky is full of cotton candy.',
 'The blinking lights of the antenna tower came into focus just as I heard a loud snap.',
 'He wondered if it could be called a beach if there was no sand.']

In [134]:
def mask_word(sentences, index):
  incomplete_sentences = []
  missing_word = []
  for text in sentences:
      entry = {}
      incomplete_text = ''
      entry['word'] = text.split(" ")[index]
      if index == -1 :
        text_preprocessed = text.split(" ")[:-1]
      elif index > 0 and index < len(text):
        text_preprocessed = text.split(" ")[:index] +  text.split(" ")[index+1:]
      else: 
         text_preprocessed = text.split(" ")[index:]
      incomplete_text = " ".join(text_preprocessed)
      entry['bi'] = text.index(entry['word'])
      entry['ei'] = entry['bi'] + len(entry['word']) 
      missing_word.append(entry)
      incomplete_sentences.append(incomplete_text)
  return incomplete_sentences, missing_word

In [140]:
def validatePredictions(predictions, expected):
  correct = 0
  for i in range(len(sentences)):
    missing = expected[i]
    if predictions[i].lower() == missing['word'].lower():
      correct = correct +1

    print("Actual Sentence")
    print(sentences[i])

    print("Sentence with Predicted Word")
    print(sentences[i].replace(missing['word'], "["+predictions[i]+"]"))
    print('\n')
  print("Accuracy {}".format(correct/len(sentences)))
    

In [137]:
def runExperiment(sentences, index):
  incomplete_sentences, missing_word = mask_word(sentences, index)

  tokenizer = bytes_tokenizer.BytesTokenizer()

  # Encode Sentences
  encoded_sentences = []
  for text in sentences:
    input_tokens = tokenizer.to_int(text)
    encoded_sentences.append(input_tokens)

  # Apply Missing Mask to sentences
  for i in range(len(sentences)):
    encoded_sentences[i][missing_word[i]['bi']:missing_word[i]['ei']] = tokenizer.mask_token

  input_sentences = [text[None] for text in encoded_sentences]
  input_sentence_mask = [np.ones_like(inputs) for inputs in input_sentences] 

  # Adding Paddings
  input_sentence_pad = []
  input_sentence_mask_pad = []
  for i in range(len(sentences)):
    inputs, input_mask = pad(MAX_SEQ_LEN, input_sentences[i], input_sentence_mask[i])
    input_sentence_pad.append(inputs)
    input_sentence_mask_pad.append(input_mask)

  # Run Predictions
  rng = jax.random.PRNGKey(1)  # Unused
  predictions = []
  for i in range(len(sentences)):
    out = apply_perceiver(params, rng=rng, inputs=input_sentence_pad[i], input_mask=input_sentence_mask_pad[i])
    missing = missing_word[i]
    masked_tokens_predictions = out[0, missing['bi']:missing['ei']].argmax(axis=-1)
    predictions.append(tokenizer.to_string(masked_tokens_predictions))  

  validatePredictions(predictions,missing_word)

In [141]:
# What is the accuracy when front of sentence is missing?
runExperiment(sentences, 0)

Actual Sentence
This is the missing word in this sentence
Sentence with Predicted Word
[What] is the missing word in this sentence


Actual Sentence
Situps are a terrible way to end your day.
Sentence with Predicted Word
[ tiees] are a terrible way to end your day.


Actual Sentence
As time wore on, simple dog commands turned into full paragraphs explaining why the dog couldn’t do something.
Sentence with Predicted Word
[as] time wore on, simple dog commands turned into full paragraphs explaining why the dog couldn’t do something.


Actual Sentence
Hang on, my kittens are scratching at the bathtub and they are upset by the lack of biscuits.
Sentence with Predicted Word
[  no] on, my kittens are scratching at the bathtub and they are upset by the lack of biscuits.


Actual Sentence
On a scale from one to ten, what is your favorite flavor of random grammar?
Sentence with Predicted Word
[on] a scale from one to ten, what is your favorite flavor of random grammar?


Actual Sentence
He had 

In [142]:
runExperiment(sentences,-1)

Actual Sentence
This is the missing word in this sentence
Sentence with Predicted Word
This is the missing word in this [ perase.]


Actual Sentence
Situps are a terrible way to end your day.
Sentence with Predicted Word
Situps are a terrible way to end your [ Rd.]


Actual Sentence
As time wore on, simple dog commands turned into full paragraphs explaining why the dog couldn’t do something.
Sentence with Predicted Word
As time wore on, simple dog commands turned into full paragraphs explaining why the dog couldn’t do [ see   the]


Actual Sentence
Hang on, my kittens are scratching at the bathtub and they are upset by the lack of biscuits.
Sentence with Predicted Word
Hang on, my kittens are scratching at the bathtub and they are upset by the lack of [ wateee..]


Actual Sentence
On a scale from one to ten, what is your favorite flavor of random grammar?
Sentence with Predicted Word
On a scale from one to ten, what is your favorite flavor of random [ coenes?]


Actual Sentence
He had 

In [144]:
runExperiment(sentences,5)

Actual Sentence
This is the missing word in this sentence
Sentence with Predicted Word
This is the miss[ a]g word [ a] this sentence


Actual Sentence
Situps are a terrible way to end your day.
Sentence with Predicted Word
Situps are a terrible way [  ] end your day.


Actual Sentence
As time wore on, simple dog commands turned into full paragraphs explaining why the dog couldn’t do something.
Sentence with Predicted Word
As time wore on, simple [ of] commands turned into full paragraphs explaining why the [ of] couldn’t do something.


Actual Sentence
Hang on, my kittens are scratching at the bathtub and they are upset by the lack of biscuits.
Sentence with Predicted Word
Hang on, my kittens are [ soi ng up] at the bathtub and they are upset by the lack of biscuits.


Actual Sentence
On a scale from one to ten, what is your favorite flavor of random grammar?
Sentence with Predicted Word
On a scale from one [ o] ten, what is your favorite flavor of random grammar?


Actual Sentence
He 