This notebooks sets up an ESM-2 15B model, with weights loaded in from PyTorch, for inference on a TPU v2-8/v3-8, taking full advantage of model parallelism.

In [1]:
# Only necessary in non-Poetry envs, remove once pip-installable.
# import sys
# sys.path.append("..")

In [2]:
# General imports
import numpy as np

from flax.core import frozen_dict
import jax

# esmjax imports
from esmjax import io, tokenizer as esm_tokenizer
from esmjax.modules import modules

# Imports specifically for multi-device sharding
from esmjax.modules import partitioning
from flax.linen import partitioning as nn_partitioning
from jax.experimental import maps, PartitionSpec as P, pjit

## Step 0: Load model

First, we load in the model and its converted weights.

In [3]:
MODEL_NAME = "esm2_t48_15B_UR50D"
# Load in the original PyTorch state; will download if first time.
state = io.get_torch_state(MODEL_NAME)

esm, params_axes = modules.get_esm2_model(state["cfg"])
esm_params = io.convert_encoder(state["model"], state["cfg"])
esm_params = frozen_dict.FrozenDict({"params": esm_params})

Let's look at the `AxisMetadata`, pre-conversion. As we see, the metadata exists only for the params that are to be sharded.

In [4]:
params_axes["3"]

FrozenDict({
    fc1: {
        kernel_axes: AxisMetadata(names=('embed_kernel', 'hidden')),
    },
    fc2: {
        kernel_axes: AxisMetadata(names=('hidden', 'embed_kernel')),
    },
    self_attn: {
        k_proj: {
            kernel_axes: AxisMetadata(names=('embed_kernel', 'heads', None)),
        },
        out_proj: {
            kernel_axes: AxisMetadata(names=('heads', None, 'embed_kernel')),
        },
        q_proj: {
            kernel_axes: AxisMetadata(names=('embed_kernel', 'heads', None)),
        },
        v_proj: {
            kernel_axes: AxisMetadata(names=('embed_kernel', 'heads', None)),
        },
    },
})

Next, we shard the loaded params onto the TPUs.

In [5]:
# Checking we have 8 devices
jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

In [6]:
# Convert `params_axes` (which only has sharding specs for params
# that will be sharded) to `esm_params` (which has a spec for ALL
# params, defaulting to None for params that are fully replicated.)
esm_axes = partitioning.get_params_axes(esm_params, params_axes, rules=partitioning.DEFAULT_TPU_RULES)

Let's quickly check the sharding spec for layer 3, for example

In [7]:
esm_axes["params"]["3"]  # looks right! Note we're only sharding the large kernels.

FrozenDict({
    fc1: {
        kernel: PartitionSpec('X', 'Y'),
        bias: None,
    },
    fc2: {
        kernel: PartitionSpec('Y', 'X'),
        bias: None,
    },
    self_attn_layer_norm: {
        scale: None,
        bias: None,
    },
    final_layer_norm: {
        scale: None,
        bias: None,
    },
    self_attn: {
        k_proj: {
            kernel: PartitionSpec('X', 'Y', None),
            bias: None,
        },
        q_proj: {
            kernel: PartitionSpec('X', 'Y', None),
            bias: None,
        },
        v_proj: {
            kernel: PartitionSpec('X', 'Y', None),
            bias: None,
        },
        out_proj: {
            kernel: PartitionSpec('Y', None, 'X'),
            bias: None,
        },
    },
})

In [8]:
# Create 2D TPU mesh
mesh_shape = (2, 4)  # X=2, Y=4, 8 TPUs total
devices = np.asarray(jax.devices()).reshape(*mesh_shape)
mesh = maps.Mesh(devices, ("X", "Y"))

# Create fn for inference.
preshard_fn = pjit.pjit(
    lambda x: x,  # this function does nothing
    in_axis_resources=(esm_axes,),  # but this spec "pre-shards" the params
    out_axis_resources=esm_axes,
)

# There's two contexts: one for the mesh, the other specifying the translation
# rules for named sharding axis -> TPU mesh logical axis
with maps.Mesh(mesh.devices, mesh.axis_names), nn_partitioning.axis_rules(
    partitioning.DEFAULT_TPU_RULES
):
    esm_sharded_params = preshard_fn(esm_params)

Let's see the mesh object:

In [9]:
mesh

Mesh(array([[0, 1, 2, 3],
       [4, 5, 6, 7]]), ('X', 'Y'))

We can access sharding specs down to individual params, if we'd like:

In [10]:
esm_sharded_params["params"]["3"]["fc1"]["kernel"].sharding_spec

ShardingSpec((Chunked(2), Chunked(4)), (ShardedAxis(axis=0), ShardedAxis(axis=1)))

In [11]:
esm_sharded_params["params"]["3"]["fc2"]["kernel"].sharding_spec

ShardingSpec((Chunked(4), Chunked(2)), (ShardedAxis(axis=1), ShardedAxis(axis=0)))

We can also see the exact indices ranges that exist on each TPU. For example, here we see each TPU has a unique 2560x5120 sized "slice" of the 5120x20480 weight matrix.

In [12]:
esm_sharded_params["params"]["3"]["fc1"]["kernel"].indices

((slice(0, 2560, None), slice(0, 5120, None)),
 (slice(0, 2560, None), slice(5120, 10240, None)),
 (slice(0, 2560, None), slice(10240, 15360, None)),
 (slice(0, 2560, None), slice(15360, 20480, None)),
 (slice(2560, 5120, None), slice(0, 5120, None)),
 (slice(2560, 5120, None), slice(5120, 10240, None)),
 (slice(2560, 5120, None), slice(10240, 15360, None)),
 (slice(2560, 5120, None), slice(15360, 20480, None)))

## Step 1: Tokenize input protein

For this example, we use the sequences for [p53](https://en.wikipedia.org/wiki/P53) (one of the most extensively studied proteins in cancer biology) and insulin. The sequence for the human orthologs are:

In [13]:
p53_seq = "MEEPQSDPSVEPPLSQETFSDLWKLLPENNVLSPLPSQAMDDLMLSPDDIEQWFTEDPGP\
    DEAPRMPEAAPPVAPAPAAPTPAAPAPAPSWPLSSSVPSQKTYQGSYGFRLGFLHSGTAK\
    SVTCTYSPALNKMFCQLAKTCPVQLWVDSTPPPGTRVRAMAIYKQSQHMTEVVRRCPHHE\
    RCSDSDGLAPPQHLIRVEGNLRVEYLDDRNTFRHSVVVPYEPPEVGSDCTTIHYNYMCNS\
    SCMGGMNRRPILTIITLEDSSGNLLGRNSFEVRVCACPGRDRRTEEENLRKKGEPHHELP\
    PGSTKRALPNNTSSSPQPKKKPLDGEYFTLQIRGRERFEMFRELNEALELKDAQAGKEPG\
    GSRAHSSHLKSKKGQSTSRHKKLMFKTEGPDSD"

insulin_seq = "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAED\
    LQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN"

We then use our tokenizer to convert these sequence of letters into sequence of integers, with appropriate padding.

In [14]:
tokenizer = esm_tokenizer.protein_tokenizer(pad_to_multiple_of=128)
tokens = [x.ids for x in tokenizer.encode_batch([p53_seq, insulin_seq])]
batch = np.array(tokens)

The first and last tokens are 0 and 2, `<cls>` and `<eos>`. Note that the first actual amino acid in both sequences is 20, which is methionine

In [15]:
batch

array([[ 0, 20,  9, ...,  1,  1,  1],
       [ 0, 20,  5, ...,  1,  1,  1]])

## Step 2: Get embeddings

We then create a `pjit`'ted function for inference, and call it just like the parameter sharding function above.

In [16]:
# Create fn for inference.
apply_fn = pjit.pjit(
    esm.apply,
    in_axis_resources=(esm_axes, P("X", None)),
    out_axis_resources=P("X", None, "Y"),
)

In [17]:
# Note that the first call takes a *while*, about 50s on a TPUv2-8
with maps.Mesh(mesh.devices, mesh.axis_names), nn_partitioning.axis_rules(
    partitioning.DEFAULT_TPU_RULES
):
    embeds = apply_fn(esm_sharded_params, batch)

Embeds is a 2x512x5120 tensor, corresponding to batch x seq x features

In [18]:
embeds.shape

(2, 512, 5120)

We can also see its sharding pattern too; the batch axis is sharded across the X mesh axis, and the embedding axis is sharded over the Y mesh axis.

In [19]:
embeds.sharding_spec, embeds.indices

(ShardingSpec((Chunked(2), NoSharding(), Chunked(4)), (ShardedAxis(axis=0), ShardedAxis(axis=1))),
 ((slice(0, 1, None), slice(None, None, None), slice(0, 1280, None)),
  (slice(0, 1, None), slice(None, None, None), slice(1280, 2560, None)),
  (slice(0, 1, None), slice(None, None, None), slice(2560, 3840, None)),
  (slice(0, 1, None), slice(None, None, None), slice(3840, 5120, None)),
  (slice(1, 2, None), slice(None, None, None), slice(0, 1280, None)),
  (slice(1, 2, None), slice(None, None, None), slice(1280, 2560, None)),
  (slice(1, 2, None), slice(None, None, None), slice(2560, 3840, None)),
  (slice(1, 2, None), slice(None, None, None), slice(3840, 5120, None))))