In [1]:
%env XLA_PYTHON_CLIENT_PREALLOCATE=false
from jax.config import config
config.update("jax_enable_x64", True)
from jax import numpy as jnp
from dataclasses import is_dataclass
import transformers
import torch
from transformers import (
    AutoTokenizer,
    AutoModelForMaskedLM, 
    BertForMaskedLM, 
    BertTokenizer, 
    BertTokenizerFast, 
    BertEmbeddings,
    BfBertEmbeddings,
    BertConfig,
    BertSelfAttention,
    BfBertSelfAttention,
    BertSelfOutput,
    BfBertSelfOutput,
    BertAttention,
    BfBertAttention,
    BertLayer,
    BfBertLayer,
    BertEncoder,
    BfBertEncoder,
    BaseModelOutputWithPastAndCrossAttentions,
    BfBaseModelOutputWithPastAndCrossAttentions,
    BertOnlyMLMHead,
    BfBertOnlyMLMHead,
    BertPredictionHeadTransform,
    BfBertPredictionHeadTransform,
    BertLMPredictionHead,
    BfBertLMPredictionHead
)
from brunoflow.ad.utils import check_node_equals_tensor, check_node_allclose_tensor
from utils import check_bf_param_weights_match_torch, check_equivalent_class, check_dataclass_keys_match, check_model_outputs_allclose, check_bf_model_outputs_match_torch_outputs, check_bf_param_grads_allclose_torch
torch.manual_seed(0)


env: XLA_PYTHON_CLIENT_PREALLOCATE=false


  from .autonotebook import tqdm as notebook_tqdm


<torch._C.Generator at 0x7f4e4679f6f0>

In [2]:
# Init torch and bf models
config = BertConfig.from_pretrained(pretrained_model_name_or_path="../../brunoflow/models/bert/config-tiny.json")
torch_model = BertLMPredictionHead(config)
bf_model = BfBertLMPredictionHead(config)

2023-01-04 23:08:57.203807: E external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:267] failed call to cuInit: CUDA_ERROR_UNKNOWN: unknown error


In [3]:
# Init inputs to bf and torch models
hidden_states_torch = torch.randn(size=(2, 19, 128))

hidden_states = jnp.array(hidden_states_torch.numpy(), dtype=jnp.float64)


In [4]:
%%time
outputs_torch = torch_model(hidden_states_torch)


CPU times: user 8.6 ms, sys: 2.74 ms, total: 11.3 ms
Wall time: 5.58 ms


In [5]:
type(outputs_torch)

torch.Tensor

In [6]:
%%time
outputs_bf = bf_model(hidden_states)


CPU times: user 471 ms, sys: 61.1 ms, total: 532 ms
Wall time: 373 ms


In [7]:
# Check that forward pass for bf works and matches output shape with torch
if isinstance(outputs_bf, (list, tuple)):
    # Handle case where outputs is a tuple/list and not just a single item
    assert len(outputs_bf) == len(outputs_torch)
    for i in range(len(outputs_bf)):
        out_bf, out_torch = outputs_bf[i], outputs_torch[i] 
        assert(out_torch.shape == out_bf.shape)
elif is_dataclass(outputs_bf):
    check_equivalent_class(outputs_bf, outputs_torch)
    check_dataclass_keys_match(outputs_bf, outputs_torch)
else:
    assert(outputs_torch.shape == outputs_bf.shape)


In [8]:
# Save torch bertonlymlmhead to file
save_path = "bertonlymlmhead_torch.pt"
torch.save(torch_model.state_dict(), save_path)

In [9]:
# Load state dict for bertonlymlmhead into BF and check weights, outputs, and backprop
bf_model.load_state_dict(torch.load(save_path))

<All keys matched successfully>

### Check weights of BF model and Torch model match exactly

In [10]:
# Check weights match
check_bf_param_weights_match_torch(bf_model, torch_model)

Value of param weight bias for bf and torch are equal? True
Value of param weight transform.dense.weight for bf and torch are equal? True
Value of param weight transform.dense.bias for bf and torch are equal? True
Value of param weight transform.LayerNorm.weight for bf and torch are equal? True
Value of param weight transform.LayerNorm.bias for bf and torch are equal? True
Value of param weight decoder.weight for bf and torch are equal? True


### Check model output after forward pass matches for BF and Torch

In [11]:
# Check output from forward passes match for bf and torch
torch_model.train(False)
outputs_bf = bf_model(hidden_states=hidden_states)
outputs_torch = torch_model(hidden_states=hidden_states_torch)

if isinstance(outputs_bf, (list, tuple)):
    assert len(outputs_bf) == len(outputs_torch)
    for i in range(len(outputs_bf)):
        out_bf, out_torch = outputs_bf[i], outputs_torch[i]
        check_bf_model_outputs_match_torch_outputs(out_bf, out_torch, atol=1e-6)
elif is_dataclass(outputs_bf):
    check_model_outputs_allclose(outputs_bf, outputs_torch, print_stats=True, atol=1e-2)
else:
    check_bf_model_outputs_match_torch_outputs(outputs_bf, outputs_torch, atol=1e-4)

Output of bf and torch are within 0.0001? True


### Check grad after backward pass matches for BF and torch

In [12]:
%%time
# Torch backward pass
torch_model.train(True)

if isinstance(outputs_torch, (list, tuple)):
    assert len(outputs_bf) == len(outputs_torch)
    backprop_node_torch = outputs_torch[0]
elif is_dataclass(outputs_torch):
    backprop_node_torch = outputs_torch.last_hidden_state
else:
    backprop_node_torch = outputs_torch
    
backprop_node_torch.backward(gradient=torch.ones_like(backprop_node_torch))

CPU times: user 74 ms, sys: 697 µs, total: 74.7 ms
Wall time: 9.82 ms


  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


In [13]:
%%time 
# BF backward pass

if isinstance(outputs_bf, (list, tuple)):
    assert len(outputs_bf) == len(outputs_torch)
    backprop_node = outputs_bf[0]
elif is_dataclass(outputs_torch):
    backprop_node = outputs_bf.last_hidden_state
else:
    backprop_node = outputs_bf
    
backprop_node.backprop(values_to_compute=("grad",))

CPU times: user 1.97 s, sys: 793 ms, total: 2.76 s
Wall time: 1.96 s


In [16]:
# Run the actual check
check_bf_param_grads_allclose_torch(bf_model, torch_model, atol=1e-3, print_output=True, use_assert=False)

Grad of param bias for bf and torch are within 0.001? True
Grad of param transform.dense.weight for bf and torch are within 0.001? True
Grad of param transform.dense.bias for bf and torch are within 0.001? True
Grad of param transform.LayerNorm.weight for bf and torch are within 0.001? True
Grad of param transform.LayerNorm.bias for bf and torch are within 0.001? True
Grad of param decoder.weight for bf and torch are within 0.001? True
