In [1]:
%env XLA_PYTHON_CLIENT_PREALLOCATE=false
from jax.config import config
config.update("jax_enable_x64", True)
from jax import numpy as jnp
import transformers
import torch
from transformers import (
    AutoTokenizer, 
    AutoModelForMaskedLM, 
    BertForMaskedLM, 
    BertTokenizer, 
    BertTokenizerFast, 
    BertEmbeddings,
    BfBertEmbeddings,
    BertConfig,
    BertSelfAttention,
    BfBertSelfAttention,
    BertSelfOutput,
    BfBertSelfOutput,
    BertOutput,
    BfBertOutput
)
from brunoflow.ad.utils import check_node_equals_tensor, check_node_allclose_tensor
from utils import check_bf_param_weights_match_torch, check_bf_model_outputs_match_torch_outputs, check_bf_param_grads_allclose_torch
torch.manual_seed(0)


  from .autonotebook import tqdm as notebook_tqdm


<torch._C.Generator at 0x7fdb960d06b0>

In [2]:
# Init torch and bf models
config = BertConfig.from_pretrained(pretrained_model_name_or_path="../../brunoflow/models/bert/config.json")
torch_model = BertOutput(config)
bf_model = BfBertOutput(config)

2022-12-24 00:37:37.275919: E external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:267] failed call to cuInit: CUDA_ERROR_UNKNOWN: unknown error


In [3]:
# Init inputs to bf and torch models
hidden_states_torch = torch.randn(size=(2, 19, 3072))
input_tensor_torch = torch.randn(size=(2, 19, 768))

hidden_states = jnp.array(hidden_states_torch.numpy(), dtype=jnp.float64)
input_tensor = jnp.array(input_tensor_torch.numpy(), dtype=jnp.float64)

In [4]:
hidden_states.dtype

dtype('float64')

In [5]:
%time
outputs_torch = torch_model(hidden_states_torch, input_tensor_torch)


CPU times: user 2 µs, sys: 1 µs, total: 3 µs
Wall time: 5.01 µs


In [6]:
%time
outputs_bf = bf_model(hidden_states, input_tensor)


CPU times: user 2 µs, sys: 0 ns, total: 2 µs
Wall time: 4.29 µs


In [7]:
# Check that forward pass for bf works and matches output shape with torch
assert(outputs_torch.shape == outputs_bf.shape)


In [8]:
# Save torch BertSelfAttention to file
save_path = "bertoutput_torch.pt"
torch.save(torch_model.state_dict(), save_path)

In [9]:
# Load state dict for BertSelfAttention into BF and check weights, outputs, and backprop
bf_model.load_state_dict(torch.load(save_path))

<All keys matched successfully>

### Check weights of BF model and Torch model match exactly

In [10]:
# Check weights match
check_bf_param_weights_match_torch(bf_model, torch_model)

Value of param weight dense.bias for bf and torch are equal? True
Value of param weight dense.weight for bf and torch are equal? True
Value of param weight LayerNorm.weight for bf and torch are equal? True
Value of param weight LayerNorm.bias for bf and torch are equal? True


### Check model output after forward pass matches for BF and Torch

In [11]:
# Check output from forward passes match for bf and torch
torch_model.train(False)
outputs_bf = bf_model(hidden_states=hidden_states, input_tensor=input_tensor)
outputs_torch = torch_model(hidden_states=hidden_states_torch, input_tensor=input_tensor_torch)

if isinstance(outputs_bf, (list, tuple)):
    assert len(outputs_bf) == len(outputs_torch)
    for i in range(len(outputs_bf)):
        out_bf, out_torch = outputs_bf[i], outputs_torch[i]
        check_bf_model_outputs_match_torch_outputs(out_bf, out_torch, atol=1e-6)
else:
    check_bf_model_outputs_match_torch_outputs(outputs_bf, outputs_torch, atol=1e-6)

Output of bf and torch are within 1e-06? True


### Check grad after backward pass matches for BF and torch

In [12]:
%time
# Torch backward pass
torch_model.train(True)
outputs_torch.backward(gradient=torch.ones_like(outputs_torch))

CPU times: user 2 µs, sys: 1 µs, total: 3 µs
Wall time: 7.15 µs


  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


In [13]:
%time 
# BF backward pass
outputs_bf.backprop(values_to_compute=("grad",))

CPU times: user 2 µs, sys: 1 µs, total: 3 µs
Wall time: 5.48 µs


: 

: 

In [None]:
# Run the actual check
check_bf_param_grads_allclose_torch(bf_model, torch_model, atol=1e-5, print_output=True)

Grad of param dense.bias for bf and torch are within 1e-05? True
Grad of param dense.weight for bf and torch are within 1e-05? True
Grad of param LayerNorm.weight for bf and torch are within 1e-05? True
Grad of param LayerNorm.bias for bf and torch are within 1e-05? True
