In [68]:
from transformers import FlaxBertModel, BertTokenizerFast, BertConfig
from datasets import load_dataset
import jax
import jax.numpy as jnp

model_name = "bert-base-uncased"
config = BertConfig(model_name)
tokenizer = BertTokenizerFast.from_pretrained(model_name)
model = FlaxBertModel.from_pretrained(model_name)

Some weights of FlaxBertModel were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: {('pooler', 'dense', 'kernel'), ('pooler', 'dense', 'bias')}
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [30]:
def formatData(t, s=0):
    if isinstance(t, dict):
        for key in t:
            print("\t"*s + str(key) + ':')
            formatData(t[key], s+1)

In [47]:
config

BertConfig {
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.17.0",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": "bert-base-uncased"
}

In [48]:
config_dict = {
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": None,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.17.0",
  "type_vocab_size": 2,
  "use_cache": True,
  "vocab_size": 30522
}
sm_config_dict = config_dict
sm_config_dict["num_hidden_layers"] = 2

In [49]:
sm_config = BertConfig(**sm_config_dict)
sm_config

BertConfig {
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 2,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.17.0",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}

In [50]:
sm_model = FlaxBertModel(sm_config)

In [52]:
formatData(model.params)

embeddings:
	LayerNorm:
		bias:
		scale:
	position_embeddings:
		embedding:
	token_type_embeddings:
		embedding:
	word_embeddings:
		embedding:
encoder:
	layer:
		0:
			attention:
				output:
					LayerNorm:
						bias:
						scale:
					dense:
						bias:
						kernel:
				self:
					key:
						bias:
						kernel:
					query:
						bias:
						kernel:
					value:
						bias:
						kernel:
			intermediate:
				dense:
					bias:
					kernel:
			output:
				LayerNorm:
					bias:
					scale:
				dense:
					bias:
					kernel:
		1:
			attention:
				output:
					LayerNorm:
						bias:
						scale:
					dense:
						bias:
						kernel:
				self:
					key:
						bias:
						kernel:
					query:
						bias:
						kernel:
					value:
						bias:
						kernel:
			intermediate:
				dense:
					bias:
					kernel:
			output:
				LayerNorm:
					bias:
					scale:
				dense:
					bias:
					kernel:
		10:
			attention:
				output:
					LayerNorm:
						bias:
						scale:
					dense:
						bias:
						

In [51]:
formatData(sm_model.params)

embeddings:
	LayerNorm:
		bias:
		scale:
	position_embeddings:
		embedding:
	token_type_embeddings:
		embedding:
	word_embeddings:
		embedding:
encoder:
	layer:
		0:
			attention:
				output:
					LayerNorm:
						bias:
						scale:
					dense:
						bias:
						kernel:
				self:
					key:
						bias:
						kernel:
					query:
						bias:
						kernel:
					value:
						bias:
						kernel:
			intermediate:
				dense:
					bias:
					kernel:
			output:
				LayerNorm:
					bias:
					scale:
				dense:
					bias:
					kernel:
		1:
			attention:
				output:
					LayerNorm:
						bias:
						scale:
					dense:
						bias:
						kernel:
				self:
					key:
						bias:
						kernel:
					query:
						bias:
						kernel:
					value:
						bias:
						kernel:
			intermediate:
				dense:
					bias:
					kernel:
			output:
				LayerNorm:
					bias:
					scale:
				dense:
					bias:
					kernel:
pooler:
	dense:
		bias:
		kernel:


In [56]:
sm_model.params['embeddings'] = model.params['embeddings']
sm_model.params['pooler'] = model.params['pooler']
for i in range(2):
    sm_model.params['encoder']['layer'][str(i)] = model.params['encoder']['layer'][str(i)]

In [69]:
# dataset = load_dataset('oscar', "unshuffled_deduplicated_en", split='train', streaming=True)

# dummy_input = next(iter(dataset))["text"]

dummy_input = "The boulder has a steep burly start, then a delicate mantle finish."
input_ids = tokenizer(dummy_input, return_tensors="np").input_ids #[:, :10]

In [70]:
dummy_input

'The boulder has a steep burly start, then a delicate mantle finish.'

In [78]:
# run a forward pass, should return an object `FlaxBaseModelOutputWithPooling`
reps = model(input_ids, output_hidden_states=True)

In [72]:
reps.pooler_output.shape

(1, 768)

In [79]:
reps.last_hidden_state.shape

(1, 17, 768)

In [84]:
reps.hidden_states[-1]

Array(True, dtype=bool)

In [26]:
dummy_inputs = ["The boulder has a steep burly start, then a delicate mantle finish.",
                "Sit start is pretty scrunchy -- harder for the very tall!"]

input_ids = tokenizer.batch_encode_plus(
    dummy_inputs,
    add_special_tokens=True,
    return_attention_mask=True,
    padding='longest')

In [27]:
reps = model(jnp.array(input_ids['input_ids']), attention_mask=jnp.array(input_ids['attention_mask']))

In [28]:
reps.pooler_output.shape

(2, 768)

In [29]:
reps.last_hidden_state.shape

(2, 17, 768)

In [None]:
from transformers import AutoConfig, AutoTokenizer, \
    FlaxAutoModelForSequenceClassification


In [None]:
config = AutoConfig.from_pretrained(
    model_name,
    num_labels=3,
    finetuning_task='sst3',
    use_auth_token=None,
)
tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    use_fast=True,
    use_auth_token=None,
)
model = FlaxAutoModelForSequenceClassification.from_pretrained(
    model_name,
    config=config,
    use_auth_token=None,
    # from_pt=True,
)

Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/Users/lara.thompson/.local/share/virtualenvs/lara.thompson-C83ZgnRu/lib/python3.9/site-packages/transformers/file_utils.py", line 2777, in _get_module
  File "/Users/lara.thompson/.pyenv/versions/3.9.13/lib/python3.9/importlib/__init__.py", line 127, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
  File "<frozen importlib._bootstrap>", line 1030, in _gcd_import
  File "<frozen importlib._bootstrap>", line 1007, in _find_and_load
  File "<frozen importlib._bootstrap>", line 986, in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 680, in _load_unlocked
  File "<frozen importlib._bootstrap_external>", line 850, in exec_module
  File "<frozen importlib._bootstrap>", line 228, in _call_with_frames_removed
  File "/Users/lara.thompson/.local/share/virtualenvs/lara.thompson-C83ZgnRu/lib/python3.9/site-packages/transformers/models/mbart/configuration_mbart.py", line 23, in <module>
    from ...ut