In [1]:
from jax.config import config
config.update("jax_enable_x64", True)
from jax import numpy as jnp
import transformers
import torch
from transformers import (
    AutoTokenizer, 
    AutoModelForMaskedLM, 
    BertForMaskedLM, 
    BertTokenizer, 
    BertTokenizerFast, 
    BertEmbeddings,
    BfBertEmbeddings,
    BertConfig,
    BertSelfAttention,
    BfBertSelfAttention,
    BertSelfOutput,
    BfBertSelfOutput,
    BertAttention,
    BfBertAttention,
)
from brunoflow.ad.utils import check_node_equals_tensor, check_node_allclose_tensor
from utils import check_bf_param_weights_match_torch, check_bf_model_outputs_match_torch_outputs, check_bf_param_grads_allclose_torch
torch.manual_seed(0)


<torch._C.Generator at 0x2b74a5272730>

In [2]:
# Init torch and bf models
config = BertConfig.from_pretrained(pretrained_model_name_or_path="../../brunoflow/models/bert/config.json")
torch_model = BertAttention(config)
bf_model = BfBertAttention(config)

2022-12-24 01:14:05.873009: W external/org_tensorflow/tensorflow/compiler/xla/service/platform_util.cc:193] unable to create StreamExecutor for CUDA:0: failed initializing StreamExecutor for CUDA device ordinal 0: INTERNAL: failed call to cuDevicePrimaryCtxRetain: CUDA_ERROR_DEVICE_UNAVAILABLE: CUDA-capable device(s) is/are busy or unavailable


In [3]:
# Init inputs to bf and torch models
hidden_states_torch = torch.randn(size=(2, 19, 768))
attention_mask_torch = torch.randn(size=(2, 1, 1, 19))

hidden_states = jnp.array(hidden_states_torch.numpy(), dtype=jnp.float64)
attention_mask = jnp.array(attention_mask_torch.numpy(), dtype=jnp.float64)

In [4]:
%%time
outputs_torch = torch_model(hidden_states_torch, attention_mask_torch)


CPU times: user 14.8 ms, sys: 1.81 ms, total: 16.7 ms
Wall time: 15.8 ms


In [5]:
%%time
outputs_bf = bf_model(hidden_states, attention_mask)


CPU times: user 743 ms, sys: 19.9 ms, total: 763 ms
Wall time: 737 ms


In [6]:
# Check that forward pass for bf works and matches output shape with torch
if isinstance(outputs_bf, (list, tuple)):
    # Handle case where outputs is a tuple/list and not just a single item
    assert len(outputs_bf) == len(outputs_torch)
    for i in range(len(outputs_bf)):
        out_bf, out_torch = outputs_bf[i], outputs_torch[i] 
        assert(out_torch.shape == out_bf.shape)

else:
    assert(outputs_torch.shape == outputs_bf.shape)


In [7]:
# Save torch BertSelfAttention to file
save_path = "bertattention_torch.pt"
torch.save(torch_model.state_dict(), save_path)

In [8]:
# Load state dict for BertSelfAttention into BF and check weights, outputs, and backprop
bf_model.load_state_dict(torch.load(save_path))

<All keys matched successfully>

### Check weights of BF model and Torch model match exactly

In [9]:
# Check weights match
check_bf_param_weights_match_torch(bf_model, torch_model)

Value of param weight self.query.bias for bf and torch are equal? True
Value of param weight self.query.weight for bf and torch are equal? True
Value of param weight self.key.bias for bf and torch are equal? True
Value of param weight self.key.weight for bf and torch are equal? True
Value of param weight self.value.bias for bf and torch are equal? True
Value of param weight self.value.weight for bf and torch are equal? True
Value of param weight output.dense.bias for bf and torch are equal? True
Value of param weight output.dense.weight for bf and torch are equal? True
Value of param weight output.LayerNorm.weight for bf and torch are equal? True
Value of param weight output.LayerNorm.bias for bf and torch are equal? True


### Check model output after forward pass matches for BF and Torch

In [10]:
# Check output from forward passes match for bf and torch
torch_model.train(False)
outputs_bf = bf_model(hidden_states=hidden_states, attention_mask=attention_mask)
outputs_torch = torch_model(hidden_states=hidden_states_torch, attention_mask=attention_mask_torch)

if isinstance(outputs_bf, (list, tuple)):
    assert len(outputs_bf) == len(outputs_torch)
    for i in range(len(outputs_bf)):
        out_bf, out_torch = outputs_bf[i], outputs_torch[i]
        check_bf_model_outputs_match_torch_outputs(out_bf, out_torch, atol=1e-6)
else:
    check_bf_model_outputs_match_torch_outputs(outputs_bf, outputs_torch, atol=1e-6)

Output of bf and torch are within 1e-06? True


### Check grad after backward pass matches for BF and torch

In [11]:
%%time
# Torch backward pass
torch_model.train(True)
outputs_torch[0].backward(gradient=torch.ones_like(outputs_torch[0]))

CPU times: user 14.5 ms, sys: 1.08 ms, total: 15.6 ms
Wall time: 14.5 ms


In [12]:
%%time
# BF backward pass
outputs_bf[0].backprop(values_to_compute=("grad",))

CPU times: user 8.36 s, sys: 16.9 s, total: 25.3 s
Wall time: 19.6 s


In [13]:
# Run the actual check
check_bf_param_grads_allclose_torch(bf_model, torch_model, atol=1e-5, print_output=True)

Grad of param self.query.bias for bf and torch are within 1e-05? True
Grad of param self.query.weight for bf and torch are within 1e-05? True
Grad of param self.key.bias for bf and torch are within 1e-05? True
Grad of param self.key.weight for bf and torch are within 1e-05? True
Grad of param self.value.bias for bf and torch are within 1e-05? True
Grad of param self.value.weight for bf and torch are within 1e-05? True
Grad of param output.dense.bias for bf and torch are within 1e-05? True
Grad of param output.dense.weight for bf and torch are within 1e-05? True
Grad of param output.LayerNorm.weight for bf and torch are within 1e-05? True
Grad of param output.LayerNorm.bias for bf and torch are within 1e-05? True
