This notebook can be used to test the state dict mapping between the original MDETR repo's checkpoints (with ResNet backbone, found [here](https://github.com/ashkamath/mdetr#pre-training)) and the refactored TorchMultimodal classes.

In [1]:
download_dir = "/data/home/ebs/data/mdetr"
repo_dir = "/data/home/ebs"

In [None]:
# Install MDETR repo 
!git clone https://github.com/ashkamath/mdetr.git $repo_dir


In [None]:
# Download checkpoint
!wget https://zenodo.org/record/4721981/files/pretrained_resnet101_checkpoint.pth?download=1 -P $download_dir

In [2]:
import os
import sys 
sys.path.append(repo_dir)
sys.path.append(os.path.join(repo_dir,"mdetr"))

# Load MDETR classes and ResNet101 weights
import torch
from torch import nn
from mdetr.models import build_model
from mdetr.main import get_args_parser
import argparse

mdetr = torch.load(os.path.join(download_dir,"pretrained_resnet101_checkpoint.pth?download=1"), map_location=torch.device('cpu'))

parser = argparse.ArgumentParser("DETR training and evaluation script", parents=[get_args_parser()])
"--dataset_config=wef --device=cpu"
args = parser.parse_args(['--dataset_config', 'wef', '--device', 'cpu'])
model, criterion, contrastive_criterion, qa_criterion, weight_dict = build_model(args)

model.load_state_dict(mdetr['model'])

  from .autonotebook import tqdm as notebook_tqdm
Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['lm_head.layer_norm.weight', 'lm_head.decoder.weight', 'lm_head.dense.bias', 'lm_head.dense.weight', 'lm_head.bias', 'lm_head.layer_norm.bias']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


<All keys matched successfully>

In [3]:
from torch import nn, Tensor
from typing import Dict, List

# Define a couple helper functions
def filter_dict(key_condition, d):
    return {k: v for k, v in d.items() if key_condition(k)}

def get_params_for_layer(state_dict, i):
    return [x for x in state_dict.keys() if f'layer.{i}.' in x or f'layers.{i}' in x]

# Mapping from TorchText layers to Hugging Face ones
# TorchText's input projection should equal the concatenation of 
# Hugging Face's Q,K,V matrics
param_mapping = {
    'self_attn.in_proj_weight': ['attention.self.query', 'attention.self.key', 'attention.self.value'],
    'self_attn.in_proj_bias': ['attention.self.query', 'attention.self.key', 'attention.self.value'], 
    'self_attn.out_proj': 'attention.output.dense',
    'norm1': 'attention.output.LayerNorm',
    'linear1': 'intermediate.dense',
    'linear2': 'output.dense',
    'norm2': 'output.LayerNorm',
}

# These are the prefixes of the text encoder layers as they occur in Hugging Face and TorchText
hf_layer_prefix = 'transformer.text_encoder.encoder.layer'
tt_layer_prefix = 'text_encoder.encoder.layers.layers'

postfixes = ['weight', 'bias']

# Create a state dict for ith layer of TorchText RoBERTa encoder
# for storing weights from ith layer of Hugging Face's encoder
def map_layer(hf_state_dict, tt_state_dict, i):
    mapped_state_dict = {}
    hf_layer = get_params_for_layer(hf_state_dict, i)
    tt_layer = get_params_for_layer(tt_state_dict, i)
    for tt_key_short, hf_key_short in param_mapping.items():
        tt_key_short = '.'.join([tt_layer_prefix, str(i), tt_key_short])
        # For Q,K,V matrices we need to concat the weights
        if isinstance(hf_key_short, List):
            hf_keys_short = list(map(lambda x: '.'.join([hf_layer_prefix, str(i), x]), hf_key_short))
            # for postfix in postfixes:
            postfix = tt_key_short.split('_')[-1]
            hf_keys = ['.'.join([x, postfix]) for x in hf_keys_short]
            if not any([x in tt_key_short for x in postfixes]):
                tt_key = '.'.join([tt_key_short, postfix])
            else:
                tt_key = tt_key_short
            # print(f"COMBINING {hf_keys}")
            qkv_combined = torch.concat([hf_state_dict[hf_key] for hf_key in hf_keys])
            # print(f"qkv_combined size is {qkv_combined.size()}")
            # print(f"Mapping into {tt_key}")
            mapped_state_dict[tt_key] = qkv_combined
        else:
            hf_key_short = '.'.join([hf_layer_prefix, str(i), hf_key_short])
            for postfix in postfixes:
                tt_key = '.'.join([tt_key_short, postfix])
                hf_key = '.'.join([hf_key_short, postfix])
                mapped_state_dict[tt_key] = hf_state_dict[hf_key]

    return mapped_state_dict

    
# Just a for loop around the text encoder layer mapping
def map_text_encoders(hf_state_dict: Dict[str, Tensor], tt_state_dict: Dict[str, Tensor], n_layers: int = 12):
    mapped_state_dict = {}
    for i in range(n_layers):
        mapped_state_dict.update(map_layer(hf_state_dict, tt_state_dict, i))
    return mapped_state_dict


# The main function used to map from the MDETR state dict to the TorchMultimodal one
# TODO: refactor to remove the explicit dependency on n_layers
def map_mdetr_state_dict(mdetr_state_dict, mm_state_dict, n_layers: int = 12): 
    # Perform the text encoder mapping
    mapped_state_dict = map_text_encoders(
        mdetr_state_dict, 
        mm_state_dict,
        n_layers=12
    )
    
    # Miscellaneous renaming (this can probably be cleaned up)
    mapped_state_dict = {k.replace('transformer.text_encoder', 'text_encoder'): v for k, v in mapped_state_dict.items() if 'embeddings' not in k}

    for k, v in mdetr_state_dict.items():
        if not k.startswith('transformer.text_encoder') and not k.startswith('transformer.resizer') and 'input_proj' not in k:
            mapped_state_dict[k.replace('backbone.0', 'image_backbone')] = v
        if 'embeddings' in k:
            mapped_state_dict[k.replace('transformer.','')] = v
        if 'input_proj' in k:
            mapped_state_dict[k.replace('input_proj','image_projection')] = v
        if 'resizer' in k:
            mapped_state_dict[k.replace('transformer.','').replace('resizer', 'text_projection')] = v
        if 'embeddings.LayerNorm' in k:
            new_k = k.replace('transformer.','')
            mapped_state_dict[new_k.replace('LayerNorm', 'layer_norm')] = v
            del mapped_state_dict[new_k]
            # mapped_state_dict[f"text_encoder.encoder.embedding_layer_norm.{k.split('.')[-1]}"] = v
        if 'bbox_embed' in k:
            parsed = k.split('.')
            i = int(parsed[parsed.index('layers') + 1])
            mapped_state_dict[k.replace('layers','model').replace(str(i), str(2*i))] = v
            del mapped_state_dict[k]
        if all([x in k for x in ['transformer', 'layers', 'linear']]):
            k_split = k.split('.')
            i = int(k_split[-2][-1])
            k_new = '.'.join(k_split[:-2] + ["mlp", "model", str(3*(i-1)), k_split[-1]])
            mapped_state_dict[k_new] = v
            del mapped_state_dict[k]
        if 'contrastive' in k:
            k_new = k.replace('align','alignment').replace('projection_image', 'image_projection').replace('projection_text', 'text_projection')
            mapped_state_dict[k_new] = v
            del mapped_state_dict[k]
    
    return mapped_state_dict


In [4]:
import torch
from torch import nn, Tensor
from mdetr.models.mdetr import MDETR
from mdetr.models.transformer import Transformer
import unittest
# from torchmultimodal.utils.common import NestedTensor
from transformers import RobertaTokenizerFast

from torchmultimodal.models.mdetr.image_encoder import mdetr_resnet101_backbone
from torchmultimodal.models.mdetr.text_encoder import mdetr_roberta_text_encoder
from torchmultimodal.models.mdetr.transformer import MDETRTransformer as mm_Transformer
from torchmultimodal.models.mdetr.model import mdetr_resnet101

max_diff = lambda x, y: torch.max(torch.abs(x - y))

# This is the class for testing the state dict mapping
class TestMDETR(unittest.TestCase):

    def setUp(self):
        self.test_tensors = torch.rand(2, 3, 64, 64).unbind(dim=0)
        mask = torch.randint(0, 2, (2, 64, 64))
        # self.samples = NestedTensor(test_tensor, mask)
        self.captions = ['I can see the sun', 'But even if I cannot see the sun, I know that it exists']
        self.tokenizer = RobertaTokenizerFast.from_pretrained('roberta-base')
        self.text = self.tokenizer.batch_encode_plus(self.captions, padding="longest", return_tensors="pt")
        self.mdetr = model
        self.mdetr.eval()
  
    def run_mdetr(self):
        self.memory_cache = self.mdetr(self.test_tensors, self.captions, encode_and_save=True)
        self.mdetr_out = self.mdetr(self.test_tensors, self.captions, encode_and_save=False, memory_cache=self.memory_cache)

        
    def run_mm_mdetr(self):
        self.mm_mdetr = mdetr_resnet101()
        self.mapped_state_dict = map_mdetr_state_dict(self.mdetr.state_dict(), self.mm_mdetr.state_dict())

        
        self.mm_mdetr.load_state_dict(self.mapped_state_dict)
        self.mm_mdetr.eval()
        self.mm_out = self.mm_mdetr(self.test_tensors, self.text.input_ids)
        self.mm_out_dict = {
            'pred_logits': self.mm_out.pred_logits, 
            'pred_boxes': self.mm_out.pred_boxes, 
            'proj_queries': self.mm_out.projected_queries,
            'proj_tokens': self.mm_out.projected_tokens
            
        }
    def compare_results(self):
        for k in self.mm_out_dict.keys():
            tensor_diff = max_diff(self.mm_out_dict[k], self.mdetr_out[k])
            print(f"Maximum difference in {k} is {tensor_diff}")
        

In [5]:
# Run the test
tester = TestMDETR()
tester.setUp()
tester.run_mdetr()
tester.run_mm_mdetr()
tester.compare_results()

  dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)


Maximum difference in pred_logits is 1.239776611328125e-05
Maximum difference in pred_boxes is 3.337860107421875e-06
Maximum difference in proj_queries is 2.384185791015625e-07
Maximum difference in proj_tokens is 3.129243850708008e-07


  dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
