In [None]:
#@title Install dependencies
!pip install omegaconf torch numpy jax termcolor

In [None]:
#@title Load checkpoint
path_to_checkpoint_folder = "./en_dense_lm_355m"  #@param {type:"string"}

import os
path_to_checkpoint = os.path.join(path_to_checkpoint_folder, "model.pt")
path_to_dict = os.path.join(path_to_checkpoint_folder, "dict.txt")
from termcolor import colored

print("Reading from dict.txt...")
indices = []
dummy_indices = []
with open(path_to_dict) as f:
    for line in f:
        index = line.split()[0]
        if not index.startswith("madeupword"):
            indices.append(int(index))
        else:
            dummy_indices.append(int(index[10:]))
mapping = [...] * (len(indices) + len(dummy_indices) + 4)
i = 0
for index in range(4):
    mapping[len(indices) + index] = i
    i += 1
for index in indices:
    mapping[index] = i
    i += 1
for index in dummy_indices:
    mapping[len(indices) + 4 + index] = i
    i += 1
assert min(mapping) == 0 and max(mapping) == len(mapping) - 1 and len(set(mapping)) == len(mapping)
import torch
mapping = torch.tensor(mapping)
print(f"Computed embedding map of length {len(mapping)}.")
print("Reading from checkpoint...")
torch_checkpoint = torch.load(path_to_checkpoint, map_location='cpu')
config = {
    "compat": "fairseq_lm",
    "n_vocab": 51200,
    "n_vocab_padding": 0,
    "norm": "layernorm",
    "pe": "fairseq_sinusoidal",
    "seq": 2048,
    "layers": torch_checkpoint["cfg"]["model"]["decoder_layers"],
    "d_model": torch_checkpoint["cfg"]["model"]["decoder_embed_dim"],
    "n_heads": torch_checkpoint["cfg"]["model"]["decoder_attention_heads"],
}
for c in (8, 6, 4, 2, 1):
    if 0 == config["n_heads"] % c == config["d_model"] % c:
        config["cores_per_replica"] = c
        break
pieces = 16
layers = config["layers"]
d_model = config["d_model"]
total_shards = config["cores_per_replica"]
config["n_vocab_padding"] = padding_rows = -(config["n_vocab"] % -total_shards)
for i in range(total_shards):
    os.makedirs(f"jax_checkpoint/shard_{i}")
print(f"Detected {layers} layers, {config['n_heads']} heads.  Embed dim {d_model}.  {total_shards} shards.  {padding_rows} embedding matrix padding rows.")
print("Done.")

In [None]:
#@title Convert checkpoint to be JAX-compatible { display-mode: "form" }
from termcolor import colored
import torch
import numpy as np
import jax.numpy as jnp
import json

def reshard_reverse(x, old_shape, is_shard_bias=False):
    if len(x.shape) == 1:
        assert False
        out = x[0:1]

    elif len(x.shape) == 2:
        #print(f"LN/bias")
        if old_shape[1] == x.shape[1]:
            #print("LN")
            if not is_shard_bias:
                out = np.tile(x[0:1], (total_shards, 1))
            else:
                #print("shard bias")
                out = np.tile(x[0:1], (total_shards, 1)) / total_shards
        else:
            #print("bias")
            out = x.reshape(old_shape)

    elif len(x.shape) == 3:
        if x.shape[0] * x.shape[2] == old_shape[2]:
            #print("case 1")
            out = x.reshape(old_shape)
        elif x.shape[0] * x.shape[1] == old_shape[1]:
            #print("case 2")
            out = jnp.transpose(x.reshape((old_shape[1], old_shape[0], old_shape[2])), (1, 0, 2))
        else:
            raise Exception(f"unimplemented, {x.shape}, {old_shape}")
    else:
        raise Exception(f"unimplemented, {x}")
    #flattened, structure = jax.tree_flatten(out)
    #return flattened
    return out

def get_old_shape(t, dim=2):
    if len(t.shape) == 2:
        shard_shape = t.shape
        if dim == 1:
            assert shard_shape[0] % total_shards == 0
            return (shard_shape[0] // total_shards, shard_shape[1])
        elif dim == 2:
            assert shard_shape[1] % total_shards == 0
            return (shard_shape[0], shard_shape[1] // total_shards)
        else:
            raise ValueError(f"unsupported dim {dim}")
    if len(t.shape) == 1:
        assert t.shape[0] % total_shards == 0
        return (t.shape[0] // total_shards,)
    else:
        raise ValueError(f"unsupported shape {t.shape}")


def split(a, n):
    k, m = divmod(len(a), n)
    return (a[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in range(n))

def save(cpu_flattened):
    for i in range(total_shards):
        cpu_flattened_chunked = split(cpu_flattened, pieces)
        for j, chunk in enumerate(cpu_flattened_chunked):
            with open(f"jax_checkpoint/shard_{i}/{j}.npz", "wb") as f:
                np.savez(f, *map(lambda c: c[i], chunk))


transforms = [
    ("decoder.embed_tokens.weight", False, 1)
]

checkpoint = []

layer_names = sorted(map(str, range(layers)))
for layer in layer_names:
    transforms.extend([
        (f"decoder.layers.{layer}.self_attn.q_proj.bias", False, 1),
        (f"decoder.layers.{layer}.self_attn.q_proj.weight", False, 2),
        (f"decoder.layers.{layer}.self_attn.v_proj.bias", False, 1),
        (f"decoder.layers.{layer}.self_attn.v_proj.weight", False, 2),
        (f"decoder.layers.{layer}.self_attn.k_proj.bias", False, 1),
        (f"decoder.layers.{layer}.self_attn.k_proj.weight", False, 2),
        (f"decoder.layers.{layer}.self_attn.out_proj.bias", True, None),
        (f"decoder.layers.{layer}.self_attn.out_proj.weight", False, 1),
        (f"decoder.layers.{layer}.fc1.bias", False, 1),
        (f"decoder.layers.{layer}.fc1.weight", False, 2),
        (f"decoder.layers.{layer}.fc2.bias", True, None),
        (f"decoder.layers.{layer}.fc2.weight", False, 1),
        (f"decoder.layers.{layer}.self_attn_layer_norm.bias", False, None),
        (f"decoder.layers.{layer}.self_attn_layer_norm.weight", False, None),
        (f"decoder.layers.{layer}.final_layer_norm.bias", False, None),
        (f"decoder.layers.{layer}.final_layer_norm.weight", False, None),
    ])
transforms.extend([
    ("decoder.layer_norm.bias", False, None),
    ("decoder.layer_norm.weight", False, None),
])

for i in range(len(transforms)):
    transform = transforms.pop(0)

    params = torch_checkpoint["model"][transform[0]]

    # Need to unscramble fairseq-style embedding matrices
    if transform[0] in ("decoder.embed_tokens.weight", "decoder.output_projection.weight"):
        params = params[mapping]
        params = torch.cat((params, torch.zeros(padding_rows, params.shape[-1], device=params.device)), dim=0)

    # torch.nn.Linear uses a transposed version of the equivalent tensor that
    # haiku.Linear uses, so we have to un-transpose the tensor first
    if not any(s in transform[0] for s in ("decoder.embed_tokens.weight",)):
        params = params.T

    if transform[2] is not None:
        old_shape = (total_shards,) + get_old_shape(params, transform[2])
    else:
        old_shape = (total_shards, params.shape[0],)
    print(f"< [{transform[0]}] {params.shape} to {old_shape}")

    params = np.asarray(params[None], dtype=jnp.bfloat16)
    params = reshard_reverse(params, old_shape, is_shard_bias=transform[1])

    if np.isnan(params).any() or np.isinf(params).any():
        raise ValueError(f"bfloat16 overflow/underflow")

    print(f"> [{transform[0]}] {params.shape}")
    assert params.shape == old_shape
    checkpoint.append(params)

# Append the checkpoint step number (can be set to an arbitrary value, in this
# case 0, as long as we're only using inference and not training the model)
checkpoint.append(np.zeros(total_shards, dtype=np.int32))

print("saving")
save(checkpoint)
del checkpoint
with open("jax_checkpoint/config.json", "w") as f:
    json.dump(config, f, indent=2)
print(colored("DONE! The JAX checkpoint is now stored at ./jax_checkpoint", "green"))