In [1]:
from jax.config import config
config.update("jax_enable_x64", True)
from jax import numpy as jnp
import transformers
import torch
from transformers import (
    AutoTokenizer, 
    AutoModelForMaskedLM, 
    BertForMaskedLM, 
    BertTokenizer, 
    BertTokenizerFast, 
    BertEmbeddings,
    BfBertEmbeddings,
    BertConfig,
    BertSelfAttention,
    BfBertSelfAttention,
    BertSelfOutput,
    BfBertSelfOutput,
    BertAttention,
    BfBertAttention,
    BertIntermediate,
    BfBertIntermediate
)
from brunoflow.ad.utils import check_node_equals_tensor, check_node_allclose_tensor
from utils import check_bf_param_weights_match_torch, check_bf_model_outputs_match_torch_outputs, check_bf_param_grads_allclose_torch
torch.manual_seed(0)


  from .autonotebook import tqdm as notebook_tqdm


<torch._C.Generator at 0x7f1ea3b236f0>

In [2]:
# Init torch and bf models
config = BertConfig.from_pretrained(pretrained_model_name_or_path="../../brunoflow/models/bert/config.json")
torch_model = BertIntermediate(config)
bf_model = BfBertIntermediate(config)

2022-12-23 23:51:08.681952: E external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:267] failed call to cuInit: CUDA_ERROR_UNKNOWN: unknown error


In [3]:
# Init inputs to bf and torch models
hidden_states_torch = torch.randn(size=(2, 19, 768))

hidden_states = jnp.array(hidden_states_torch.numpy(), dtype=jnp.float64)

In [4]:
%time
outputs_torch = torch_model(hidden_states_torch)


CPU times: user 2 µs, sys: 1 µs, total: 3 µs
Wall time: 6.2 µs


In [5]:
%time
outputs_bf = bf_model(hidden_states)


CPU times: user 3 µs, sys: 2 µs, total: 5 µs
Wall time: 6.68 µs


In [6]:
# Check that forward pass for bf works and matches output shape with torch
if isinstance(outputs_bf, (list, tuple)):
    # Handle case where outputs is a tuple/list and not just a single item
    assert len(outputs_bf) == len(outputs_torch)
    for i in range(len(outputs_bf)):
        out_bf, out_torch = outputs_bf[i], outputs_torch[i] 
        assert(out_torch.shape == out_bf.shape)

else:
    assert(outputs_torch.shape == outputs_bf.shape)


In [7]:
# Save torch BertSelfAttention to file
save_path = "bertintermediate_torch.pt"
torch.save(torch_model.state_dict(), save_path)

In [8]:
# Load state dict for BertSelfAttention into BF and check weights, outputs, and backprop
bf_model.load_state_dict(torch.load(save_path))

<All keys matched successfully>

### Check weights of BF model and Torch model match exactly

In [9]:
# Check weights match
check_bf_param_weights_match_torch(bf_model, torch_model)

Value of param weight dense.bias for bf and torch are equal? True
Value of param weight dense.weight for bf and torch are equal? True


### Check model output after forward pass matches for BF and Torch

In [10]:
# Check output from forward passes match for bf and torch
torch_model.train(False)
outputs_bf = bf_model(hidden_states=hidden_states)
outputs_torch = torch_model(hidden_states=hidden_states_torch)

if isinstance(outputs_bf, (list, tuple)):
    assert len(outputs_bf) == len(outputs_torch)
    for i in range(len(outputs_bf)):
        out_bf, out_torch = outputs_bf[i], outputs_torch[i]
        check_bf_model_outputs_match_torch_outputs(out_bf, out_torch, atol=1e-6)
else:
    check_bf_model_outputs_match_torch_outputs(outputs_bf, outputs_torch, atol=1e-6)

Output of bf and torch are within 1e-06? True


### Check grad after backward pass matches for BF and torch

In [11]:
%time
# Torch backward pass
torch_model.train(True)

if isinstance(outputs_torch, (list, tuple)):
    assert len(outputs_bf) == len(outputs_torch)
    outputs_torch[0].backward(gradient=torch.ones_like(outputs_torch[0]))
else:
    outputs_torch.backward(gradient=torch.ones_like(outputs_torch))

CPU times: user 4 µs, sys: 1 µs, total: 5 µs
Wall time: 12.4 µs


  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


In [12]:
%time 
# BF backward pass
if isinstance(outputs_bf, (list, tuple)):
    assert len(outputs_bf) == len(outputs_torch)
    outputs_bf[0].backprop(values_to_compute=("grad",))
else:
    outputs_bf.backprop(values_to_compute=("grad",))

CPU times: user 4 µs, sys: 0 ns, total: 4 µs
Wall time: 11 µs


XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory allocating 54509174784 bytes.

In [None]:
# Run the actual check
check_bf_param_grads_allclose_torch(bf_model, torch_model, atol=1e-5, print_output=True)

Grad of param self.query.bias for bf and torch are within 1e-05? True
Grad of param self.query.weight for bf and torch are within 1e-05? True
Grad of param self.key.bias for bf and torch are within 1e-05? True
Grad of param self.key.weight for bf and torch are within 1e-05? True
Grad of param self.value.bias for bf and torch are within 1e-05? True
Grad of param self.value.weight for bf and torch are within 1e-05? True
Grad of param output.dense.bias for bf and torch are within 1e-05? True
Grad of param output.dense.weight for bf and torch are within 1e-05? True
Grad of param output.LayerNorm.weight for bf and torch are within 1e-05? True
Grad of param output.LayerNorm.bias for bf and torch are within 1e-05? True
