In [1]:
from jax import numpy as jnp
import transformers
from transformers import (
    AutoTokenizer, 
    AutoModelForMaskedLM, 
    BertForMaskedLM, 
    BertTokenizer, 
    BertTokenizerFast, 
    BertEmbeddings,
    BfBertEmbeddings,
    BertConfig,
)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Establish data
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
text = ["hello I want to eat some [MASK] meat today. It's thanksgiving [MASK] all!", "yo yo what's up"]

# tokenize text and pass into model
tokens = tokenizer(text, return_tensors="pt", padding=True)
input_ids = tokens["input_ids"]
print(input_ids)

tensor([[  101,  7592,  1045,  2215,  2000,  4521,  2070,   103,  6240,  2651,
          1012,  2009,  1005,  1055, 15060,   103,  2035,   999,   102],
        [  101, 10930, 10930,  2054,  1005,  1055,  2039,   102,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0]])


In [3]:
# Create BfBertEmbeddings and BertEmbeddings
config = BertConfig.from_pretrained(pretrained_model_name_or_path="/home/kevin/code/rycolab/brunoflow/models/bert/config.json")
bf_embs = BfBertEmbeddings(config)
torch_embs = BertEmbeddings(config)

2022-12-21 14:58:45.381768: E external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:267] failed call to cuInit: CUDA_ERROR_UNKNOWN: unknown error
  self.token_type_ids = jnp.zeros(self.position_ids.shape, dtype=jnp.int64) # todo is this 64 bit necessary?


In [4]:
# Compare output of BfBertEmbeddings and BertEmbeddings on the text
jax_input_ids = jnp.array(input_ids.numpy(), dtype=int)
print(bf_embs(input_ids=jax_input_ids).val)
print(torch_embs(input_ids=input_ids))

[[[-0.34472668  0.5307575  -0.6256574  ... -0.74307156 -0.7948409
    0.410834  ]
  [-0.12443667 -0.8550836  -0.33784032 ... -0.03916439 -1.4545143
    0.26670295]
  [-2.1477165  -0.90835226 -1.105538   ...  0.75945294 -1.8616817
   -0.16157249]
  ...
  [-1.8152208   0.4767846  -0.3528018  ... -0.03726343 -2.2972047
    0.14477256]
  [-0.9735971  -2.01189    -0.08289954 ... -0.51852256 -1.0732709
    0.9564046 ]
  [-1.0029035  -1.2219137   0.8737193  ... -0.09658846 -1.1882741
    0.27896053]]

 [[-0.34472668  0.5307575  -0.6256574  ... -0.74307156 -0.7948409
    0.410834  ]
  [-0.21959372 -1.3216001  -0.11712483 ... -0.14018822 -1.1034871
    0.9674713 ]
  [-2.029127   -1.2888647  -1.9155287  ...  0.5731998  -0.8514466
    0.7274356 ]
  ...
  [-1.2329972   0.89194095 -0.20066781 ... -0.42031294 -0.9897826
   -0.20450695]
  [-2.481228   -2.043665   -0.8651028  ... -0.5931333  -0.5373075
   -0.54956347]
  [-1.3868319  -0.63329226 -0.24705154 ... -0.3211176  -0.5438302
   -0.5995097 ]]]
