In [1]:
import brunoflow as bf
from brunoflow.ad.utils import check_node_equals_tensor, check_node_allclose_tensor
from jax import numpy as jnp
import transformers
import torch
from transformers import (
    AutoTokenizer, 
    AutoModelForMaskedLM, 
    BertForMaskedLM, 
    BertTokenizer, 
    BertTokenizerFast, 
    BertEmbeddings,
    BfBertEmbeddings,
    BertConfig,
)

torch.manual_seed(0)

<torch._C.Generator at 0x2b947240b3b0>

In [2]:
# Establish data
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
text = ["hello I want to eat some [MASK] meat today. It's thanksgiving [MASK] all!", "yo yo what's up"]

# tokenize text and pass into model
tokens = tokenizer(text, return_tensors="pt", padding=True)
input_ids = tokens["input_ids"]
print(input_ids)

tensor([[  101,  7592,  1045,  2215,  2000,  4521,  2070,   103,  6240,  2651,
          1012,  2009,  1005,  1055, 15060,   103,  2035,   999,   102],
        [  101, 10930, 10930,  2054,  1005,  1055,  2039,   102,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0]])


In [3]:
# Create BfBertEmbeddings and BertEmbeddings
config = BertConfig.from_pretrained(pretrained_model_name_or_path="../../brunoflow/models/bert/config.json")
bf_embs = BfBertEmbeddings(config)
torch_embs = BertEmbeddings(config)
print(bf_embs)
print(torch_embs)

2022-12-23 00:40:47.961439: W external/org_tensorflow/tensorflow/compiler/xla/service/platform_util.cc:193] unable to create StreamExecutor for CUDA:0: failed initializing StreamExecutor for CUDA device ordinal 0: INTERNAL: failed call to cuDevicePrimaryCtxRetain: CUDA_ERROR_DEVICE_UNAVAILABLE: CUDA-capable device(s) is/are busy or unavailable
  self.register_buffer("token_type_ids", bf.Node(jnp.zeros(self.position_ids.shape, dtype=jnp.int64)), persistent=False) # todo is this 64 bit necessary?


BfBertEmbeddings(
  (word_embeddings): Embedding(30522, 768, padding_idx=0)
  (position_embeddings): Embedding(512, 768)
  (token_type_embeddings): Embedding(2, 768)
  (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
  (dropout): Dropout(p=0.1)
)
BertEmbeddings(
  (word_embeddings): Embedding(30522, 768, padding_idx=0)
  (position_embeddings): Embedding(512, 768)
  (token_type_embeddings): Embedding(2, 768)
  (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
  (dropout): Dropout(p=0.1, inplace=False)
)


In [4]:
# Save torch BertEmbeddings to file
save_path = "bertembeddings_torch.pt"
torch.save(torch_embs.state_dict(), save_path)

In [5]:
# Load torch BertEmbeddings into bf
bf_embs.load_state_dict(torch.load(save_path))

<All keys matched successfully>

In [6]:
bf_embs

BfBertEmbeddings(
  (word_embeddings): Embedding(30522, 768, padding_idx=0)
  (position_embeddings): Embedding(512, 768)
  (token_type_embeddings): Embedding(2, 768)
  (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
  (dropout): Dropout(p=0.1)
)

In [7]:
# Check that bert embedding values loaded correctly into bf and they match the torch vals 
assert(check_node_equals_tensor(bf_embs.word_embeddings.weight, torch_embs.word_embeddings.weight))
assert(check_node_equals_tensor(bf_embs.position_embeddings.weight, torch_embs.position_embeddings.weight))
assert(check_node_equals_tensor(bf_embs.token_type_embeddings.weight, torch_embs.token_type_embeddings.weight))
assert(check_node_equals_tensor(bf_embs.LayerNorm.weight, torch_embs.LayerNorm.weight))
# print(check_node_equals_tensor(bf_embs.dropout.weight, torch_embs.dropout.weight)) # this fails because dropout has no weights, I guess


In [8]:
# Compare output of forward pass of BfBertEmbeddings and BertEmbeddings on the text - they're equal!
jax_input_ids = jnp.array(input_ids.numpy(), dtype=int)
torch_embs.train(False)
out_bf = bf_embs(input_ids=jax_input_ids)
out_torch = torch_embs(input_ids=input_ids)
# print(out_bf.val)
# print(out_torch)
assert(check_node_allclose_tensor(out_bf, out_torch))



### Compare grads of parameters between torch and bf after a backward pass

In [9]:
%time
# Torch backward pass
torch_embs.train(True)
out_torch.backward(gradient=torch.ones_like(out_torch))

CPU times: user 3 µs, sys: 1 µs, total: 4 µs
Wall time: 6.91 µs


In [10]:
%time 
# BF backward pass
out_bf.backprop(values_to_compute=("grad",))

CPU times: user 5 µs, sys: 1 µs, total: 6 µs
Wall time: 12.6 µs




In [11]:
# print("word_embeddings:", bf_embs.word_embeddings.weight.grad, torch_embs.word_embeddings.weight.grad)
# print("position_embeddings:", bf_embs.position_embeddings.weight.grad, torch_embs.position_embeddings.weight.grad)
# print("token_type_embeddings:", bf_embs.token_type_embeddings.weight.grad, torch_embs.token_type_embeddings.weight.grad)
# print("LayerNorm:", bf_embs.LayerNorm.weight.grad, torch_embs.LayerNorm.weight.grad)

In [12]:
bf_emb_params = {name: param for name, param in bf_embs.named_parameters()}
torch_emb_params = {name: param for name, param in torch_embs.named_parameters()}
assert set(bf_emb_params.keys()) == set(torch_emb_params.keys())

for name in bf_emb_params.keys():
    print(f"Grad of param {name} for bf and torch are within 1e-6? {jnp.allclose(bf_emb_params[name].grad, torch_emb_params[name].grad.numpy(), atol=1e-6)}")
    assert jnp.allclose(bf_emb_params[name].grad, torch_emb_params[name].grad.numpy(), atol=1e-6), f"Grad of param {name} for bf and torch are not within 1e-6."

Grad of param word_embeddings.weight for bf and torch are within 1e-6? True
Grad of param position_embeddings.weight for bf and torch are within 1e-6? True
Grad of param token_type_embeddings.weight for bf and torch are within 1e-6? True
Grad of param LayerNorm.weight for bf and torch are within 1e-6? True
Grad of param LayerNorm.bias for bf and torch are within 1e-6? True
