## T5-Plex Demo

*Licensed under the Apache License, Version 2.0.*

\

<a href="https://colab.research.google.com/github/google/uncertainty-baselines/blob/main/experimental/plex/plex_t5_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

To run this public colab, please use its `Connect to a local runtime` option by following the **Setup Guide** below.

\
\
This notebook demonstrates how one can load the released **T5-Plex** checkpoints from the *Plex: Towards Reliability using Pretrained Large Model Extensions* paper using [JAX](https://jax.readthedocs.io/), and run inference on a single example. 

For more advanced usage, full training and fine-tuning scripts can be found at https://github.com/google/uncertainty-baselines/tree/main/baselines/jft.

## Setup Guide for Colab Local Runtime
(This setup guide is adapted from [T5X](https://github.com/google-research/t5x/blob/main/t5x/notebooks/README.md) public colab.)

Currently the [default public Colab](https://colab.research.google.com/) doesn't support Python version higher than 3.7, which is needed for running Plex-T5 models. Here we provide an alternative: creating a custom jupyter kernel/runtime on a local machine (or a cloud machine), and then use Colab's `Connect to a local runtime` option to run this notebook. [This page](https://research.google.com/colaboratory/local-runtimes.html) contains additional details on how to setup up and use a local runtime.


### Prepare python env

On a local machine (or a cloud machine), create a Python environment via

```
sudo apt update
sudo apt install python3-pip

sudo apt install -y python3.10 python3.10-venv
python3.10 -m venv plex_t5_venv
```

### Install T5X with its dependencies.

```
source plex_t5_venv/bin/activate
python3 -m pip install -U pip setuptools wheel ipython
python3 -m pip install flax
git clone --branch=main https://github.com/google-research/t5x
cd t5x
python3 -m pip install -e .
cd -
```

### Install `uncertainty-baselines` with its dependencies.

```
rm -rf uncertainty-baselines
git clone https://github.com/google/uncertainty-baselines.git
cp -r uncertainty-baselines/baselines/t5/* .

python3 -m pip install ./uncertainty-baselines[models,datasets]
```

After installing the above, the versions of `jax` and `jaxlib` will be downgraded. Thus we need to re-install the newer versions:

```
python3 -m pip install jax==0.3.23 jaxlib==0.3.22
```
Note that you might get warnings about `tensorflow-federated` having incompatible version. This does not affect the usage of this colab.

### Install and launch Jupyter
At last, we prepare a Jupyter local runtime that can be accessed by our colab notebook. For more detailed official instructions see [here](https://research.google.com/colaboratory/local-runtimes.html).

```
python3 -m pip install notebook
python3 -m pip install --upgrade jupyter_http_over_ws>=0.0.7
jupyter serverextension enable --py jupyter_http_over_ws
```

Use the command below to **launch** the prepared runtime.

```
jupyter notebook   --NotebookApp.allow_origin='https://colab.research.google.com'   --port=8888   --NotebookApp.port_retries=0
```

Note that, depending on the installation path of `jupyter`, you might need to swap `jupyter` by `~/plex_t5_venv/bin/jupyter` in the above commands. To find out what's the installation path of your `jupyter` in the virtual environment, use `which jupyter` command.

\

You could also swap `allow_origin='https://colab.research.google.com'` with `allow_origin='https://colab.sandbox.google.com'` if needed.

\

From the log of the above command, you can see an http link starting with `http://localhost:8888/?token`s. Copy and paste it into the `Connect to a local runtime` option and now you should be able to run this colab.


## Imports

In [None]:
import functools

import jax
import jax.numpy as jnp
import numpy as np
from jax import random

import seqio
import t5x
from t5x import utils as t5_utils
from t5x import partitioning
from t5x.examples.t5 import network
import nest_asyncio
nest_asyncio.apply()

import decoding
import utils
import uncertainty_baselines as ub
from data.tasks import nalue as nalue_task
from uncertainty_baselines.models import t5_be_gp
from models import be_models

## Define model

In [None]:
# Define transformer module.
DROPOUT_RATE = 0.0
NUM_EMBEDDINGS = 32128
COVMAT_MOMENTUM = -1.0
BE_ENS_SIZE = 5
MEAN_FIELD_FACTOR = 0.0001
NORMALIZE_INPUT = True
RANDOM_SIGN = 0.5
STEPS_PER_EPOCH = None

t5_config = network.T5Config(
      vocab_size=NUM_EMBEDDINGS,
      dropout_rate=DROPOUT_RATE,
      dtype='bfloat16',
      emb_dim = 1024,
      head_dim=64,
      logits_via_embedding=False,
      mlp_activations=('gelu', 'linear'),
      mlp_dim=2816,
      num_decoder_layers=24,
      num_encoder_layers=24,
      num_heads=16)
module = t5_be_gp.TransformerBEGp(
    config=t5_config,
    be_decoder_layers = (-1,),
    covmat_momentum = COVMAT_MOMENTUM,
    ens_size = BE_ENS_SIZE,
    mean_field_factor = MEAN_FIELD_FACTOR,
    normalize_input = NORMALIZE_INPUT,
    random_sign_init = RANDOM_SIGN,
    ridge_penalty = 1.0,
    steps_per_epoch = STEPS_PER_EPOCH)

In [None]:
# Define vocab.
sentencepiece_model_file = "gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model"
VOCABULARY = seqio.SentencePieceVocabulary(sentencepiece_model_file=sentencepiece_model_file)

In [None]:
# Define optimizer.
OPTIMIZER = utils.AdafactorGP(decay_rate=0.8, step_offset=0)

In [None]:
# Define loss HParam.
Z_LOSS = 0.0001
LABEL_SMOOTHING = 0.0
LOSS_NORMALIZING_FACTOR = 233472
LABEL_TOKENS = nalue_task.get_nalue_intent_tokens()

In [None]:
model = be_models.EncoderDecoderBEGpClassifierModel(
    module=module,
    decode_fn=functools.partial(
        decoding.beam_search, alpha=0., return_token_scores=True),
    input_vocabulary=VOCABULARY,
    label_smoothing=LABEL_SMOOTHING,
    label_tokens=LABEL_TOKENS,
    loss_normalizing_factor=LOSS_NORMALIZING_FACTOR,
    optimizer_def=OPTIMIZER,
    output_vocabulary=VOCABULARY,
    z_loss=Z_LOSS)

## Load checkpoint

In [None]:
CHECKPOINT_PATH = f'gs://plex-paper/plex_t5_large_c4_to_nalue/'
restore_checkpoint_cfg = t5_utils.RestoreCheckpointConfig(
    path = CHECKPOINT_PATH, mode = 'specific', use_gda = False)

In [None]:
partitioner = partitioning.PjitPartitioner(
    num_partitions = 1,
    logical_axis_rules = partitioning.standard_logical_axis_rules()
)

In [None]:
batch_size = 32
TASK_FEATURE_LENGTHS = {'inputs': 512, 'targets': 3}
input_shapes = {
    'decoder_input_tokens': (batch_size, TASK_FEATURE_LENGTHS['targets']),
    'decoder_loss_weights': (batch_size, TASK_FEATURE_LENGTHS['targets']),
    'decoder_target_tokens': (batch_size, TASK_FEATURE_LENGTHS['targets']),
    'encoder_input_tokens': (batch_size, TASK_FEATURE_LENGTHS['inputs'])
}

# Create train state initializer.
train_state_initializer = t5_utils.TrainStateInitializer(
      optimizer_def=None,  # Do not load optimizer state.
      init_fn=model.get_initial_variables,
      input_shapes=input_shapes,
      partitioner=partitioner)

In [None]:
# Restore train state from checkpoint.
train_state = train_state_initializer.from_checkpoint([restore_checkpoint_cfg])

## Run inference

In [None]:
input_text = "can you please provide me with assistance in moving money from one account to another" #@param {type:"string"}

In [None]:
infer_step_jit = jax.jit(model.predict_batch_with_aux)

In [None]:
input_tokenized = VOCABULARY.encode(input_text)
input_padded = np.pad(input_tokenized, (0, 512 - len(input_tokenized)))
infer_batch = {}
infer_batch['encoder_input_tokens'] = jax.numpy.expand_dims(input_padded,
                                                            axis=0)
infer_batch['decoder_input_tokens'] = np.zeros((1, 3), dtype=np.int32)

In [None]:
# Runs inference on a batch via partitioned_infer_step.
rng = jax.random.PRNGKey(0)
batch_result = infer_step_jit(train_state.params, infer_batch, rng)

In [None]:
VOCABULARY.decode_tf(batch_result[0])