# Trying to load Torch BERT weights into Jax

We'd like to load the fine-tuned weights from 

In [1]:
import torch
import torch.nn as nn
import jax.numpy as jnp
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from transformers import BertForSequenceClassification, FlaxBertForSequenceClassification

## Load Torch Fine-tuned weights

They intialize their model like this

In [3]:
bert_theirs = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

And then read in this state dict


In [4]:
state_dict = torch.load('../data/classifier/classifier.pt', map_location=torch.device('cpu'))

In [5]:
bert_theirs.load_state_dict(state_dict)

<All keys matched successfully>

## Trying to load the state dict into JAX

There is a [github dicussion](https://github.com/google/flax/discussions/927) from around two years ago that indicates that this is possible. But I am afraid that the details have changed a bit since then, and I can't find anything more recent.  

In [6]:
# instantiate the complimentary Flax BERT pretrained model
bert_ours = FlaxBertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing FlaxBertForSequenceClassification: {('cls', 'predictions', 'transform', 'dense', 'kernel'), ('cls', 'predictions', 'transform', 'dense', 'bias'), ('cls', 'predictions', 'transform', 'LayerNorm', 'scale'), ('cls', 'predictions', 'bias'), ('cls', 'predictions', 'transform', 'LayerNorm', 'bias')}
- This IS expected if you are initializing FlaxBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing FlaxBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of FlaxBertForSequenceClassification were not initialized from the model checkpoint at bert-base-unca

### Attempt #1 

First I try a slightly adapted version of the method from the first link from the above github discussion https://github.com/nikitakit/flax_bert/blob/master/import_weights.py. It appears to map the torch parameter keywords to their jax counterparts. 

In [7]:
def load_params_from_hf(pt_params, hidden_size, num_attention_heads):
    jax_params = {}
    # mapping between HuggingFace PyTorch BERT and JAX model
    pt_key_to_jax_key = [
        # Output heads
        ('cls.seq_relationship', 'classification'),
        ('cls.predictions.transform.LayerNorm', 'predictions_transform_layernorm'),
        ('cls.predictions.transform.dense', 'predictions_transform_dense'),
        ('cls.predictions.bias', 'predictions_output.bias'),
        ('cls.predictions.decoder.weight', 'UNUSED'),
        ('cls.predictions.decoder.bias', 'UNUSED'),
        # Embeddings
        ('embeddings.position_ids', 'UNUSED'),
        ('embeddings.word_embeddings.weight', 'word_embeddings.embedding'),
        ('embeddings.token_type_embeddings.weight', 'type_embeddings.embedding'),
        ('embeddings.position_embeddings.weight', 'position_embeddings.embedding'),
        ('embeddings.LayerNorm', 'embeddings_layer_norm'),
        # Pooler
        ('pooler.dense.', 'pooler.'),
        # Layers
        ('bert.encoder.layer.', 'bert.encoder_layer_'),
        # ('bert/encoder/layer_', 'bert/encoder_layer_'),
        ('attention.self', 'self_attention.attn'),
        ('attention.output.dense', 'self_attention.attn.output'),
        ('attention.output.LayerNorm', 'self_attention_layer_norm'),
        ('output.LayerNorm', 'output_layer_norm'),
        ('intermediate.dense', 'feed_forward.intermediate'),
        ('output.dense', 'feed_forward.output'),
        # Parameter names
        ('weight', 'kernel'),
        ('beta', 'bias'),
        ('gamma', 'scale'),
        ('layer_norm.kernel', 'layer_norm.scale'),
        ('layernorm.kernel', 'layernorm.scale'),
        ]
    pt_keys_to_transpose = (
            "dense.weight",
            "attention.self.query",
            "attention.self.key",
            "attention.self.value"
            )
    for pt_key, val in pt_params.items():
        jax_key = pt_key
        for pt_name, jax_name in pt_key_to_jax_key:
            jax_key = jax_key.replace(pt_name, jax_name)

        if 'UNUSED' in jax_key:
                continue

        if any([x in pt_key for x in pt_keys_to_transpose]):
                val = val.T
        val = np.asarray(val)

        # Reshape kernels if necessary
        reshape_params = ['key', 'query', 'value']
        for key in reshape_params:
            if f'self_attention.attn.{key}.kernel' in jax_key:
                val = np.swapaxes(
                        val.reshape((hidden_size, num_attention_heads, -1)), 0, 1)
            elif f'self_attention.attn.{key}.bias' in jax_key:
                val = val.reshape((num_attention_heads, -1))
        if 'self_attention.attn.output.kernel' in jax_key:
            val = val.reshape((num_attention_heads, -1, hidden_size))
        elif 'self_attention.attn.output.bias' in jax_key:
            # The multihead attention implementation we use creates a bias vector for
            # each head, even though this is highly redundant.
            val = np.stack(
                    [val] + [np.zeros_like(val)] * (num_attention_heads - 1), axis=0)

        jax_params[jax_key] = val

    # jax position embedding kernel has additional dimension
    pos_embedding = jax_params[
            'bert.position_embeddings.embedding']
    jax_params[
            'bert.position_embeddings.embedding'] = pos_embedding[
                    np.newaxis, ...]

    # this layer doesn't have parameters, but key is required to be present
    jax_params['GatherIndexes_0'] = {}

    # convert flat param dict into nested dict using `/` as delimeter
    outer_dict = {}
    for key, val in jax_params.items():
        tokens = key.split('.')
        inner_dict = outer_dict
        # each token except the very last should add a layer to the nested dict
        for token in tokens[:-1]:
            if token not in inner_dict:
                inner_dict[token] = {}
            inner_dict = inner_dict[token]
        inner_dict[tokens[-1]] = val

    if 'global_step' in outer_dict:
        del outer_dict['global_step']

    return outer_dict

In [8]:
#their_params_fl = pt2fl(bert_theirs)
their_params_fl = load_params_from_hf(state_dict, hidden_size=768, num_attention_heads=12)

  val = val.T


(The above seems like an ignorable warning)

The `FlaxBertForSequenceClassification.params` property acts as a setter, so it will complain if the new value isnt formatted correctly. So it seems that the above method does not produce a valid jax state dict. Maybe I can transform it so its correct?

In [9]:
bert_ours.params = their_params_fl

ValueError: Some parameters are missing. Make sure that `params` include the following parameters {('bert', 'encoder', 'layer', '10', 'output', 'dense', 'kernel'), ('bert', 'encoder', 'layer', '2', 'attention', 'self', 'value', 'kernel'), ('bert', 'encoder', 'layer', '3', 'attention', 'output', 'dense', 'bias'), ('bert', 'encoder', 'layer', '7', 'attention', 'self', 'value', 'bias'), ('bert', 'encoder', 'layer', '1', 'attention', 'output', 'LayerNorm', 'bias'), ('bert', 'encoder', 'layer', '5', 'attention', 'self', 'key', 'bias'), ('bert', 'encoder', 'layer', '7', 'attention', 'self', 'key', 'kernel'), ('bert', 'encoder', 'layer', '9', 'attention', 'self', 'query', 'kernel'), ('bert', 'embeddings', 'LayerNorm', 'bias'), ('bert', 'encoder', 'layer', '10', 'intermediate', 'dense', 'bias'), ('bert', 'encoder', 'layer', '5', 'output', 'LayerNorm', 'scale'), ('bert', 'encoder', 'layer', '5', 'output', 'LayerNorm', 'bias'), ('bert', 'encoder', 'layer', '10', 'attention', 'self', 'query', 'bias'), ('bert', 'encoder', 'layer', '2', 'output', 'dense', 'bias'), ('bert', 'encoder', 'layer', '8', 'attention', 'self', 'value', 'bias'), ('bert', 'encoder', 'layer', '3', 'output', 'LayerNorm', 'bias'), ('bert', 'encoder', 'layer', '0', 'output', 'LayerNorm', 'bias'), ('bert', 'encoder', 'layer', '6', 'attention', 'output', 'dense', 'bias'), ('bert', 'encoder', 'layer', '4', 'output', 'dense', 'bias'), ('bert', 'encoder', 'layer', '0', 'attention', 'output', 'LayerNorm', 'scale'), ('bert', 'encoder', 'layer', '3', 'attention', 'output', 'LayerNorm', 'bias'), ('bert', 'encoder', 'layer', '9', 'intermediate', 'dense', 'bias'), ('bert', 'encoder', 'layer', '7', 'attention', 'output', 'dense', 'bias'), ('bert', 'encoder', 'layer', '8', 'attention', 'output', 'dense', 'kernel'), ('bert', 'encoder', 'layer', '6', 'output', 'LayerNorm', 'scale'), ('bert', 'encoder', 'layer', '6', 'output', 'LayerNorm', 'bias'), ('bert', 'encoder', 'layer', '5', 'attention', 'self', 'query', 'kernel'), ('bert', 'encoder', 'layer', '11', 'attention', 'self', 'key', 'kernel'), ('bert', 'encoder', 'layer', '2', 'attention', 'output', 'LayerNorm', 'scale'), ('bert', 'encoder', 'layer', '1', 'intermediate', 'dense', 'bias'), ('bert', 'encoder', 'layer', '1', 'attention', 'self', 'value', 'bias'), ('bert', 'encoder', 'layer', '11', 'intermediate', 'dense', 'kernel'), ('bert', 'encoder', 'layer', '1', 'output', 'dense', 'bias'), ('bert', 'encoder', 'layer', '7', 'output', 'dense', 'bias'), ('bert', 'encoder', 'layer', '9', 'attention', 'self', 'value', 'bias'), ('bert', 'encoder', 'layer', '5', 'attention', 'output', 'LayerNorm', 'scale'), ('bert', 'encoder', 'layer', '9', 'attention', 'self', 'key', 'kernel'), ('bert', 'pooler', 'dense', 'kernel'), ('bert', 'encoder', 'layer', '1', 'attention', 'output', 'dense', 'kernel'), ('bert', 'encoder', 'layer', '0', 'intermediate', 'dense', 'kernel'), ('bert', 'encoder', 'layer', '4', 'output', 'LayerNorm', 'bias'), ('bert', 'encoder', 'layer', '10', 'attention', 'self', 'key', 'bias'), ('bert', 'encoder', 'layer', '8', 'output', 'dense', 'kernel'), ('bert', 'encoder', 'layer', '0', 'attention', 'self', 'key', 'kernel'), ('bert', 'encoder', 'layer', '6', 'intermediate', 'dense', 'kernel'), ('bert', 'encoder', 'layer', '11', 'attention', 'output', 'dense', 'bias'), ('bert', 'encoder', 'layer', '3', 'attention', 'self', 'value', 'bias'), ('bert', 'encoder', 'layer', '5', 'output', 'dense', 'bias'), ('bert', 'encoder', 'layer', '3', 'output', 'dense', 'kernel'), ('bert', 'encoder', 'layer', '9', 'attention', 'output', 'dense', 'bias'), ('bert', 'encoder', 'layer', '2', 'attention', 'self', 'key', 'kernel'), ('bert', 'encoder', 'layer', '9', 'output', 'LayerNorm', 'scale'), ('bert', 'encoder', 'layer', '11', 'attention', 'output', 'LayerNorm', 'scale'), ('bert', 'encoder', 'layer', '9', 'output', 'LayerNorm', 'bias'), ('bert', 'encoder', 'layer', '11', 'attention', 'output', 'LayerNorm', 'bias'), ('bert', 'encoder', 'layer', '3', 'attention', 'output', 'dense', 'kernel'), ('bert', 'encoder', 'layer', '7', 'attention', 'self', 'value', 'kernel'), ('bert', 'encoder', 'layer', '5', 'attention', 'self', 'key', 'kernel'), ('bert', 'encoder', 'layer', '9', 'attention', 'output', 'LayerNorm', 'scale'), ('bert', 'encoder', 'layer', '10', 'intermediate', 'dense', 'kernel'), ('bert', 'encoder', 'layer', '6', 'attention', 'self', 'query', 'bias'), ('bert', 'encoder', 'layer', '0', 'attention', 'output', 'dense', 'bias'), ('bert', 'embeddings', 'word_embeddings', 'embedding'), ('bert', 'encoder', 'layer', '10', 'attention', 'self', 'query', 'kernel'), ('bert', 'encoder', 'layer', '2', 'output', 'dense', 'kernel'), ('bert', 'encoder', 'layer', '4', 'attention', 'self', 'value', 'kernel'), ('bert', 'encoder', 'layer', '8', 'attention', 'self', 'value', 'kernel'), ('bert', 'encoder', 'layer', '2', 'attention', 'output', 'dense', 'bias'), ('bert', 'encoder', 'layer', '6', 'attention', 'output', 'dense', 'kernel'), ('bert', 'encoder', 'layer', '4', 'output', 'dense', 'kernel'), ('bert', 'encoder', 'layer', '9', 'intermediate', 'dense', 'kernel'), ('bert', 'encoder', 'layer', '5', 'attention', 'output', 'dense', 'bias'), ('bert', 'encoder', 'layer', '11', 'output', 'LayerNorm', 'scale'), ('bert', 'encoder', 'layer', '7', 'attention', 'output', 'dense', 'kernel'), ('bert', 'encoder', 'layer', '0', 'attention', 'self', 'query', 'kernel'), ('bert', 'encoder', 'layer', '1', 'attention', 'self', 'key', 'bias'), ('bert', 'encoder', 'layer', '1', 'intermediate', 'dense', 'kernel'), ('bert', 'encoder', 'layer', '0', 'attention', 'output', 'LayerNorm', 'bias'), ('bert', 'encoder', 'layer', '1', 'attention', 'self', 'value', 'kernel'), ('bert', 'encoder', 'layer', '1', 'output', 'dense', 'kernel'), ('bert', 'encoder', 'layer', '6', 'output', 'dense', 'bias'), ('bert', 'encoder', 'layer', '2', 'attention', 'self', 'query', 'kernel'), ('bert', 'encoder', 'layer', '7', 'output', 'dense', 'kernel'), ('bert', 'encoder', 'layer', '9', 'attention', 'self', 'value', 'kernel'), ('bert', 'encoder', 'layer', '3', 'intermediate', 'dense', 'bias'), ('bert', 'encoder', 'layer', '4', 'attention', 'output', 'dense', 'kernel'), ('bert', 'encoder', 'layer', '6', 'attention', 'self', 'key', 'bias'), ('bert', 'encoder', 'layer', '2', 'attention', 'output', 'LayerNorm', 'bias'), ('bert', 'encoder', 'layer', '11', 'attention', 'self', 'value', 'bias'), ('bert', 'encoder', 'layer', '10', 'attention', 'self', 'key', 'kernel'), ('bert', 'encoder', 'layer', '5', 'attention', 'output', 'LayerNorm', 'bias'), ('bert', 'encoder', 'layer', '11', 'attention', 'output', 'dense', 'kernel'), ('bert', 'encoder', 'layer', '8', 'attention', 'self', 'query', 'bias'), ('bert', 'encoder', 'layer', '5', 'output', 'dense', 'kernel'), ('bert', 'encoder', 'layer', '9', 'attention', 'output', 'dense', 'kernel'), ('bert', 'encoder', 'layer', '2', 'intermediate', 'dense', 'bias'), ('bert', 'encoder', 'layer', '4', 'attention', 'self', 'query', 'bias'), ('bert', 'encoder', 'layer', '7', 'intermediate', 'dense', 'bias'), ('bert', 'encoder', 'layer', '10', 'attention', 'output', 'dense', 'bias'), ('bert', 'encoder', 'layer', '6', 'attention', 'self', 'query', 'kernel'), ('bert', 'encoder', 'layer', '1', 'attention', 'self', 'query', 'bias'), ('bert', 'encoder', 'layer', '11', 'output', 'dense', 'bias'), ('bert', 'encoder', 'layer', '10', 'attention', 'output', 'LayerNorm', 'scale'), ('bert', 'encoder', 'layer', '5', 'attention', 'self', 'value', 'bias'), ('bert', 'encoder', 'layer', '9', 'attention', 'output', 'LayerNorm', 'bias'), ('bert', 'embeddings', 'token_type_embeddings', 'embedding'), ('bert', 'encoder', 'layer', '0', 'output', 'dense', 'bias'), ('bert', 'encoder', 'layer', '2', 'output', 'LayerNorm', 'scale'), ('bert', 'encoder', 'layer', '5', 'attention', 'output', 'dense', 'kernel'), ('bert', 'encoder', 'layer', '8', 'attention', 'self', 'key', 'bias'), ('bert', 'encoder', 'layer', '8', 'intermediate', 'dense', 'bias'), ('bert', 'encoder', 'layer', '3', 'attention', 'self', 'query', 'bias'), ('bert', 'encoder', 'layer', '4', 'attention', 'self', 'key', 'bias'), ('bert', 'encoder', 'layer', '1', 'attention', 'self', 'key', 'kernel'), ('bert', 'encoder', 'layer', '10', 'output', 'LayerNorm', 'scale'), ('bert', 'encoder', 'layer', '6', 'output', 'dense', 'kernel'), ('bert', 'encoder', 'layer', '11', 'output', 'LayerNorm', 'bias'), ('bert', 'encoder', 'layer', '5', 'intermediate', 'dense', 'bias'), ('bert', 'encoder', 'layer', '3', 'intermediate', 'dense', 'kernel'), ('bert', 'encoder', 'layer', '6', 'attention', 'output', 'LayerNorm', 'scale'), ('bert', 'encoder', 'layer', '6', 'attention', 'self', 'value', 'bias'), ('bert', 'encoder', 'layer', '6', 'attention', 'self', 'key', 'kernel'), ('bert', 'encoder', 'layer', '11', 'attention', 'self', 'key', 'bias'), ('bert', 'encoder', 'layer', '7', 'attention', 'output', 'LayerNorm', 'scale'), ('bert', 'encoder', 'layer', '11', 'attention', 'self', 'value', 'kernel'), ('bert', 'encoder', 'layer', '4', 'attention', 'output', 'LayerNorm', 'scale'), ('bert', 'encoder', 'layer', '7', 'output', 'LayerNorm', 'scale'), ('bert', 'encoder', 'layer', '7', 'attention', 'self', 'query', 'bias'), ('bert', 'encoder', 'layer', '8', 'attention', 'self', 'query', 'kernel'), ('bert', 'encoder', 'layer', '1', 'output', 'LayerNorm', 'scale'), ('bert', 'encoder', 'layer', '10', 'attention', 'self', 'value', 'bias'), ('bert', 'encoder', 'layer', '2', 'intermediate', 'dense', 'kernel'), ('bert', 'encoder', 'layer', '8', 'attention', 'output', 'LayerNorm', 'scale'), ('bert', 'encoder', 'layer', '3', 'attention', 'self', 'key', 'bias'), ('bert', 'encoder', 'layer', '3', 'attention', 'self', 'value', 'kernel'), ('bert', 'encoder', 'layer', '9', 'output', 'dense', 'bias'), ('bert', 'encoder', 'layer', '8', 'output', 'LayerNorm', 'scale'), ('bert', 'encoder', 'layer', '8', 'output', 'LayerNorm', 'bias'), ('bert', 'encoder', 'layer', '7', 'intermediate', 'dense', 'kernel'), ('bert', 'encoder', 'layer', '10', 'attention', 'output', 'dense', 'kernel'), ('bert', 'encoder', 'layer', '1', 'attention', 'self', 'query', 'kernel'), ('bert', 'encoder', 'layer', '0', 'attention', 'self', 'value', 'bias'), ('bert', 'encoder', 'layer', '11', 'output', 'dense', 'kernel'), ('bert', 'encoder', 'layer', '1', 'attention', 'output', 'LayerNorm', 'scale'), ('bert', 'embeddings', 'LayerNorm', 'scale'), ('bert', 'encoder', 'layer', '5', 'attention', 'self', 'value', 'kernel'), ('bert', 'encoder', 'layer', '4', 'intermediate', 'dense', 'bias'), ('bert', 'encoder', 'layer', '11', 'attention', 'self', 'query', 'bias'), ('bert', 'encoder', 'layer', '0', 'attention', 'output', 'dense', 'kernel'), ('bert', 'encoder', 'layer', '10', 'output', 'dense', 'bias'), ('bert', 'encoder', 'layer', '2', 'attention', 'self', 'value', 'bias'), ('bert', 'encoder', 'layer', '0', 'output', 'dense', 'kernel'), ('bert', 'encoder', 'layer', '9', 'attention', 'self', 'query', 'bias'), ('bert', 'encoder', 'layer', '7', 'attention', 'self', 'key', 'bias'), ('bert', 'encoder', 'layer', '10', 'attention', 'output', 'LayerNorm', 'bias'), ('bert', 'encoder', 'layer', '8', 'attention', 'self', 'key', 'kernel'), ('bert', 'encoder', 'layer', '8', 'intermediate', 'dense', 'kernel'), ('bert', 'encoder', 'layer', '2', 'attention', 'output', 'dense', 'kernel'), ('bert', 'encoder', 'layer', '3', 'attention', 'self', 'query', 'kernel'), ('bert', 'encoder', 'layer', '3', 'output', 'LayerNorm', 'scale'), ('bert', 'encoder', 'layer', '4', 'attention', 'self', 'value', 'bias'), ('bert', 'encoder', 'layer', '0', 'output', 'LayerNorm', 'scale'), ('bert', 'encoder', 'layer', '4', 'attention', 'self', 'key', 'kernel'), ('bert', 'encoder', 'layer', '2', 'output', 'LayerNorm', 'bias'), ('bert', 'encoder', 'layer', '3', 'attention', 'output', 'LayerNorm', 'scale'), ('bert', 'encoder', 'layer', '5', 'intermediate', 'dense', 'kernel'), ('bert', 'encoder', 'layer', '0', 'attention', 'self', 'query', 'bias'), ('bert', 'encoder', 'layer', '6', 'attention', 'self', 'value', 'kernel'), ('bert', 'encoder', 'layer', '10', 'output', 'LayerNorm', 'bias'), ('bert', 'embeddings', 'position_embeddings', 'embedding'), ('bert', 'encoder', 'layer', '8', 'attention', 'output', 'dense', 'bias'), ('bert', 'encoder', 'layer', '2', 'attention', 'self', 'query', 'bias'), ('bert', 'encoder', 'layer', '6', 'attention', 'output', 'LayerNorm', 'bias'), ('bert', 'encoder', 'layer', '5', 'attention', 'self', 'query', 'bias'), ('bert', 'encoder', 'layer', '4', 'attention', 'output', 'dense', 'bias'), ('bert', 'encoder', 'layer', '7', 'attention', 'self', 'query', 'kernel'), ('bert', 'encoder', 'layer', '7', 'attention', 'output', 'LayerNorm', 'bias'), ('bert', 'encoder', 'layer', '11', 'intermediate', 'dense', 'bias'), ('bert', 'encoder', 'layer', '4', 'output', 'LayerNorm', 'scale'), ('bert', 'encoder', 'layer', '4', 'attention', 'output', 'LayerNorm', 'bias'), ('bert', 'encoder', 'layer', '9', 'attention', 'self', 'key', 'bias'), ('bert', 'encoder', 'layer', '10', 'attention', 'self', 'value', 'kernel'), ('bert', 'pooler', 'dense', 'bias'), ('bert', 'encoder', 'layer', '7', 'output', 'LayerNorm', 'bias'), ('bert', 'encoder', 'layer', '1', 'attention', 'output', 'dense', 'bias'), ('bert', 'encoder', 'layer', '3', 'attention', 'self', 'key', 'kernel'), ('bert', 'encoder', 'layer', '0', 'intermediate', 'dense', 'bias'), ('bert', 'encoder', 'layer', '9', 'output', 'dense', 'kernel'), ('bert', 'encoder', 'layer', '4', 'attention', 'self', 'query', 'kernel'), ('bert', 'encoder', 'layer', '1', 'output', 'LayerNorm', 'bias'), ('bert', 'encoder', 'layer', '8', 'attention', 'output', 'LayerNorm', 'bias'), ('bert', 'encoder', 'layer', '8', 'output', 'dense', 'bias'), ('bert', 'encoder', 'layer', '0', 'attention', 'self', 'key', 'bias'), ('bert', 'encoder', 'layer', '6', 'intermediate', 'dense', 'bias'), ('bert', 'encoder', 'layer', '0', 'attention', 'self', 'value', 'kernel'), ('bert', 'encoder', 'layer', '3', 'output', 'dense', 'bias'), ('bert', 'encoder', 'layer', '4', 'intermediate', 'dense', 'kernel'), ('bert', 'encoder', 'layer', '11', 'attention', 'self', 'query', 'kernel'), ('bert', 'encoder', 'layer', '2', 'attention', 'self', 'key', 'bias')}

In [10]:
from flax.traverse_util import flatten_dict, unflatten_dict

In [11]:
required_keys = bert_ours.required_params

In [12]:
flattened_keys_transformed = set(flatten_dict(their_params_fl).keys())

In [13]:
def postprocess_flattned_keys(keys): 
    keys = list(keys)
    for i, key in enumerate(keys): 
        new_key = list(key)
        if 'encoder_layer_' in key[1]: 
            new_key = [new_key[0]] + new_key[1].split('_') + new_key[2:]
        if new_key[1] == 'embeddings_layer_norm': 
            new_key = [new_key[0]] + ['embeddings', 'LayerNorm'] + [new_key[-1]]
        if new_key[1] in ['position_embeddings', 
                          'word_embeddings', 
                          'type_embeddings']: 
            if new_key[1] == 'type_embeddings': 
                new_key[1] = 'token_type_embeddings'
            new_key = [new_key[0]] + ['embeddings'] + new_key[1:]
        if new_key[1] == 'pooler': 
            new_key.insert(-1, 'dense')
        if len(key) > 3:
            if new_key[4] == 'feed_forward': 
                new_key[4] = 'dense'
            if new_key[4] == 'self_attention': 
                new_key[4] = 'attention'
                new_key[5] = 'self'
                if new_key[6] == 'output':
                    new_key[5] = 'output'
                    new_key[6] = 'dense'
            if new_key[-3:-1] == ['dense', 'intermediate']: 
                new_key[-3] = 'intermediate'
                new_key[-2] = 'dense'
            if new_key[-2] == 'output_layer_norm': 
                new_key = new_key[:-2] + ['output', 'LayerNorm'] + new_key[-1:]
            if new_key[-2] == 'self_attention_layer_norm': 
                new_key = new_key[:-2] + ['attention', 'output', 'LayerNorm'] + new_key[-1:]
        if len(new_key) == 7: 
            if '.'.join(new_key[:3]) == 'bert.encoder.layer' and \
            '.'.join(new_key[-3:]) in ['dense.output.bias', 'dense.output.kernel']: 
                new_key[-3] = 'output'
                new_key[-2] = 'dense'
        keys[i] = tuple(new_key)
    return set(keys)

In [14]:
flattened_keys_cleaned = postprocess_flattned_keys(flattened_keys_transformed)

In [15]:
len(flattened_keys_cleaned)

201

In [16]:
required_keys - flattened_keys_cleaned

set()

In [17]:
flattened_keys_cleaned - required_keys

set()

Aww yeah, but now i have to use these keys to line up the dictionaries

In [18]:
def clean_state_dict(processed_state_dict):
    state_dict_flat = flatten_dict(processed_state_dict)
    clean_state_dict = {}
    for key in state_dict_flat.keys(): 
        new_key = list(key)
        if 'encoder_layer_' in key[1]: 
            new_key = [new_key[0]] + new_key[1].split('_') + new_key[2:]
        if new_key[1] == 'embeddings_layer_norm': 
            new_key = [new_key[0]] + ['embeddings', 'LayerNorm'] + [new_key[-1]]
        if new_key[1] in ['position_embeddings', 
                          'word_embeddings', 
                          'type_embeddings']: 
            if new_key[1] == 'type_embeddings': 
                new_key[1] = 'token_type_embeddings'
            new_key = [new_key[0]] + ['embeddings'] + new_key[1:]
        if new_key[1] == 'pooler': 
            new_key.insert(-1, 'dense')
        if len(key) > 3:
            if new_key[4] == 'feed_forward': 
                new_key[4] = 'dense'
            if new_key[4] == 'self_attention': 
                new_key[4] = 'attention'
                new_key[5] = 'self'
                if new_key[6] == 'output':
                    new_key[5] = 'output'
                    new_key[6] = 'dense'
            if new_key[-3:-1] == ['dense', 'intermediate']: 
                new_key[-3] = 'intermediate'
                new_key[-2] = 'dense'
            if new_key[-2] == 'output_layer_norm': 
                new_key = new_key[:-2] + ['output', 'LayerNorm'] + new_key[-1:]
            if new_key[-2] == 'self_attention_layer_norm': 
                new_key = new_key[:-2] + ['attention', 'output', 'LayerNorm'] + new_key[-1:]
        if len(new_key) == 7: 
            if '.'.join(new_key[:3]) == 'bert.encoder.layer' and \
            '.'.join(new_key[-3:]) in ['dense.output.bias', 'dense.output.kernel']: 
                new_key[-3] = 'output'
                new_key[-2] = 'dense'
        clean_state_dict[tuple(new_key)] = state_dict_flat[key]
    return unflatten_dict(clean_state_dict)

In [19]:
new_sd = clean_state_dict(their_params_fl)

In [20]:
bert_ours.params = new_sd

In [21]:
bert_ours

<transformers.models.bert.modeling_flax_bert.FlaxBertForSequenceClassification at 0x7f02b4db6fe0>

Well, hopefully this works!