In [1]:
%env XLA_PYTHON_CLIENT_PREALLOCATE=false
import brunoflow as bf
from brunoflow.ad.utils import check_node_equals_tensor, check_node_allclose_tensor
from jax import numpy as jnp
import transformers
import torch
from transformers import (
    AutoTokenizer, 
    AutoModelForMaskedLM, 
    BertForMaskedLM, 
    BertTokenizer, 
    BertTokenizerFast, 
    BertEmbeddings,
    BfBertEmbeddings,
    BertConfig,
)

torch.manual_seed(0)

env: XLA_PYTHON_CLIENT_PREALLOCATE=false


  from .autonotebook import tqdm as notebook_tqdm


<torch._C.Generator at 0x7f298f02b4b0>

In [2]:
# Establish data
model_id = "google/bert_uncased_L-2_H-128_A-2"

tokenizer = BertTokenizerFast.from_pretrained(model_id)
text = ["hello I want to eat some [MASK] meat today. It's thanksgiving [MASK] all!", "yo yo what's up"]

# tokenize text and pass into model
tokens = tokenizer(text, return_tensors="pt", padding=True)
input_ids = tokens["input_ids"]
print(input_ids)

tensor([[  101,  7592,  1045,  2215,  2000,  4521,  2070,   103,  6240,  2651,
          1012,  2009,  1005,  1055, 15060,   103,  2035,   999,   102],
        [  101, 10930, 10930,  2054,  1005,  1055,  2039,   102,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0]])


In [3]:
# Create BfBertEmbeddings and BertEmbeddings
config = BertConfig.from_pretrained(pretrained_model_name_or_path="../../brunoflow/models/bert/config-tiny.json")
bf_embs = BfBertEmbeddings(config)
torch_embs = BertEmbeddings(config)
print(bf_embs)
print(torch_embs)

BfBertEmbeddings(
  (word_embeddings): Embedding(30522, 128, padding_idx=0)
  (position_embeddings): Embedding(512, 128)
  (token_type_embeddings): Embedding(2, 128)
  (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
  (dropout): Dropout(p=0)
)
BertEmbeddings(
  (word_embeddings): Embedding(30522, 128, padding_idx=0)
  (position_embeddings): Embedding(512, 128)
  (token_type_embeddings): Embedding(2, 128)
  (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
  (dropout): Dropout(p=0, inplace=False)
)


  "token_type_ids", bf.Node(jnp.zeros(self.position_ids.shape, dtype=jnp.int64)), persistent=False


In [4]:
# Save torch BertEmbeddings to file
save_path = "bertembeddings_torch.pt"
torch.save(torch_embs.state_dict(), save_path)

In [5]:
# Load torch BertEmbeddings into bf
bf_embs.load_state_dict(torch.load(save_path))

<All keys matched successfully>

In [6]:
bf_embs

BfBertEmbeddings(
  (word_embeddings): Embedding(30522, 128, padding_idx=0)
  (position_embeddings): Embedding(512, 128)
  (token_type_embeddings): Embedding(2, 128)
  (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
  (dropout): Dropout(p=0)
)

In [7]:
# Check that bert embedding values loaded correctly into bf and they match the torch vals 
assert(check_node_equals_tensor(bf_embs.word_embeddings.weight, torch_embs.word_embeddings.weight))
assert(check_node_equals_tensor(bf_embs.position_embeddings.weight, torch_embs.position_embeddings.weight))
assert(check_node_equals_tensor(bf_embs.token_type_embeddings.weight, torch_embs.token_type_embeddings.weight))
assert(check_node_equals_tensor(bf_embs.LayerNorm.weight, torch_embs.LayerNorm.weight))
# print(check_node_equals_tensor(bf_embs.dropout.weight, torch_embs.dropout.weight)) # this fails because dropout has no weights, I guess


In [8]:
# See how much changing the precision affects the gradient
# for name, param in bf_embs.named_parameters():
#     param.val = jnp.round(param.val, 3)
#     print(name, param)



In [9]:
# Compare output of forward pass of BfBertEmbeddings and BertEmbeddings on the text - they're equal!
jax_input_ids = jnp.array(input_ids.numpy(), dtype=int)
torch_embs.train(False)
bf_embs.train(False)

out_bf = bf_embs(input_ids=jax_input_ids)
out_torch = torch_embs(input_ids=input_ids)
# print(out_bf.val)
# print(out_torch)
# print(out_bf.val - out_torch.detach().numpy())
assert(check_node_allclose_tensor(out_bf, out_torch, atol=1e-3))



### Compare grads of parameters between torch and bf after a backward pass

In [10]:
%%time
# Torch backward pass
torch_embs.train(True)
out_torch.backward(gradient=torch.ones_like(out_torch))

CPU times: user 4.2 ms, sys: 69 µs, total: 4.27 ms
Wall time: 2.97 ms


In [11]:
%%time
# BF backward pass
out_bf.backprop(values_to_compute=("grad",))



CPU times: user 2.26 s, sys: 342 ms, total: 2.6 s
Wall time: 4.69 s




In [12]:
# print("word_embeddings:", bf_embs.word_embeddings.weight.grad, torch_embs.word_embeddings.weight.grad)
# print("position_embeddings:", bf_embs.position_embeddings.weight.grad, torch_embs.position_embeddings.weight.grad)
# print("token_type_embeddings:", bf_embs.token_type_embeddings.weight.grad, torch_embs.token_type_embeddings.weight.grad)
# print("LayerNorm:", bf_embs.LayerNorm.weight.grad, torch_embs.LayerNorm.weight.grad)

In [13]:
bf_emb_params = {name: param for name, param in bf_embs.named_parameters()}
torch_emb_params = {name: param for name, param in torch_embs.named_parameters()}
assert set(bf_emb_params.keys()) == set(torch_emb_params.keys())

for name in bf_emb_params.keys():
    print(f"Grad of param {name} for bf and torch are within 1e-6? {jnp.allclose(bf_emb_params[name].grad, torch_emb_params[name].grad.numpy(), atol=1e-6)}")
    assert jnp.allclose(bf_emb_params[name].grad, torch_emb_params[name].grad.numpy(), atol=1e-6), f"Grad of param {name} for bf and torch are not within 1e-6."

Grad of param word_embeddings.weight for bf and torch are within 1e-6? True
Grad of param position_embeddings.weight for bf and torch are within 1e-6? True
Grad of param token_type_embeddings.weight for bf and torch are within 1e-6? True
Grad of param LayerNorm.weight for bf and torch are within 1e-6? True
Grad of param LayerNorm.bias for bf and torch are within 1e-6? True


In [14]:
(bf_emb_params["LayerNorm.weight"].grad, torch_emb_params["LayerNorm.weight"].grad.numpy())

(DeviceArray([ 56.536148  ,   9.873916  ,  -9.946613  ,  29.669815  ,
              -19.766806  ,   4.3356705 , -23.152435  , -29.78961   ,
              -18.999136  ,  36.90544   , -35.539238  , -37.673416  ,
               27.412798  ,   0.0662511 ,  13.515636  ,  -9.084731  ,
               -9.820608  , -51.483334  ,  10.379288  ,  26.532597  ,
               -2.8499513 ,  10.431053  ,   5.990584  , -41.563046  ,
                8.451949  ,  42.47052   ,   6.350166  ,  -6.573671  ,
               -1.4423742 ,  37.940216  ,  52.01646   ,  23.367987  ,
               -6.191135  ,   0.69063747,  -8.358557  ,  43.680916  ,
               10.069973  ,  -4.429394  , -37.8228    ,  -1.1591048 ,
               24.835108  , -44.698177  , -19.027328  , -22.159668  ,
               -3.9408004 ,   0.2023251 , -10.610362  ,   0.97283334,
               25.851631  , -22.99358   ,  -2.786594  , -51.906845  ,
               -7.4670386 ,   8.02406   , -54.158016  ,  19.590197  ,
               27.29