In [1]:
%%capture
!pip install transformers sentencepiece flax

In [1]:
import requests
import numpy as np
from PIL import Image


url = 'http://images.cocodataset.org/val2017/000000001532.jpg'
image = Image.open(requests.get(url, stream=True).raw)
text = "A highway with a black car on it"

In [2]:
from transformers import ViTFeatureExtractor, FlaxViTModel, BertTokenizerFast, FlaxBertModel, BertConfig, ViTConfig
from transformers.models.vit.modeling_flax_vit import FlaxViTModule
import flax.linen as nn
import jax.numpy as jnp
import jax
from jax import lax, random

In [3]:
vit_config = ViTConfig.from_pretrained('google/vit-base-patch16-224-in21k')
bert_config = BertConfig.from_pretrained('bert-base-uncased')

## Flax

In [4]:
# configuration = ViTConfig(hidden_size=1024)
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
model_vit = FlaxViTModel.from_pretrained('google/vit-base-patch16-224-in21k')

vit_inputs = feature_extractor(images=image, return_tensors="jax")
encoder_outputs = model_vit(**vit_inputs)



In [5]:
import pprint

In [6]:
encoder_outputs['last_hidden_state'].shape

(1, 197, 768)

In [7]:
tokenizer = BertTokenizerFast.from_pretrained('bert-base-multilingual-uncased')
model_bert = FlaxBertModel.from_pretrained('bert-base-multilingual-uncased')

bert_inputs = tokenizer(text)

Some weights of FlaxBertModel were not initialized from the model checkpoint at bert-base-multilingual-uncased and are newly initialized: {('pooler', 'dense', 'bias'), ('pooler', 'dense', 'kernel')}
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [8]:
from transformers import AutoConfig
AutoConfig.for_model(**bert_config.to_dict())

BertConfig {
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.8.2",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}

In [9]:
import copy
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging


logger = logging.get_logger(__name__)


class ViTBertConfig(PretrainedConfig):
    model_type = "vit-bert"
    is_composition = True

    def __init__(self, bert_config_dict, vit_config_dict, **kwargs):
        super().__init__(**kwargs)

        if bert_config_dict is None:
            raise ValueError("`bert_config_dict` can not be `None`.")

        if vit_config_dict is None:
            raise ValueError("`vit_config_dict` can not be `None`.")

        self.bert_config = BertConfig(**bert_config_dict)

        self.vit_config = ViTConfig(**vit_config_dict)

    @classmethod
    def from_bert_vit_configs(cls, bert_config: PretrainedConfig, vit_config: PretrainedConfig, **kwargs):
        r"""
        Instantiate a :class:`HybridCLIPConfig` (or a derived class) from text model configuration and
        vision model configuration.
        Returns:
            :class:`HybridCLIPConfig`: An instance of a configuration object
        """

        return cls(bert_config_dict=bert_config.to_dict(), vit_config_dict=vit_config.to_dict(), **kwargs)

    def to_dict(self):
        """
        Serializes this instance to a Python dictionary. Override the default
        :meth:`~transformers.PretrainedConfig.to_dict`.
        Returns:
            :obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
        """
        output = copy.deepcopy(self.__dict__)
        output["bert_config"] = self.bert_config.to_dict()
        output["vit_config"] = self.vit_config.to_dict()
        output["model_type"] = self.__class__.model_type
        return output

In [10]:
class FlaxViTBertEmbeddings(nn.Module):
    """Construct the embeddings from word, position and token_type embeddings."""

    config: ViTBertConfig
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation

    def setup(self):
        bert_config = self.config.bert_config
        vit_config = self.config.vit_config

        self.word_embeddings = nn.Embed(
            bert_config.vocab_size,
            bert_config.hidden_size,
            embedding_init=jax.nn.initializers.normal(stddev=bert_config.initializer_range),
            dtype=self.dtype,
        )
        self.position_embeddings = nn.Embed(
            bert_config.max_position_embeddings,
            bert_config.hidden_size,
            embedding_init=jax.nn.initializers.normal(stddev=bert_config.initializer_range),
            dtype=self.dtype,
        )
        self.token_type_embeddings = nn.Embed(
            bert_config.type_vocab_size,
            bert_config.hidden_size,
            embedding_init=jax.nn.initializers.normal(stddev=bert_config.initializer_range),
            dtype=self.dtype,
        )

        self.vit_module = FlaxViTModule(vit_config, dtype=self.dtype)
        self.visual_projection = nn.Dense(bert_config.hidden_size, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(bert_config.initializer_range, self.dtype))

        self.visual_position_embeddings = nn.Embed(
            bert_config.max_position_embeddings,
            bert_config.hidden_size,
            embedding_init=jax.nn.initializers.normal(stddev=bert_config.initializer_range),
            dtype=self.dtype,
        )
        self.visual_token_type_embeddings = nn.Embed(
            bert_config.type_vocab_size,
            bert_config.hidden_size,
            embedding_init=jax.nn.initializers.normal(stddev=bert_config.initializer_range),
            dtype=self.dtype,
        )

        self.LayerNorm = nn.LayerNorm(epsilon=bert_config.layer_norm_eps, dtype=self.dtype)
        self.dropout = nn.Dropout(rate=bert_config.hidden_dropout_prob)
        
    def __call__(self, input_ids, token_type_ids, position_ids, pixel_values, visual_token_type_ids, visual_position_ids, deterministic: bool = True):
        # Embed
        inputs_embeds = self.word_embeddings(input_ids.astype("i4"))
        position_embeds = self.position_embeddings(position_ids.astype("i4"))
        token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4"))

        # Sum all embeddings
        word_embeddings = inputs_embeds + token_type_embeddings + position_embeds

        # Visual Embed
        visual_inputs_embeds = self.vit_module(pixel_values=pixel_values)[0]
        visual_inputs_embeds = self.visual_projection(visual_inputs_embeds)
        visual_token_type_embeddings = self.visual_token_type_embeddings(visual_token_type_ids.astype("i4"))
        visual_position_embeds = self.visual_position_embeddings(visual_position_ids.astype("i4"))

        # Sum all visual embeddings
        visual_embeddings = visual_inputs_embeds + visual_token_type_embeddings + visual_position_embeds

        # Concat
        hidden_states = jnp.concatenate((word_embeddings, visual_embeddings),axis=1)

        # Layer Norm
        hidden_states = self.LayerNorm(hidden_states)
        hidden_states = self.dropout(hidden_states, deterministic=deterministic)
        return hidden_states

In [11]:
input_ids = jnp.array([bert_inputs["input_ids"]], dtype=jnp.int32)
token_type_ids = jnp.array([bert_inputs['token_type_ids']], dtype=jnp.int32)
attention_mask = jnp.array([bert_inputs['attention_mask']], dtype=jnp.int32)
position_ids = jnp.arange(0, input_ids.shape[1], dtype=jnp.int32).reshape(1, -1)
pixel_values = vit_inputs['pixel_values']
visual_attention_mask = jnp.ones(encoder_outputs.last_hidden_state.shape[:-1], dtype=jnp.int32)
visual_token_type_ids = jnp.ones(encoder_outputs.last_hidden_state.shape[:-1], dtype=jnp.int32)
visual_position_ids = jnp.zeros(encoder_outputs.last_hidden_state.shape[:-1], dtype=jnp.int32)

In [12]:
attention_mask.shape

(1, 10)

In [13]:
visual_attention_mask.shape

(1, 197)

In [14]:
jnp.concatenate((attention_mask, visual_attention_mask), axis=1).shape

(1, 207)

In [15]:
vit_bert_config = ViTBertConfig.from_bert_vit_configs(bert_config, vit_config)
vit_bert_config

ViTBertConfig {
  "bert_config": {
    "_name_or_path": "",
    "add_cross_attention": false,
    "architectures": [
      "BertForMaskedLM"
    ],
    "attention_probs_dropout_prob": 0.1,
    "bad_words_ids": null,
    "bos_token_id": null,
    "chunk_size_feed_forward": 0,
    "decoder_start_token_id": null,
    "diversity_penalty": 0.0,
    "do_sample": false,
    "early_stopping": false,
    "encoder_no_repeat_ngram_size": 0,
    "eos_token_id": null,
    "finetuning_task": null,
    "forced_bos_token_id": null,
    "forced_eos_token_id": null,
    "gradient_checkpointing": false,
    "hidden_act": "gelu",
    "hidden_dropout_prob": 0.1,
    "hidden_size": 768,
    "id2label": {
      "0": "LABEL_0",
      "1": "LABEL_1"
    },
    "initializer_range": 0.02,
    "intermediate_size": 3072,
    "is_decoder": false,
    "is_encoder_decoder": false,
    "label2id": {
      "LABEL_0": 0,
      "LABEL_1": 1
    },
    "layer_norm_eps": 1e-12,
    "length_penalty": 1.0,
    "max_length": 

In [16]:
# flax_embedding_layer = FlaxViTBertEmbeddings(vit_bert_config)
# key = random.PRNGKey(0)
# params = flax_embedding_layer.init(key, input_ids, token_type_ids, position_ids, pixel_values, visual_token_type_ids, visual_position_ids)
# jax.tree_map(lambda x: x.shape, params)
# flax_embedding_layer.apply(params, input_ids, token_type_ids, position_ids, pixel_values, visual_token_type_ids, visual_position_ids).shape

In [17]:
from transformers.models.bert.modeling_flax_bert import FlaxPreTrainedModel, FlaxBertEncoder, FlaxBertPooler, FlaxBaseModelOutputWithPooling, FlaxBertPreTrainedModel
from typing import Tuple, Optional
from flax.core.frozen_dict import FrozenDict

In [18]:
class FlaxViTBertModule(nn.Module):
    config: ViTBertConfig
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation
    add_pooling_layer: bool = True

    def setup(self):
        self.embeddings = FlaxViTBertEmbeddings(self.config, dtype=self.dtype)
        self.encoder = FlaxBertEncoder(self.config.bert_config, dtype=self.dtype)
        self.pooler = FlaxBertPooler(self.config.bert_config, dtype=self.dtype)
        print(self.embeddings)
    def __call__(
        self,
        input_ids,
        attention_mask,
        token_type_ids,
        position_ids,
        pixel_values,
        visual_attention_mask,
        visual_token_type_ids, 
        visual_position_ids,
        deterministic = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        hidden_states = self.embeddings(
            input_ids, token_type_ids, position_ids, pixel_values, visual_token_type_ids, visual_position_ids, deterministic=deterministic
        )

        combined_attention_mask = jnp.concatenate((attention_mask, visual_attention_mask), axis=1)
        deterministic=True
        outputs = self.encoder(
            hidden_states,
            combined_attention_mask,
            deterministic=True,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        hidden_states = outputs[0]
        pooled = self.pooler(hidden_states) if self.add_pooling_layer else None

        if not return_dict:
            # if pooled is None, don't return it
            if pooled is None:
                return (hidden_states,) + outputs[1:]
            return (hidden_states, pooled) + outputs[1:]

        return FlaxBaseModelOutputWithPooling(
            last_hidden_state=hidden_states,
            pooler_output=pooled,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

In [19]:
class FlaxViTBertModel(FlaxPreTrainedModel):
    config_class: ViTBertConfig
    module_class = FlaxViTBertModule

    def __init__(
        self, config: ViTBertConfig, input_shape: Tuple = None, seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
    ):

        if input_shape is None:
            input_shape = ((1, 1), (1, config.vit_config.image_size, config.vit_config.image_size, 3), (1, 197))

        module = self.module_class(config=config, dtype=dtype, **kwargs)
        super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)

    def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
        # init input tensors
        textual_input_shape = input_shape[0]
        input_ids = jnp.zeros(textual_input_shape, dtype="i4")
        token_type_ids = jnp.zeros_like(input_ids)
        position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), textual_input_shape)
        attention_mask = jnp.ones_like(input_ids)

        pixel_values = jax.random.normal(rng, input_shape[1])
        visual_attention_mask = jnp.ones(input_shape[2]) # TODO: Fix this
        visual_token_type_ids = jnp.ones(input_shape[2]) # TODO: Fix this
        visual_position_ids = jnp.broadcast_to(jnp.zeros(jnp.atleast_2d(input_ids).shape[-1]), input_shape[2]) # TODO: Fix this

        params_rng, dropout_rng = jax.random.split(rng)
        rngs = {"params": params_rng, "dropout": dropout_rng}

        return self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids, pixel_values,
        visual_attention_mask,
        visual_token_type_ids, 
        visual_position_ids, return_dict=False)[
            "params"
        ]

    def __call__(
        self,
        input_ids,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        pixel_values=None,
        visual_attention_mask=None,
        visual_token_type_ids=None, 
        visual_position_ids=None,
        
        params: dict = None,
        dropout_rng: jax.random.PRNGKey = None,
        train: bool = False,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        output_attentions = output_attentions if output_attentions is not None else self.config.bert_config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.bert_config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.bert_config.return_dict


        pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))

        # init input tensors if not passed
        if token_type_ids is None:
            token_type_ids = jnp.zeros_like(input_ids)

        if position_ids is None:
            position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)

        if attention_mask is None:
            attention_mask = jnp.ones_like(input_ids)

        if visual_token_type_ids is None:
            visual_token_type_ids = jnp.ones(input_ids.shape) # TODO: Fix this.

        if visual_position_ids is None:
            visual_position_ids = jnp.broadcast_to(jnp.atleast_2d(input_ids).shape[-1],input_ids.shape) # TODO: Fix this.

        if visual_attention_mask is None:
            visual_attention_mask = jnp.ones(input_ids.shape) # TODO: Fix this.

        # Handle any PRNG if needed
        rngs = {}
        if dropout_rng is not None:
            rngs["dropout"] = dropout_rng

        return self.module.apply(
            {"params": params or self.params},
            jnp.array(input_ids, dtype="i4"),
            jnp.array(attention_mask, dtype="i4"),
            jnp.array(token_type_ids, dtype="i4"),
            jnp.array(position_ids, dtype="i4"),
            jnp.array(pixel_values, dtype=jnp.float32),
            jnp.array(visual_attention_mask, dtype="i4"),
            jnp.array(visual_token_type_ids, dtype="i4"),
            jnp.array(visual_position_ids, dtype="i4"),
            not train,
            output_attentions,
            output_hidden_states,
            return_dict,
            rngs=rngs,
        )

    @classmethod
    def from_bert_vit_pretrained(
        cls,
        bert_model_name_or_path: str = None,
        vit_model_name_or_path: str = None,
        *model_args,
        **kwargs,
    ) -> FlaxPreTrainedModel:

        kwargs_bert = {
            argument[len("bert_") :]: value for argument, value in kwargs.items() if argument.startswith("text_")
        }

        kwargs_vit = {
            argument[len("vit_") :]: value for argument, value in kwargs.items() if argument.startswith("vision_")
        }

        # remove text, vision kwargs from kwargs
        for key in kwargs_bert.keys():
            del kwargs["bert_" + key]
        for key in kwargs_vit.keys():
            del kwargs["vit_" + key]

        # Load and initialize the text and vision model
        bert_model = kwargs_bert.pop("model", None)
        if bert_model is None:
            assert (
                bert_model_name_or_path is not None
            ), "If `model` is not defined as an argument, a `bert_model_name_or_path` has to be defined"
            from transformers import FlaxBertModel

            if "config" not in kwargs_bert:
                from transformers import BertConfig

                bert_config = BertConfig.from_pretrained(bert_model_name_or_path)
                kwargs_bert["config"] = bert_config

            bert_model = FlaxBertModel.from_pretrained(
                bert_model_name_or_path, *model_args, from_pt=True, **kwargs_bert
            )

        vit_model = kwargs_vit.pop("model", None)
        if vit_model is None:
            assert (
                vit_model_name_or_path is not None
            ), "If `model` is not defined as an argument, a `vit_model_name_or_path` has to be defined"
            from transformers import FlaxViTModel

            if "config" not in kwargs_vit:
                from transformers import ViTConfig

                vit_config = ViTConfig.from_pretrained(vit_model_name_or_path)
                kwargs_vit["config"] = vit_config

            vit_model = FlaxViTModel.from_pretrained(vit_model_name_or_path, *model_args, **kwargs_vit)

        # instantiate config with corresponding kwargs
        dtype = kwargs.pop("dtype", jnp.float32)
        config = ViTBertConfig.from_bert_vit_configs(bert_model.config, vit_model.config, **kwargs)

        # init model
        model = cls(config, *model_args, dtype=dtype, **kwargs)

        for key in model.params.keys():
            if key != "embeddings":
                model.params[key] = bert_model.params[key]
            else:
                model.params["embeddings"]["vit_module"] = vit_model.params
                for sub_key in bert_model.params[key]:
                    model.params[key][sub_key] = bert_model.params[key][sub_key]

        return model

In [20]:
flax_model = FlaxViTBertModel.from_bert_vit_pretrained('bert-base-uncased', 'google/vit-base-patch16-224-in21k', seed=0, dtype=jnp.float32)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing FlaxBertModel: {('cls', 'predictions', 'transform', 'LayerNorm', 'bias'), ('cls', 'seq_relationship', 'bias'), ('cls', 'seq_relationship', 'kernel'), ('cls', 'predictions', 'decoder', 'kernel'), ('cls', 'predictions', 'transform', 'dense', 'bias'), ('cls', 'predictions', 'transform', 'LayerNorm', 'weight'), ('cls', 'predictions', 'transform', 'dense', 'kernel'), ('cls', 'predictions', 'bias')}
- This IS expected if you are initializing FlaxBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing FlaxBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


FlaxViTBertEmbeddings(
    # attributes
    config = ViTBertConfig {
      "bert_config": {
        "_name_or_path": "",
        "add_cross_attention": false,
        "architectures": [
          "BertForMaskedLM"
        ],
        "attention_probs_dropout_prob": 0.1,
        "bad_words_ids": null,
        "bos_token_id": null,
        "chunk_size_feed_forward": 0,
        "decoder_start_token_id": null,
        "diversity_penalty": 0.0,
        "do_sample": false,
        "early_stopping": false,
        "encoder_no_repeat_ngram_size": 0,
        "eos_token_id": null,
        "finetuning_task": null,
        "forced_bos_token_id": null,
        "forced_eos_token_id": null,
        "gradient_checkpointing": false,
        "hidden_act": "gelu",
        "hidden_dropout_prob": 0.1,
        "hidden_size": 768,
        "id2label": {
          "0": "LABEL_0",
          "1": "LABEL_1"
        },
        "initializer_range": 0.02,
        "intermediate_size": 3072,
        "is_decoder": false

In [48]:
outputs = flax_model(input_ids, attention_mask,token_type_ids, position_ids, pixel_values, visual_attention_mask, visual_token_type_ids, visual_position_ids, output_hidden_states=True)

FlaxViTBertEmbeddings(
    # attributes
    config = ViTBertConfig {
      "bert_config": {
        "_name_or_path": "",
        "add_cross_attention": false,
        "architectures": [
          "BertForMaskedLM"
        ],
        "attention_probs_dropout_prob": 0.1,
        "bad_words_ids": null,
        "bos_token_id": null,
        "chunk_size_feed_forward": 0,
        "decoder_start_token_id": null,
        "diversity_penalty": 0.0,
        "do_sample": false,
        "early_stopping": false,
        "encoder_no_repeat_ngram_size": 0,
        "eos_token_id": null,
        "finetuning_task": null,
        "forced_bos_token_id": null,
        "forced_eos_token_id": null,
        "gradient_checkpointing": false,
        "hidden_act": "gelu",
        "hidden_dropout_prob": 0.1,
        "hidden_size": 768,
        "id2label": {
          "0": "LABEL_0",
          "1": "LABEL_1"
        },
        "initializer_range": 0.02,
        "intermediate_size": 3072,
        "is_decoder": false

In [51]:
outputs.keys()

odict_keys(['last_hidden_state', 'pooler_output', 'hidden_states'])

In [31]:
from transformers.models.bert.modeling_flax_bert import FlaxBertOnlyMLMHead

class FlaxViTBertForMaskedLMModule(nn.Module):
    config: ViTBertConfig
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        self.vitbert = FlaxViTBertModule(config=self.config, add_pooling_layer=False, dtype=self.dtype)
        self.cls = FlaxBertOnlyMLMHead(config=self.config.bert_config, dtype=self.dtype)

    def __call__(
        self,
        input_ids,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        pixel_values=None,
        visual_attention_mask=None,
        visual_token_type_ids=None, 
        visual_position_ids=None,
        params: dict = None,
        dropout_rng: jax.random.PRNGKey = None,
        train: bool = False,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):

        # Model
        outputs = self.vitbert(
                input_ids,
                attention_mask,
                token_type_ids,
                position_ids,
                pixel_values,
                visual_attention_mask,
                visual_token_type_ids, 
                visual_position_ids,
                deterministic=True,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
        )

        hidden_states = outputs[0]
        if self.config.bert_config.tie_word_embeddings:
            shared_embedding = self.vitbert.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
        else:
            shared_embedding = None

        # Compute the prediction scores
        logits = self.cls(hidden_states, shared_embedding=shared_embedding)

        if not return_dict:
            return (logits,) + outputs[1:]

        return FlaxMaskedLMOutput(
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

In [32]:
class FlaxViTBertForMaskedLM(FlaxViTBertModel):
    module_class = FlaxViTBertForMaskedLMModule



In [38]:
flax_vitbert_mlm = FlaxViTBertForMaskedLM(vit_bert_config)

FlaxViTBertEmbeddings(
    # attributes
    config = ViTBertConfig {
      "bert_config": {
        "_name_or_path": "",
        "add_cross_attention": false,
        "architectures": [
          "BertForMaskedLM"
        ],
        "attention_probs_dropout_prob": 0.1,
        "bad_words_ids": null,
        "bos_token_id": null,
        "chunk_size_feed_forward": 0,
        "decoder_start_token_id": null,
        "diversity_penalty": 0.0,
        "do_sample": false,
        "early_stopping": false,
        "encoder_no_repeat_ngram_size": 0,
        "eos_token_id": null,
        "finetuning_task": null,
        "forced_bos_token_id": null,
        "forced_eos_token_id": null,
        "gradient_checkpointing": false,
        "hidden_act": "gelu",
        "hidden_dropout_prob": 0.1,
        "hidden_size": 768,
        "id2label": {
          "0": "LABEL_0",
          "1": "LABEL_1"
        },
        "initializer_range": 0.02,
        "intermediate_size": 3072,
        "is_decoder": false

In [52]:
outputs_vit = flax_vitbert_mlm (input_ids=input_ids,pixel_values=pixel_values,
                                 visual_attention_mask = visual_attention_mask,\
                                 visual_token_type_ids=visual_token_type_ids,
                                 visual_position_ids=visual_position_ids)

FlaxViTBertEmbeddings(
    # attributes
    config = ViTBertConfig {
      "bert_config": {
        "_name_or_path": "",
        "add_cross_attention": false,
        "architectures": [
          "BertForMaskedLM"
        ],
        "attention_probs_dropout_prob": 0.1,
        "bad_words_ids": null,
        "bos_token_id": null,
        "chunk_size_feed_forward": 0,
        "decoder_start_token_id": null,
        "diversity_penalty": 0.0,
        "do_sample": false,
        "early_stopping": false,
        "encoder_no_repeat_ngram_size": 0,
        "eos_token_id": null,
        "finetuning_task": null,
        "forced_bos_token_id": null,
        "forced_eos_token_id": null,
        "gradient_checkpointing": false,
        "hidden_act": "gelu",
        "hidden_dropout_prob": 0.1,
        "hidden_size": 768,
        "id2label": {
          "0": "LABEL_0",
          "1": "LABEL_1"
        },
        "initializer_range": 0.02,
        "intermediate_size": 3072,
        "is_decoder": false

In [53]:
outputs_vit[0].shape

(1, 207, 30522)

In [76]:
class FlaxViTMaskedLM(FlaxPreTrainedModel):
    config_class: ViTBertConfig
    module_class = FlaxViTBertForMaskedLMModule
    dtype: jnp.dtype = jnp.float32

    def __init__(self, config):
        super().__init__(config)

        if config.is_decoder:
            logger.warning(
                "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for "
                "bi-directional self-attention."
            )

        self.vitbert = FlaxViTBertModel(config=self.config, add_pooling_layer=False,dtype=self.dtype)
        self.cls = FlaxBertOnlyMLMHead(config=self.config.bert_config, dtype=self.dtype)

        self.init_weights()
        



    def get_output_embeddings(self):
        return self.cls.predictions.decoder

    def set_output_embeddings(self, new_embeddings):
        self.cls.predictions.decoder = new_embeddings



    def forward(
        self,
        input_ids,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        pixel_values=None,
        visual_attention_mask=None,
        visual_token_type_ids=None, 
        visual_position_ids=None,
        params: dict = None,
        dropout_rng: jax.random.PRNGKey = None,
        train: bool = False,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None
    ):

        return_dict = return_dict if return_dict is not None else self.config.bert_config.return_dict
        pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))

        outputs = self.vitbert(
                input_ids,
                attention_mask,
                token_type_ids,
                position_ids,
                pixel_values,
                visual_attention_mask,
                visual_token_type_ids, 
                visual_position_ids,
                deterministic=deterministic,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
        )

        sequence_output = outputs[0]
        prediction_scores = self.cls(sequence_output)

        masked_lm_loss = None
        if labels is not None:
            loss_fct = CrossEntropyLoss()  # -100 index = padding token
            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.bert_config.vocab_size), labels.view(-1))

        if not return_dict:
            output = (prediction_scores,) + outputs[2:]
            return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output


        return FlaxMaskedLMOutput(
            loss=masked_lm_loss,
            logits=prediction_scores,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )




  