This notebook contains the code used to convert the official MaskGIT tokenizer weights from JAX to PyTorch.

If you want to use it, you will have two install two additional requirements: flax and tensorflow.

In [1]:
import flax
import tensorflow.compat.v1 as tf
import torch

from maskgit.nets.vqgan_tokenizer import VQVAE
from maskgit.configs import vqgan_config

In [2]:
def restore_from_path(path):
  with tf.io.gfile.GFile(path, "rb") as f:
    state = flax.serialization.from_bytes(None, f.read())
  return state

In [3]:
def get_params(jax_params, ckpt, name):
    for k, v in ckpt.items():
        if isinstance(v, dict):
            get_params(jax_params, v, name + [k])
        else:
            jax_params['.'.join(name + [k])] = v

In [4]:
path = 'checkpoints/tokenizer_imagenet256_checkpoint'
ckpt = restore_from_path(path)['params']

In [5]:
jax_params = {}
get_params(jax_params, ckpt, [])
for k, v in jax_params.items():
    print(k, v.shape)

decoder.Conv_0.bias (512,)
decoder.Conv_0.kernel (3, 3, 256, 512)
decoder.Conv_1.bias (512,)
decoder.Conv_1.kernel (3, 3, 512, 512)
decoder.Conv_2.bias (256,)
decoder.Conv_2.kernel (3, 3, 256, 256)
decoder.Conv_3.bias (256,)
decoder.Conv_3.kernel (3, 3, 256, 256)
decoder.Conv_4.bias (128,)
decoder.Conv_4.kernel (3, 3, 128, 128)
decoder.Conv_5.bias (3,)
decoder.Conv_5.kernel (3, 3, 128, 3)
decoder.GroupNorm_0.bias (128,)
decoder.GroupNorm_0.scale (128,)
decoder.ResBlock_0.Conv_0.kernel (3, 3, 512, 512)
decoder.ResBlock_0.Conv_1.kernel (3, 3, 512, 512)
decoder.ResBlock_0.GroupNorm_0.bias (512,)
decoder.ResBlock_0.GroupNorm_0.scale (512,)
decoder.ResBlock_0.GroupNorm_1.bias (512,)
decoder.ResBlock_0.GroupNorm_1.scale (512,)
decoder.ResBlock_1.Conv_0.kernel (3, 3, 512, 512)
decoder.ResBlock_1.Conv_1.kernel (3, 3, 512, 512)
decoder.ResBlock_1.GroupNorm_0.bias (512,)
decoder.ResBlock_1.GroupNorm_0.scale (512,)
decoder.ResBlock_1.GroupNorm_1.bias (512,)
decoder.ResBlock_1.GroupNorm_1.scale (5

In [6]:
config = vqgan_config.get_config()
vq = VQVAE(config)
for n, p in vq.named_parameters():
    print(n, p.shape)

quantizer.codebook.weight torch.Size([1024, 256])
encoder.conv_in.weight torch.Size([128, 3, 3, 3])
encoder.res_blocks.0.0.norm0.weight torch.Size([128])
encoder.res_blocks.0.0.norm0.bias torch.Size([128])
encoder.res_blocks.0.0.conv0.weight torch.Size([128, 128, 3, 3])
encoder.res_blocks.0.0.norm1.weight torch.Size([128])
encoder.res_blocks.0.0.norm1.bias torch.Size([128])
encoder.res_blocks.0.0.conv1.weight torch.Size([128, 128, 3, 3])
encoder.res_blocks.0.1.norm0.weight torch.Size([128])
encoder.res_blocks.0.1.norm0.bias torch.Size([128])
encoder.res_blocks.0.1.conv0.weight torch.Size([128, 128, 3, 3])
encoder.res_blocks.0.1.norm1.weight torch.Size([128])
encoder.res_blocks.0.1.norm1.bias torch.Size([128])
encoder.res_blocks.0.1.conv1.weight torch.Size([128, 128, 3, 3])
encoder.res_blocks.1.0.norm0.weight torch.Size([128])
encoder.res_blocks.1.0.norm0.bias torch.Size([128])
encoder.res_blocks.1.0.conv0.weight torch.Size([128, 128, 3, 3])
encoder.res_blocks.1.0.norm1.weight torch.Siz

In [7]:
torch_params = [[k, torch.from_numpy(v.copy())] for k, v in jax_params.items()]

convert_names = {
    '.kernel': '.weight',
    '.scale': '.weight',
    '.ResBlock_': '.res_blocks.',
    '.GroupNorm_': '.norm',
    '.Conv_': '.conv',
    'quantizer.codebook': 'quantizer.codebook.weight',
    'decoder.conv0': 'decoder.conv_in',
    'decoder.conv5': 'decoder.conv_out',
    'decoder.norm0': 'decoder.norm_out',
    'encoder.conv0': 'encoder.conv_in',
    'encoder.conv1': 'encoder.conv_out',
    'encoder.norm0': 'encoder.norm_out',
}
for in_pat, out_pat in convert_names.items():
    for i in range(len(torch_params)):
        torch_params[i][0] = torch_params[i][0].replace(in_pat, out_pat)

for i in range(len(torch_params)):
    name = torch_params[i][0]
    if name.startswith('decoder.conv'):
        try:
            x = int(name.split('.')[1][4:])
            tp = name.split('.')[2]
            torch_params[i][0] = f'decoder.res_blocks.{x}.3.{tp}'
        except ValueError:
            pass
    if name.startswith('decoder.res_blocks') or name.startswith('encoder.res_blocks'):
        tokens = name.split('.')
        x = int(tokens[2])
        x1 = x // 2
        x2 = x % 2
        tokens = tokens[:2] + [str(x1), str(x2)] + tokens[3:]
        torch_params[i][0] = '.'.join(tokens)
        if 'conv2' in torch_params[i][0]:
            torch_params[i][0] = torch_params[i][0].replace('.conv2.', '.conv_res.')
    if len(torch_params[i][1].shape) == 4:
        torch_params[i][1] = torch_params[i][1].permute(3, 2, 0, 1)

for n, p in torch_params:
    print(n, p.shape)

decoder.conv_in.bias torch.Size([512])
decoder.conv_in.weight torch.Size([512, 256, 3, 3])
decoder.res_blocks.1.3.bias torch.Size([512])
decoder.res_blocks.1.3.weight torch.Size([512, 512, 3, 3])
decoder.res_blocks.2.3.bias torch.Size([256])
decoder.res_blocks.2.3.weight torch.Size([256, 256, 3, 3])
decoder.res_blocks.3.3.bias torch.Size([256])
decoder.res_blocks.3.3.weight torch.Size([256, 256, 3, 3])
decoder.res_blocks.4.3.bias torch.Size([128])
decoder.res_blocks.4.3.weight torch.Size([128, 128, 3, 3])
decoder.conv_out.bias torch.Size([3])
decoder.conv_out.weight torch.Size([3, 128, 3, 3])
decoder.norm_out.bias torch.Size([128])
decoder.norm_out.weight torch.Size([128])
decoder.res_blocks.0.0.conv0.weight torch.Size([512, 512, 3, 3])
decoder.res_blocks.0.0.conv1.weight torch.Size([512, 512, 3, 3])
decoder.res_blocks.0.0.norm0.bias torch.Size([512])
decoder.res_blocks.0.0.norm0.weight torch.Size([512])
decoder.res_blocks.0.0.norm1.bias torch.Size([512])
decoder.res_blocks.0.0.norm1.w

In [8]:
state_dict = {n: p for n, p in torch_params}
torch.save({'state_dict': state_dict}, 'checkpoints/tokenizer_imagenet256.ckpt')