**Objective**: Use perceiver IO model code in a colab for an interesting ml task

Source: 
* https://colab.research.google.com/github/2796gaurav/code_examples/blob/main/Perceiver/Perceiver_masked_language_modelling.ipynb#scrollTo=ipZs6p0Xk3lb 
* https://medium.com/analytics-vidhya/perceiver-io-a-general-architecture-for-structured-inputs-outputs-4ad669315e7f 

In [1]:
!pip install dm-haiku
!pip install einops

!mkdir /content/perceiver
!touch /content/perceiver/__init__.py
!wget -O /content/perceiver/bytes_tokenizer.py https://raw.githubusercontent.com/deepmind/deepmind-research/master/perceiver/bytes_tokenizer.py
!wget -O /content/perceiver/io_processors.py https://raw.githubusercontent.com/deepmind/deepmind-research/master/perceiver/io_processors.py
!wget -O /content/perceiver/perceiver.py https://raw.githubusercontent.com/deepmind/deepmind-research/master/perceiver/perceiver.py
!wget -O /content/perceiver/position_encoding.py https://raw.githubusercontent.com/deepmind/deepmind-research/master/perceiver/position_encoding.py

Collecting dm-haiku
  Downloading dm_haiku-0.0.4-py3-none-any.whl (284 kB)
[?25l[K     |█▏                              | 10 kB 18.8 MB/s eta 0:00:01[K     |██▎                             | 20 kB 13.5 MB/s eta 0:00:01[K     |███▌                            | 30 kB 10.1 MB/s eta 0:00:01[K     |████▋                           | 40 kB 9.0 MB/s eta 0:00:01[K     |█████▊                          | 51 kB 4.6 MB/s eta 0:00:01[K     |███████                         | 61 kB 4.9 MB/s eta 0:00:01[K     |████████                        | 71 kB 4.3 MB/s eta 0:00:01[K     |█████████▏                      | 81 kB 4.9 MB/s eta 0:00:01[K     |██████████▍                     | 92 kB 4.7 MB/s eta 0:00:01[K     |███████████▌                    | 102 kB 3.8 MB/s eta 0:00:01[K     |████████████▊                   | 112 kB 3.8 MB/s eta 0:00:01[K     |█████████████▉                  | 122 kB 3.8 MB/s eta 0:00:01[K     |███████████████                 | 133 kB 3.8 MB/s eta 0:00:01

In [2]:
from typing import Union

import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
import pickle

from perceiver import perceiver, position_encoding, io_processors, bytes_tokenizer

In [3]:
!wget -O language_perceiver_io_bytes.pickle https://storage.googleapis.com/perceiver_io/language_perceiver_io_bytes.pickle

with open("language_perceiver_io_bytes.pickle", "rb") as f:
  params = pickle.loads(f.read())

--2021-10-07 02:22:06--  https://storage.googleapis.com/perceiver_io/language_perceiver_io_bytes.pickle
Resolving storage.googleapis.com (storage.googleapis.com)... 108.177.97.128, 108.177.125.128, 142.251.8.128, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|108.177.97.128|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 804479532 (767M) [application/octet-stream]
Saving to: ‘language_perceiver_io_bytes.pickle’


2021-10-07 02:22:33 (30.9 MB/s) - ‘language_perceiver_io_bytes.pickle’ saved [804479532/804479532]



# Text Dataset
Using Kaggle [Quotes Dataset](https://www.kaggle.com/akmittal/quotes-dataset) provided by user Amit Mittal

**Perceiver IO task**: To predict the last word (masked) for 36,937 quotes

In [4]:
! pip install kaggle
! mkdir ~/.kaggle
! cp kaggle.json ~/.kaggle/
! chmod 600 ~/.kaggle/kaggle.json



In [5]:
! kaggle datasets download -d akmittal/quotes-dataset

Downloading quotes-dataset.zip to /content
  0% 0.00/3.88M [00:00<?, ?B/s]
100% 3.88M/3.88M [00:00<00:00, 64.4MB/s]


In [6]:
! unzip quotes-dataset

Archive:  quotes-dataset.zip
  inflating: quotes.json             


## Preprocessing

In [47]:
df_quote = pd.read_json('quotes.json')
total_q = 1000
df_quote = df_quote['Quote'].unique()[:total_q]
print("Number of quotes =", total_q)
df_quote[:5]

Number of quotes = 1000


array(["Don't cry because it's over, smile because it happened.",
       "I'm selfish, impatient and a little insecure. I make mistakes, I am out of control and at times hard to handle. But if you can't handle me at my worst, then you sure as hell don't deserve me at my best.",
       'Be yourself; everyone else is already taken.',
       "Two things are infinite: the universe and human stupidity; and I'm not sure about the universe.",
       "Be who you are and say what you feel, because those who mind don't matter, and those who matter don't mind."],
      dtype=object)

In [48]:
df_quote_incomplete = []
missing_index = [] # stored length of missing word for slicing later
missing_word = []
for text in df_quote:
    missing_word.append(text.split(" ")[-1])
    text_preprocessed = text.split(" ")[:-1]
    incomplete_text = " ".join(text_preprocessed)
    missing_index.append(len(text) - len(incomplete_text))
    df_quote_incomplete.append(incomplete_text)

In [49]:
df_quote_incomplete[:5]

["Don't cry because it's over, smile because it",
 "I'm selfish, impatient and a little insecure. I make mistakes, I am out of control and at times hard to handle. But if you can't handle me at my worst, then you sure as hell don't deserve me at my",
 'Be yourself; everyone else is already',
 "Two things are infinite: the universe and human stupidity; and I'm not sure about the",
 "Be who you are and say what you feel, because those who mind don't matter, and those who matter don't"]

# Model Config


## Encoder and Decoder

In [10]:
D_MODEL = 768
D_LATENTS = 1280
MAX_SEQ_LEN = 2048

encoder_config = dict(
    num_self_attends_per_block=26,
    num_blocks=1,
    z_index_dim=256,
    num_z_channels=D_LATENTS,
    num_self_attend_heads=8,
    num_cross_attend_heads=8,
    qk_channels=8 * 32,
    v_channels=D_LATENTS,
    use_query_residual=True,
    cross_attend_widening_factor=1,
    self_attend_widening_factor=1)

decoder_config = dict(
    output_num_channels=D_LATENTS,
    position_encoding_type='trainable',
    output_index_dims=MAX_SEQ_LEN,
    num_z_channels=D_LATENTS,
    qk_channels=8 * 32,
    v_channels=D_MODEL,
    num_heads=8,
    final_project=False,
    use_query_residual=False,
    trainable_position_encoding_kwargs=dict(num_channels=D_MODEL))

## Decoding Perceiver Model

In [11]:
def apply_perceiver(
    inputs: jnp.ndarray, input_mask: jnp.ndarray) -> jnp.ndarray:
  """Runs a forward pass on the Perceiver.

  Args:
    inputs: input bytes, an int array of shape [B, T]
    input_mask: Array of shape indicating which entries are valid and which are
      masked. A truthy value indicates that the entry is valid.

  Returns:
    The output logits, an array of shape [B, T, vocab_size].
  """
  assert inputs.shape[1] == MAX_SEQ_LEN

  embedding_layer = hk.Embed(
      vocab_size=tokenizer.vocab_size,
      embed_dim=D_MODEL)
  embedded_inputs = embedding_layer(inputs)

  batch_size = embedded_inputs.shape[0]

  input_pos_encoding = perceiver.position_encoding.TrainablePositionEncoding(
      index_dim=MAX_SEQ_LEN, num_channels=D_MODEL)
  embedded_inputs = embedded_inputs + input_pos_encoding(batch_size)
  perceiver_mod = perceiver.Perceiver(
      encoder=perceiver.PerceiverEncoder(**encoder_config),
      decoder=perceiver.BasicDecoder(**decoder_config))
  output_embeddings = perceiver_mod(
      embedded_inputs, is_training=False, input_mask=input_mask, query_mask=input_mask)

  logits = io_processors.EmbeddingDecoder(
      embedding_matrix=embedding_layer.embeddings)(output_embeddings)
  return logits

apply_perceiver = hk.transform(apply_perceiver).apply

## Masked text


In [12]:
# Pad each quotes to the same length
def pad(max_sequence_length: int, inputs, input_mask):
  input_len = inputs.shape[1]
  assert input_len <= max_sequence_length
  pad_len = max_sequence_length - input_len
  padded_inputs = np.pad(
      inputs,
      pad_width=((0, 0), (0, pad_len)),
      constant_values=tokenizer.pad_token)
  padded_mask = np.pad(
      input_mask,
      pad_width=((0, 0), (0, pad_len)),
      constant_values=0)
  return padded_inputs, padded_mask

In [50]:
tokenizer = bytes_tokenizer.BytesTokenizer()

# Encode quotes
quote_tokens = []
for text in df_quote:
  input_tokens = tokenizer.to_int(text)
  quote_tokens.append(input_tokens)

# Mask " missing.". Note that the model performs much better if the masked chunk starts with a space
for i in range(total_q):
  quote_tokens[i][len(df_quote_incomplete[i]):missing_index[i] + len(df_quote_incomplete[i])] = tokenizer.mask_token

In [51]:
input_quotes = [text[None] for text in quote_tokens]
input_quotes_mask = [np.ones_like(inputs) for inputs in input_quotes] 

input_quotes_pad = []
input_quotes_mask_pad = []
for i in range(total_q):
  inputs, input_mask = pad(MAX_SEQ_LEN, input_quotes[i], input_quotes_mask[i])
  input_quotes_pad.append(inputs)
  input_quotes_mask_pad.append(input_mask)

# Prediction

In [52]:
rng = jax.random.PRNGKey(1)  # Unused

prediction = []
for i in range(total_q):
  if i % 100 == 0:
    print("Getting prediction -", i)
  out = apply_perceiver(params, rng=rng, inputs=input_quotes_pad[i], input_mask=input_quotes_mask_pad[i])
  masked_tokens_predictions = out[0, len(df_quote_incomplete[i]):missing_index[i] + len(df_quote_incomplete[i])].argmax(axis=-1)
  prediction.append(tokenizer.to_string(masked_tokens_predictions))

Getting prediction - 0
Getting prediction - 100
Getting prediction - 200
Getting prediction - 300
Getting prediction - 400
Getting prediction - 500
Getting prediction - 600
Getting prediction - 700
Getting prediction - 800
Getting prediction - 900


In [53]:
correct = 0
for i in range(total_q):
  if prediction[i] == ' ' + missing_word[i]:
    correct += 1
  else:
    print("Real quote:")
    print(df_quote[i])
    print("Perceiver IO quote:")
    print(df_quote_incomplete[i] + prediction[i])
    print()

Real quote:
Be yourself; everyone else is already taken.
Perceiver IO quote:
Be yourself; everyone else is already there.

Real quote:
Be who you are and say what you feel, because those who mind don't matter, and those who matter don't mind.
Perceiver IO quote:
Be who you are and say what you feel, because those who mind don't matter, and those who matter don't care.

Real quote:
A room without books is like a body without a soul.
Perceiver IO quote:
A room without books is like a body without a bood.

Real quote:
Friendship ... is born at the moment when one man says to another "What! You too? I thought that no one but myself . . .
Perceiver IO quote:
Friendship ... is born at the moment when one man says to another "What! You too? I thought that no one but myself . .  

Real quote:
If you want to know what a man's like, take a good look at how he treats his inferiors, not his equals.
Perceiver IO quote:
If you want to know what a man's like, take a good look at how he treats his inf

In [54]:
print("Perceiver IO accuracy on predicting missing word =", "{0:.0%}".format(correct/total_q))

Perceiver IO accuracy on predicting missing word = 35%
