# PerceiverIO language modeling

An example of a **masked-language model** pretrained using a large text corpus obtained by combining English Wikipedia and C4.

```
@article{Jaegle2021PerceiverIA,
  title={Perceiver IO: A General Architecture for Structured Inputs \& Outputs},
  author={Andrew Jaegle and Sebastian Borgeaud and Jean-Baptiste Alayrac and Carl Doersch and Catalin Ionescu and David Ding and Skanda Koppula and Andrew Brock and Evan Shelhamer and Olivier J. H'enaff and Matthew M. Botvinick and Andrew Zisserman and Oriol Vinyals and Jo{\~a}o Carreira},
  journal={ArXiv},
  year={2021},
  volume={abs/2107.14795}
}
@article{Raffel2020ExploringTL,
  title={Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer},
  author={Colin Raffel and Noam M. Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu},
  journal={ArXiv},
  year={2020},
  volume={abs/1910.10683}
}
```

In [1]:
%%bash

## Haiku is used to convert weights of the original model.
pip install -U dm-haiku



In [2]:
from typing import Any, List, Optional, Union
from functools import partial
import jax.numpy as jnp

from jax import numpy as jnp
from flax import linen as nn
from flax_extra import random
from flax_extra import combinator as cb
from flax_extra.layer import (
    FeedForward,
    FeedForwardCt,
    KVQAttention,
    KVQAttentionCt,
    SelfAttention,
    SelfAttentionCt,
    Encoding,
    MultimodalEncodingCt,
    MultimodalPositionalEncodingCt,
    Decoding,
    TrainablePositionalEncoding,
    EmbedDecoding,
)
from flax_extra import data
from flax_extra.layer.io import (
    input_encoding,
    target_encoding,
    query_encoding,
    output_decoding,
)
from flax_extra.model import PerceiverIO
from util.original_model import variables

Array = jnp.ndarray
Precision = Any
Positions = List[int]
MaybePositions = Optional[Positions]

In [3]:
class PerceiverMLM(nn.Module):
    input_embedding: MultimodalEncodingCt
    input_positional_encoding: Union[MultimodalPositionalEncodingCt, MultimodalEncodingCt]
    encoder_query_encoding: MultimodalPositionalEncodingCt
    decoder_query_encoding: MultimodalPositionalEncodingCt
    n_processor_shards: int = 8
    n_processor_blocks: int = 6
    processor_attention: SelfAttentionCt = SelfAttention
    processor_feed_forward: FeedForwardCt = FeedForward
    encoder_attention: KVQAttentionCt = KVQAttention
    encoder_feed_forward: FeedForwardCt = FeedForward
    use_encoder_q_residual: bool = True
    decoder_attention: KVQAttentionCt = KVQAttention
    decoder_feed_forward: FeedForwardCt = FeedForward
    use_decoder_q_residual: bool = False
    deterministic: bool = True
    precision: Optional[Precision] = None

    @nn.compact
    def __call__(
        self,
        inputs: Union[Array, List[Array]],
        input_mask: Optional[Array] = None,
        targets: Optional[Union[Array, List[Array]]] = None,
        target_mask: Optional[Array] = None,
        output_positions: MaybePositions = None,
    ) -> Array:
        input_embedding = self.input_embedding(name="EmbeddingEncoder")

        # Use the same vocabulary for inputs and outputs.
        def shared_output_embedding():
            return input_embedding

        def decoder_query_encoding(use_teacher_forcing: bool) -> type:
            if use_teacher_forcing:
                return target_encoding(
                    Encoding,
                    preprocessing=shared_output_embedding,
                    aggregation=cb.add(),
                    positional_encoding=self.decoder_query_encoding,
                )
            else:
                return self.decoder_query_encoding

        return PerceiverIO(
            input_encoding=input_encoding(
                Encoding,
                preprocessing=shared_output_embedding,
                aggregation=cb.add(),
                positional_encoding=self.input_positional_encoding,
            ),
            encoder_query_encoding=self.encoder_query_encoding,
            decoder_query_encoding=decoder_query_encoding(
                use_teacher_forcing=targets is not None,
            ),
            output_decoding=output_decoding(
                Decoding,
                embedding_decoding=partial(
                    EmbedDecoding,
                    embedding=shared_output_embedding().embedding,
                )
            ),
            n_processor_shards=self.n_processor_shards,
            n_processor_blocks=self.n_processor_blocks,
            processor_attention=self.processor_attention,
            processor_feed_forward=self.processor_feed_forward,
            encoder_attention=self.encoder_attention,
            encoder_feed_forward=self.encoder_feed_forward,
            use_encoder_q_residual=self.use_encoder_q_residual,
            decoder_attention=self.decoder_attention,
            decoder_feed_forward=self.decoder_feed_forward,
            use_decoder_q_residual=self.use_decoder_q_residual,
            deterministic=self.deterministic,
            precision=self.precision,
            name="PerceiverIO",
        )(inputs, input_mask, targets, target_mask, output_positions)

In [4]:
MAX_INPUT_LENGTH = 2048
D_INPUT = 768

tokenizer = data.bytes_tokenizer(["PAD", "BOS", "EOS", "MASK", "CLS", "SEP"])
model = PerceiverMLM(
    input_embedding=partial(
        nn.Embed,
        num_embeddings=tokenizer.vocab_size,
        features=D_INPUT,
    ),
    input_positional_encoding=partial(
        TrainablePositionalEncoding,
        seqlen=MAX_INPUT_LENGTH,
        dimension=D_INPUT,
    ),
    encoder_query_encoding=query_encoding(
        TrainablePositionalEncoding,
        seqlen=256,
        dimension=1280,
    ),
    decoder_query_encoding=query_encoding(
        TrainablePositionalEncoding,
        seqlen=MAX_INPUT_LENGTH,
        dimension=D_INPUT,
    ),
    n_processor_shards=1,
    n_processor_blocks=26,
    processor_attention=partial(SelfAttention, n_heads=8, d_qk=256, d_v=1280),
    encoder_attention=partial(KVQAttention, n_heads=8, d_qk=256, d_v=1280),
    decoder_attention=partial(KVQAttention, n_heads=8, d_qk=256, d_v=768),
)
model_init = model.init
model_apply = model.apply

In [None]:
%%bash

wget -cO "/tmp/perceiver_mlm_bytes.pickle" "https://storage.googleapis.com/perceiver_io/language_perceiver_io_bytes.pickle"

In [7]:
initial_variables = variables("/tmp/perceiver_mlm_bytes.pickle")
rng = random.sequence(seed=0)
collections = ["params"]

input_tokens = "This is an incomplete sentence where some words are missing."
input_ids = tokenizer.to_ids(input_tokens)
# Mask " missing.". Note that the model performs much better if the masked chunk starts with a space.
input_ids[51:60] = tokenizer.reserved_ids.get("MASK")
print(f"Input sequence without masked text:\n`{tokenizer.to_tokens(input_ids)}`")
inputs = tokenizer.pad(input_ids[None], max_length=MAX_INPUT_LENGTH)
input_mask = tokenizer.pad(jnp.ones_like(input_ids)[None], max_length=MAX_INPUT_LENGTH)

outputs = model_apply(
    initial_variables,
    inputs=inputs,
    input_mask=input_mask,
    rngs=random.into_collection(key=next(rng), labels=collections),
)

output_ids = outputs[0, 51:60].argmax(axis=-1)
print(f"Predicted text:\n`{tokenizer.to_tokens(output_ids)}` <- {output_ids}")

Input sequence without masked text:
`This is an incomplete sentence where some words are`
Predicted text:
` missing.` <- [ 38 115 111 121 121 111 116 109  52]
