# Re-map FLAVA checkpoint

Generalizing FLAVA components with the rest of the codebase can cause existing model checkpoints to go out of sync with the updated architecture. This notebook shows how to load the existing checkpoint, re-map the old layers to the new layers, and save the new checkpoint.

If you wish to save a new checkpoint, you must have access to the PyTorch AWS S3 account.

### Load original model

Load the existing checkpoint into the FLAVA class to see what the architecture currently is.

In [3]:
import torch
from torchmultimodal.models.flava.flava_model import flava_model_for_classification, flava_model_for_pretraining

flava_classification = flava_model_for_classification(num_classes=3)

### Print summary

In [None]:
flava_classification

### Mapping function

Replace this function with the code needed to map the old layer weights to the new layer weights.

In [4]:
import re

def map_state_dict(state_dict):
    mapped_state_dict = {}
    for param, val in state_dict.items():
        res = re.search('attention.attention', param)
        if res:
            idx = res.start()
            new_param = param[:idx] + param[idx+10:]
        else:
            new_param = param
        mapped_state_dict[new_param] = val
    return mapped_state_dict

### Load checkpoint

One more time except don't load into the FLAVA class, keep it as `state_dict`

In [5]:
# Replace this path if it changes
old_model_url = 'https://huggingface.co/aps/flava_full_pretrained_encoders_torchmm/resolve/main/pytorch_model.bin'

old_state_dict = torch.hub.load_state_dict_from_url(old_model_url)

### Perform re-mapping

In [6]:
new_state_dict = map_state_dict(old_state_dict)

### Save updated checkpoint

In [7]:
save_path = '/Users/rafiayub/flava.bin'
torch.save(new_state_dict, save_path)