In [1]:
%env XLA_PYTHON_CLIENT_PREALLOCATE=false
from jax.config import config
config.update("jax_enable_x64", False)
from jax import numpy as jnp
from dataclasses import is_dataclass
import transformers
import torch
from transformers import (
    AutoTokenizer, 
    AutoModelForMaskedLM, 
    BertForMaskedLM, 
    BertTokenizer, 
    BertTokenizerFast, 
    BertEmbeddings,
    BfBertEmbeddings,
    BertConfig,
    BertSelfAttention,
    BfBertSelfAttention,
    BertSelfOutput,
    BfBertSelfOutput,
    BertAttention,
    BfBertAttention,
    BertLayer,
    BfBertLayer,
    BertEncoder,
    BfBertEncoder,
    BaseModelOutputWithPastAndCrossAttentions,
    BfBaseModelOutputWithPastAndCrossAttentions,
)
from brunoflow.ad.utils import check_node_equals_tensor, check_node_allclose_tensor
from utils import check_bf_param_weights_match_torch, check_equivalent_class, check_dataclass_keys_match, check_model_outputs_allclose, check_bf_model_outputs_match_torch_outputs, check_bf_param_grads_allclose_torch
torch.manual_seed(0)


env: XLA_PYTHON_CLIENT_PREALLOCATE=false


<torch._C.Generator at 0x2ab1a7002790>

In [2]:
# Init torch and bf models
config = BertConfig.from_pretrained(pretrained_model_name_or_path="../../brunoflow/models/bert/config.json")
torch_model = BertEncoder(config)
bf_model = BfBertEncoder(config)

In [3]:
# Init inputs to bf and torch models
hidden_states_torch = torch.randn(size=(2, 19, 768))
attention_mask_torch = torch.randn(size=(2, 1, 1, 19))

hidden_states = jnp.array(hidden_states_torch.numpy(), dtype=jnp.float64)
attention_mask = jnp.array(attention_mask_torch.numpy(), dtype=jnp.float64)

  hidden_states = jnp.array(hidden_states_torch.numpy(), dtype=jnp.float64)
  attention_mask = jnp.array(attention_mask_torch.numpy(), dtype=jnp.float64)


In [4]:
%%time
outputs_torch = torch_model(hidden_states_torch, attention_mask_torch)


CPU times: user 461 ms, sys: 19.5 ms, total: 481 ms
Wall time: 485 ms


In [5]:
type(outputs_torch)

transformers.modeling_outputs.BaseModelOutputWithPastAndCrossAttentions

In [6]:
%%time
outputs_bf = bf_model(hidden_states, attention_mask)


CPU times: user 3.54 s, sys: 1.46 s, total: 4.99 s
Wall time: 7.75 s


In [7]:
# Check that forward pass for bf works and matches output shape with torch
if isinstance(outputs_bf, (list, tuple)):
    # Handle case where outputs is a tuple/list and not just a single item
    assert len(outputs_bf) == len(outputs_torch)
    for i in range(len(outputs_bf)):
        out_bf, out_torch = outputs_bf[i], outputs_torch[i] 
        assert(out_torch.shape == out_bf.shape)
elif is_dataclass(outputs_bf):
    check_equivalent_class(outputs_bf, outputs_torch)
    check_dataclass_keys_match(outputs_bf, outputs_torch)
else:
    assert(outputs_torch.shape == outputs_bf.shape)


In [8]:
# Save torch BertSelfAttention to file
save_path = "bertlayer_torch.pt"
torch.save(torch_model.state_dict(), save_path)

In [9]:
# Load state dict for BertSelfAttention into BF and check weights, outputs, and backprop
bf_model.load_state_dict(torch.load(save_path))

<All keys matched successfully>

### Check weights of BF model and Torch model match exactly

In [10]:
# Check weights match
check_bf_param_weights_match_torch(bf_model, torch_model)

Value of param weight layer.0.attention.self.query.bias for bf and torch are equal? True
Value of param weight layer.0.attention.self.query.weight for bf and torch are equal? True
Value of param weight layer.0.attention.self.key.bias for bf and torch are equal? True
Value of param weight layer.0.attention.self.key.weight for bf and torch are equal? True
Value of param weight layer.0.attention.self.value.bias for bf and torch are equal? True
Value of param weight layer.0.attention.self.value.weight for bf and torch are equal? True
Value of param weight layer.0.attention.output.dense.bias for bf and torch are equal? True
Value of param weight layer.0.attention.output.dense.weight for bf and torch are equal? True
Value of param weight layer.0.attention.output.LayerNorm.weight for bf and torch are equal? True
Value of param weight layer.0.attention.output.LayerNorm.bias for bf and torch are equal? True
Value of param weight layer.0.intermediate.dense.bias for bf and torch are equal? True
V

Value of param weight layer.6.attention.self.value.bias for bf and torch are equal? True
Value of param weight layer.6.attention.self.value.weight for bf and torch are equal? True
Value of param weight layer.6.attention.output.dense.bias for bf and torch are equal? True
Value of param weight layer.6.attention.output.dense.weight for bf and torch are equal? True
Value of param weight layer.6.attention.output.LayerNorm.weight for bf and torch are equal? True
Value of param weight layer.6.attention.output.LayerNorm.bias for bf and torch are equal? True
Value of param weight layer.6.intermediate.dense.bias for bf and torch are equal? True
Value of param weight layer.6.intermediate.dense.weight for bf and torch are equal? True
Value of param weight layer.6.output.dense.bias for bf and torch are equal? True
Value of param weight layer.6.output.dense.weight for bf and torch are equal? True
Value of param weight layer.6.output.LayerNorm.weight for bf and torch are equal? True
Value of param we

### Check model output after forward pass matches for BF and Torch

In [17]:
# Check output from forward passes match for bf and torch
torch_model.train(False)
outputs_bf = bf_model(hidden_states=hidden_states, attention_mask=attention_mask)
outputs_torch = torch_model(hidden_states=hidden_states_torch, attention_mask=attention_mask_torch)

if isinstance(outputs_bf, (list, tuple)):
    assert len(outputs_bf) == len(outputs_torch)
    for i in range(len(outputs_bf)):
        out_bf, out_torch = outputs_bf[i], outputs_torch[i]
        check_bf_model_outputs_match_torch_outputs(out_bf, out_torch, atol=1e-6)
elif is_dataclass(outputs_bf):
    check_model_outputs_allclose(outputs_bf, outputs_torch, print_stats=True, atol=1e-2)
else:
    check_bf_model_outputs_match_torch_outputs(outputs_bf, outputs_torch, atol=1e-6)

Checking diff between BF and torch for last_hidden_state:
Output of bf and torch are within 0.01? True
	Stats on diff in outputs between bf and torch:                   0
count  29184.000000
mean       0.000298
std        0.000224
min        0.000000
25%        0.000119
50%        0.000252
75%        0.000430
max        0.001816


### Check grad after backward pass matches for BF and torch

In [12]:
%%time
# Torch backward pass
torch_model.train(True)

if isinstance(outputs_torch, (list, tuple)):
    assert len(outputs_bf) == len(outputs_torch)
    backprop_node_torch = outputs_torch[0]
elif is_dataclass(outputs_torch):
    backprop_node_torch = outputs_torch.last_hidden_state
else:
    backprop_node_torch = outputs_torch
    
backprop_node_torch.backward(gradient=torch.ones_like(backprop_node_torch))

CPU times: user 775 ms, sys: 86.5 ms, total: 861 ms
Wall time: 862 ms


In [13]:
%%time 
# BF backward pass

if isinstance(outputs_bf, (list, tuple)):
    assert len(outputs_bf) == len(outputs_torch)
    backprop_node = outputs_bf[0]
elif is_dataclass(outputs_torch):
    backprop_node = outputs_bf.last_hidden_state
else:
    backprop_node = outputs_bf
    
backprop_node.backprop(values_to_compute=("grad",))

CPU times: user 8.34 s, sys: 5.73 s, total: 14.1 s
Wall time: 22.9 s


In [14]:
# Run the actual check
check_bf_param_grads_allclose_torch(bf_model, torch_model, atol=1e-5, print_output=True, use_assert=False)

Grad of param layer.0.attention.self.query.bias for bf and torch are within 1e-05? True
Grad of param layer.0.attention.self.query.weight for bf and torch are within 1e-05? True
Grad of param layer.0.attention.self.key.bias for bf and torch are within 1e-05? True
Grad of param layer.0.attention.self.key.weight for bf and torch are within 1e-05? True
Grad of param layer.0.attention.self.value.bias for bf and torch are within 1e-05? True
Grad of param layer.0.attention.self.value.weight for bf and torch are within 1e-05? True
Grad of param layer.0.attention.output.dense.bias for bf and torch are within 1e-05? True
Grad of param layer.0.attention.output.dense.weight for bf and torch are within 1e-05? True
Grad of param layer.0.attention.output.LayerNorm.weight for bf and torch are within 1e-05? True
Grad of param layer.0.attention.output.LayerNorm.bias for bf and torch are within 1e-05? True
Grad of param layer.0.intermediate.dense.bias for bf and torch are within 1e-05? True
Grad of para

	Stats on diff in grad for layer.11.output.LayerNorm.weight between bf and torch:                 0
count  768.000000
mean     0.005615
std      0.004385
min      0.000000
25%      0.002037
50%      0.004748
75%      0.008144
max      0.024424
Grad of param layer.11.output.LayerNorm.bias for bf and torch are within 1e-05? True


In [15]:
print(dict(bf_model.named_parameters())["layer.11.output.LayerNorm.bias"].grad)

[38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38.
 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38.
 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38.
 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38.
 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38.
 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38.
 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38.
 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38.
 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38.
 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38.
 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38.
 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38.
 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38.
 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38. 38

In [16]:
print(dict(torch_model.named_parameters())["layer.11.output.LayerNorm.bias"].grad)

tensor([38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38.,
        38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38.,
        38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38.,
        38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38.,
        38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38.,
        38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38.,
        38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38.,
        38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38.,
        38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38.,
        38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38.,
        38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38.,
        38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38.,
        38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 3