This notebook contains the code used to convert the official MaskGIT transformer weights from JAX to PyTorch.

If you want to use it, you will have two install two additional requirements: flax and tensorflow.

In [1]:
import flax
import tensorflow.compat.v1 as tf
import torch

from maskgit.nets.bidirectional_transformer import BidirectionalTransformer

In [2]:
def restore_from_path(path):
  with tf.io.gfile.GFile(path, "rb") as f:
    state = flax.serialization.from_bytes(None, f.read())
  return state

In [3]:
def get_params(jax_params, ckpt, name):
    for k, v in ckpt.items():
        if isinstance(v, dict):
            get_params(jax_params, v, name + [k])
        else:
            jax_params['.'.join(name + [k])] = v

In [4]:
path = 'checkpoints/maskgit_imagenet256_checkpoint'
ckpt = restore_from_path(path)['params']

In [5]:
jax_params = {}
get_params(jax_params, ckpt, [])
for k, v in jax_params.items():
    print(k, v.shape)

Embed_0.embeddings_ln.bias (768,)
Embed_0.embeddings_ln.scale (768,)
Embed_0.position_embeddings.embedding (257, 768)
Embed_0.word_embeddings.embedding (2025, 768)
TransformerLayer_0.Attention_0.attention_output_ln.bias (768,)
TransformerLayer_0.Attention_0.attention_output_ln.scale (768,)
TransformerLayer_0.Attention_0.self_attention.key.bias (16, 48)
TransformerLayer_0.Attention_0.self_attention.key.kernel (768, 16, 48)
TransformerLayer_0.Attention_0.self_attention.out.bias (768,)
TransformerLayer_0.Attention_0.self_attention.out.kernel (16, 48, 768)
TransformerLayer_0.Attention_0.self_attention.query.bias (16, 48)
TransformerLayer_0.Attention_0.self_attention.query.kernel (768, 16, 48)
TransformerLayer_0.Attention_0.self_attention.value.bias (16, 48)
TransformerLayer_0.Attention_0.self_attention.value.kernel (768, 16, 48)
TransformerLayer_0.Mlp_0.intermediate_output.bias (3072,)
TransformerLayer_0.Mlp_0.intermediate_output.kernel (768, 3072)
TransformerLayer_0.Mlp_0.layer_output.bia

In [6]:
class Args:
    def __init__(self):
        self.num_image_tokens = 257
        self.num_codebook_vectors = 2025
        self.dim = 768
        self.n_layers = 24
        self.hidden_dim = 3072
        self.num_heads = 8
        self.attention_dropout = 0.1
        self.hidden_dropout = 0.1

transformer = BidirectionalTransformer(
    num_image_tokens=257,
    num_codebook_vectors=2025,
    dim=768,
    n_layers=24,
    hidden_dim=3072,
    num_heads=8,
    attention_dropout=0.1,
    hidden_dropout=0.1
)
for n, p in transformer.named_parameters():
    print(n, p.shape)

pos_emb torch.Size([257, 768])
bias torch.Size([2025])
tok_emb.weight torch.Size([2025, 768])
blocks.0.MultiHeadAttention.in_proj_weight torch.Size([2304, 768])
blocks.0.MultiHeadAttention.in_proj_bias torch.Size([2304])
blocks.0.MultiHeadAttention.out_proj.weight torch.Size([768, 768])
blocks.0.MultiHeadAttention.out_proj.bias torch.Size([768])
blocks.0.AttentionLN.weight torch.Size([768])
blocks.0.AttentionLN.bias torch.Size([768])
blocks.0.MlpLN.weight torch.Size([768])
blocks.0.MlpLN.bias torch.Size([768])
blocks.0.MLP.0.weight torch.Size([3072, 768])
blocks.0.MLP.0.bias torch.Size([3072])
blocks.0.MLP.2.weight torch.Size([768, 3072])
blocks.0.MLP.2.bias torch.Size([768])
blocks.1.MultiHeadAttention.in_proj_weight torch.Size([2304, 768])
blocks.1.MultiHeadAttention.in_proj_bias torch.Size([2304])
blocks.1.MultiHeadAttention.out_proj.weight torch.Size([768, 768])
blocks.1.MultiHeadAttention.out_proj.bias torch.Size([768])
blocks.1.AttentionLN.weight torch.Size([768])
blocks.1.Attent

In [7]:
torch_params = [[k, torch.from_numpy(v.copy())] for k, v in jax_params.items()]

convert_names = {
    '.kernel': '.weight',
    '.scale': '.weight',
    'Embed_0.position_embeddings.embedding': 'pos_emb',
    'Embed_0.word_embeddings.embedding': 'tok_emb.weight',
    'Embed_0.embeddings_ln': 'emb_ln',
    'MlmLayer_0.mlm_bias.bias': 'bias',
    'MlmLayer_0.mlm_dense': 'Token_Prediction.0',
    'MlmLayer_0.mlm_ln': 'Token_Prediction.2',
    'TransformerLayer_': 'blocks.',
    'Mlp_': 'MLP.',
    'MLP.0.layer_output_ln': 'MlpLN',
    '0.intermediate_output': '0',
    '0.layer_output': '2',
    'Attention_0.attention_output_ln': 'AttentionLN',
    'Attention_0.self_attention': 'MultiHeadAttention',
    'MultiHeadAttention.out': 'MultiHeadAttention.out_proj',
}
for in_pat, out_pat in convert_names.items():
    for i in range(len(torch_params)):
        torch_params[i][0] = torch_params[i][0].replace(in_pat, out_pat)

for i in range(len(torch_params)):
    name = torch_params[i][0]

    if name.startswith('block') and len(torch_params[i][1].shape) == 3:
        if torch_params[i][1].shape[0] == 768:
            torch_params[i][1] = torch_params[i][1].permute(1, 2, 0)
        elif torch_params[i][1].shape[-1] == 768:
            torch_params[i][1] = torch_params[i][1].permute(2, 0, 1)
    if name.startswith('block') and len(torch_params[i][1].shape) == 2:
        if not 'bias' in name:
            torch_params[i][1] = torch_params[i][1].permute(1, 0)
    if name.startswith('Token_Prediction') and len(torch_params[i][1].shape) == 2:
        torch_params[i][1] = torch_params[i][1].permute(1, 0)

    if name.startswith('block') and 'bias' in name and len(torch_params[i][1].shape) == 2:
        torch_params[i][1] = torch_params[i][1].reshape(-1)
    if name.startswith('block') and 'weight' in name and len(torch_params[i][1].shape) == 3:
        if torch_params[i][1].shape[0] == 768:
            torch_params[i][1] = torch_params[i][1].reshape(torch_params[i][1].shape[0], -1)
        elif torch_params[i][1].shape[-1] == 768:
            torch_params[i][1] = torch_params[i][1].reshape(-1, torch_params[i][1].shape[-1])

cat_torch_params = []
for i in range(24):
    base_name = f'blocks.{i}.MultiHeadAttention'
    for w in ['bias', 'weight']:
        params = []
        for n in ['query', 'key', 'value']:
            name = f'{base_name}.{n}.{w}'
            params.append([p[1] for p in torch_params if p[0] == name][0])
        out_name = f'{base_name}.in_proj_{w}'
        cat_torch_params.append((out_name, torch.cat(params, dim=0)))

final_torch_params = []
for p in torch_params:
    if 'key' in p[0] or 'query' in p[0] or 'value' in p[0]:
        continue
    final_torch_params.append(p)
for p in cat_torch_params:
    final_torch_params.append(p)

for n, p in final_torch_params:
    print(n, p.shape, p.min(), p.max())

emb_ln.bias torch.Size([768]) tensor(-0.7091) tensor(0.6671)
emb_ln.weight torch.Size([768]) tensor(0.4073) tensor(1.9009)
pos_emb torch.Size([257, 768]) tensor(-0.2600) tensor(0.4943)
tok_emb.weight torch.Size([2025, 768]) tensor(-1.3872) tensor(0.9978)
blocks.0.AttentionLN.bias torch.Size([768]) tensor(-2.9947) tensor(2.4949)
blocks.0.AttentionLN.weight torch.Size([768]) tensor(0.7839) tensor(1.8952)
blocks.0.MultiHeadAttention.out_proj.bias torch.Size([768]) tensor(-0.3167) tensor(0.4164)
blocks.0.MultiHeadAttention.out_proj.weight torch.Size([768, 768]) tensor(-0.4847) tensor(0.3539)
blocks.0.MLP.0.bias torch.Size([3072]) tensor(-0.5490) tensor(0.0892)
blocks.0.MLP.0.weight torch.Size([3072, 768]) tensor(-0.3889) tensor(0.3238)
blocks.0.MLP.2.bias torch.Size([768]) tensor(-0.2196) tensor(0.1671)
blocks.0.MLP.2.weight torch.Size([768, 3072]) tensor(-1.1138) tensor(1.0315)
blocks.0.MlpLN.bias torch.Size([768]) tensor(-0.9987) tensor(0.4279)
blocks.0.MlpLN.weight torch.Size([768]) ten

In [8]:
state_dict = {n: p for n, p in final_torch_params}
torch.save({'state_dict': state_dict}, 'checkpoints/maskgit_imagenet256.ckpt')