# Compare MUGEN's Video VQVAE with TorchMultimodal's

This notebook loads the public MUGEN checkpoint for Video VQVAE, remaps the state_dict, and loads it into TorchMultimodal's Video VQVAE to ensure the outputs match. 

### Set directories

Replace these with your local directories.

In [2]:
checkpoint_dir = '/Users/rafiayub/checkpoints/'
repo_dir = '/Users/rafiayub/mugen/'
home_dir = '/Users/rafiayub/'

### Clone MUGEN's repo

In [None]:
!git clone https://github.com/mugen-org/MUGEN_baseline.git $repo_dir

### Download and unzip checkpoints

This will take some time.

In [None]:
!wget https://dl.noahmt.com/creativity/data/MUGEN_release/checkpoints.zip -P $checkpoint_dir

In [None]:
import os

# Unzip checkpoints
zip_location = os.path.join(checkpoint_dir, 'checkpoints.zip')
!unzip $zip_location -d $checkpoint_dir

### Load checkpoint into MUGEN model

In [3]:
import sys
import os
sys.path.append(home_dir)

import torch
from torch import nn
import mugen

ckpt = torch.load(
    os.path.join(checkpoint_dir, 'generation/video_vqvae/L32/epoch=54-step=599999.ckpt'), 
    map_location=torch.device('cpu')
)

The arguments are taken from MUGEN's training scripts found at: https://github.com/mugen-org/MUGEN_baseline/blob/main/generation/experiments/vqvae/VideoVQVAE_L32.sh

In [4]:
class Namespace:
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)


vqvae_args=Namespace(
    embedding_dim=256,
    n_codes=2048,
    n_hiddens=240,
    n_res_layers=4,
    lr=0.0003,
    downsample=(4, 32, 32),
    kernel_size=3,
    sequence_length=16,
    resolution=256,
)
vv_mugen = mugen.VQVAE(vqvae_args)

In [5]:
vv_mugen.load_state_dict(ckpt['state_dict'])

<All keys matched successfully>

### Create TorchMultimodal's Video VQVAE

In [6]:
from examples.mugen.generation.video_vqvae import video_vqvae_mugen

vv_torchmm = video_vqvae_mugen(pretrained_model_key=None)

### Remap MUGEN's state_dict and load into new model

In [7]:
import re

def map_state_dict(state_dict):
    mapped_state_dict = {}
    dim_map = {'w': '2', 'h': '1', 't': '0'}
    layer_map = {'w_qs': 'query', 'w_ks': 'key', 'w_vs': 'value', 'fc': 'output'}
    for param, val in state_dict.items():
        new_param = param
        res = re.search('encoder.convs.', param)
        if res:
            idx = res.end()
            layer_id = int(param[idx])
            new_param = param[:idx] + str(layer_id * 2) + param[idx+1:]
            mapped_state_dict[new_param] = val
            continue
        res = re.search('encoder.conv_last', param)
        if res:
            idx = res.start() + len('encoder.')
            new_param = param[:idx] + 'convs.10' + param[res.end():]
            mapped_state_dict[new_param] = val
            continue
        res = re.search('attn_[w,h,t]\..*\.', param)
        if res:
            dim = param[res.start()+5]
            new_dim = dim_map[dim]
            layer = param[res.start()+7:res.end()-1]
            new_layer = layer_map[layer]
            new_param = param[:res.start()] + 'mha_attns.' + new_dim + '.' + new_layer + '.' + param[res.end():]
            mapped_state_dict[new_param] = val
            continue
        res = re.search('pre_vq_conv', param)
        if res:
            new_param = 'encoder.conv_out' + param[res.end():]
            mapped_state_dict[new_param] = val
            continue
        res = re.search('post_vq_conv', param)
        if res:
            new_param = 'decoder.conv_in' + param[res.end():]
            mapped_state_dict[new_param] = val
            continue
        res = re.search('decoder.convts.', param)
        if res:
            idx = res.end()
            layer_id = int(param[idx])
            new_param = param[:idx] + str(layer_id * 2) + param[idx+1:]
            mapped_state_dict[new_param] = val
            continue
        if param == 'codebook.N':
            new_param = 'codebook.code_usage'
            mapped_state_dict[new_param] = val
            continue
        if param == 'codebook.z_avg':
            new_param = 'codebook.code_avg'
            mapped_state_dict[new_param] = val
            continue
        if param == 'codebook.embeddings':
            new_param = 'codebook.embedding'
            mapped_state_dict[new_param] = val
            continue
            
        mapped_state_dict[new_param] = val
        
    return mapped_state_dict

In [8]:
new_state_dict = map_state_dict(ckpt['state_dict'])

In [9]:
vv_torchmm.load_state_dict(new_state_dict)

<All keys matched successfully>

### Compare outputs with a random input

In [10]:
torch.manual_seed(4)
video = torch.randn(1,3,32,256,256) # b, c, t, h, w

vv_mugen.eval()
vv_torchmm.eval()

loss, x_recon, codebook_output = vv_mugen(video)
output = vv_torchmm(video)

diff = abs(output.decoded - x_recon)
print(f'Max difference between outputs: {torch.max(diff).item()}')
print(f'Mean difference between outputs: {torch.mean(diff).item()}')

Max difference between outputs: 3.0875205993652344e-05
Mean difference between outputs: 1.7353995929170196e-07


### Save mapped checkpoint

In [9]:
save_path = '/Users/rafiayub/checkpoints/generation/video_vqvae/mugen_video_vqvae_L32.pt'
torch.save(new_state_dict, save_path)