In [1]:
%env XLA_PYTHON_CLIENT_PREALLOCATE=false
from jax.config import config
config.update("jax_enable_x64", True)
from jax import numpy as jnp
import transformers
import torch
from transformers import (
    AutoTokenizer, 
    AutoModelForMaskedLM, 
    BertForMaskedLM, 
    BertTokenizer, 
    BertTokenizerFast, 
    BertEmbeddings,
    BfBertEmbeddings,
    BertConfig,
    BertSelfAttention,
    BfBertSelfAttention,
    BertSelfOutput,
    BfBertSelfOutput,
    BertAttention,
    BfBertAttention,
    BertIntermediate,
    BfBertIntermediate
)
from brunoflow.ad.utils import check_node_equals_tensor, check_node_allclose_tensor
from utils import check_bf_param_weights_match_torch, check_bf_model_outputs_match_torch_outputs, check_bf_param_grads_allclose_torch
torch.manual_seed(0)


env: XLA_PYTHON_CLIENT_PREALLOCATE=false


<torch._C.Generator at 0x2ab03e012910>

In [2]:
# Init torch and bf models
config = BertConfig.from_pretrained(pretrained_model_name_or_path="../../brunoflow/models/bert/config.json")
torch_model = BertIntermediate(config)
bf_model = BfBertIntermediate(config)

In [3]:
# Init inputs to bf and torch models
hidden_states_torch = torch.randn(size=(2, 19, 768))

hidden_states = jnp.array(hidden_states_torch.numpy(), dtype=jnp.float64)

In [4]:
%%time
outputs_torch = torch_model(hidden_states_torch)


CPU times: user 13.7 ms, sys: 2.17 ms, total: 15.9 ms
Wall time: 15.2 ms


In [5]:
%%time
outputs_bf = bf_model(hidden_states)


CPU times: user 587 ms, sys: 277 ms, total: 864 ms
Wall time: 1.08 s


In [6]:
# Check that forward pass for bf works and matches output shape with torch
if isinstance(outputs_bf, (list, tuple)):
    # Handle case where outputs is a tuple/list and not just a single item
    assert len(outputs_bf) == len(outputs_torch)
    for i in range(len(outputs_bf)):
        out_bf, out_torch = outputs_bf[i], outputs_torch[i] 
        assert(out_torch.shape == out_bf.shape)

else:
    assert(outputs_torch.shape == outputs_bf.shape)


In [7]:
# Save torch BertSelfAttention to file
save_path = "bertintermediate_torch.pt"
torch.save(torch_model.state_dict(), save_path)

In [8]:
# Load state dict for BertSelfAttention into BF and check weights, outputs, and backprop
bf_model.load_state_dict(torch.load(save_path))

<All keys matched successfully>

### Check weights of BF model and Torch model match exactly

In [9]:
# Check weights match
check_bf_param_weights_match_torch(bf_model, torch_model)

Value of param weight dense.bias for bf and torch are equal? True
Value of param weight dense.weight for bf and torch are equal? True


### Check model output after forward pass matches for BF and Torch

In [10]:
# Check output from forward passes match for bf and torch
torch_model.train(False)
outputs_bf = bf_model(hidden_states=hidden_states)
outputs_torch = torch_model(hidden_states=hidden_states_torch)

if isinstance(outputs_bf, (list, tuple)):
    assert len(outputs_bf) == len(outputs_torch)
    for i in range(len(outputs_bf)):
        out_bf, out_torch = outputs_bf[i], outputs_torch[i]
        check_bf_model_outputs_match_torch_outputs(out_bf, out_torch, atol=1e-6)
else:
    check_bf_model_outputs_match_torch_outputs(outputs_bf, outputs_torch, atol=1e-6)

Output of bf and torch are within 1e-06? True


### Check grad after backward pass matches for BF and torch

In [11]:
%%time
# Torch backward pass
torch_model.train(True)

if isinstance(outputs_torch, (list, tuple)):
    assert len(outputs_bf) == len(outputs_torch)
    outputs_torch[0].backward(gradient=torch.ones_like(outputs_torch[0]))
else:
    outputs_torch.backward(gradient=torch.ones_like(outputs_torch))

CPU times: user 15.3 ms, sys: 3.77 ms, total: 19 ms
Wall time: 16.6 ms


In [12]:
%%time
# BF backward pass
if isinstance(outputs_bf, (list, tuple)):
    assert len(outputs_bf) == len(outputs_torch)
    outputs_bf[0].backprop(values_to_compute=("grad",))
else:
    outputs_bf.backprop(values_to_compute=("grad",))

CPU times: user 1.42 s, sys: 828 ms, total: 2.25 s
Wall time: 3.54 s


In [13]:
# Run the actual check
check_bf_param_grads_allclose_torch(bf_model, torch_model, atol=1e-5, print_output=True)

Grad of param dense.bias for bf and torch are within 1e-05? True
Grad of param dense.weight for bf and torch are within 1e-05? True


In [14]:
print(bf_model.dense.bias.grad)
print(bf_model.dense.weight.grad)

[17.08071289 19.88536307 16.41011589 ... 22.49699207 18.58312123
 15.14929407]
[[ 4.06865117  1.80183488  3.17271962 ... -1.23772789  0.27722289
  -7.01184791]
 [ 3.83343693 -1.17917966  2.12415311 ... -2.8559995   0.73369748
  -1.34719317]
 [ 4.91057663  0.60769482 -0.23988259 ... -6.17080812 -0.62998463
  -4.87208344]
 ...
 [-1.69226959 -1.22749706  2.55204808 ...  0.11834742 -1.72904177
  -2.54002883]
 [ 3.26504851 -2.34438119 -1.36697079 ... -1.72814144  2.87235288
   0.3346037 ]
 [ 0.10640387 -1.74101877 -1.42268893 ... -5.03681767  0.27159493
  -3.36159435]]


In [15]:
print(torch_model.dense.bias.grad)
print(torch_model.dense.weight.grad)

tensor([17.0807, 19.8854, 16.4101,  ..., 22.4970, 18.5831, 15.1493])
tensor([[ 4.0687,  1.8018,  3.1727,  ..., -1.2377,  0.2772, -7.0118],
        [ 3.8334, -1.1792,  2.1242,  ..., -2.8560,  0.7337, -1.3472],
        [ 4.9106,  0.6077, -0.2399,  ..., -6.1708, -0.6300, -4.8721],
        ...,
        [-1.6923, -1.2275,  2.5520,  ...,  0.1183, -1.7290, -2.5400],
        [ 3.2650, -2.3444, -1.3670,  ..., -1.7281,  2.8724,  0.3346],
        [ 0.1064, -1.7410, -1.4227,  ..., -5.0368,  0.2716, -3.3616]])
