In [1]:
%env XLA_PYTHON_CLIENT_PREALLOCATE=false
from jax.config import config
config.update("jax_enable_x64", True)
from dataclasses import is_dataclass
from jax import numpy as jnp
import transformers
import torch
from transformers import (
    AutoTokenizer, 
    AutoModelForMaskedLM, 
    BertTokenizer, 
    BertTokenizerFast, 
    BertEmbeddings,
    BfBertEmbeddings,
    BertConfig,
    BertSelfAttention,
    BfBertSelfAttention,
    BertSelfOutput,
    BfBertSelfOutput,
    BertAttention,
    BfBertAttention,
    BertLayer,
    BfBertLayer,
    BertEncoder,
    BfBertEncoder,
    BaseModelOutputWithPastAndCrossAttentions,
    BfBaseModelOutputWithPastAndCrossAttentions,
    BertModel,
    BfBertModel,
)
from brunoflow.ad.utils import check_node_equals_tensor, check_node_allclose_tensor
from utils import check_bf_param_weights_match_torch, check_equivalent_class, check_dataclass_keys_match, check_model_outputs_allclose, check_bf_model_outputs_match_torch_outputs, check_bf_param_grads_allclose_torch
torch.manual_seed(0)


env: XLA_PYTHON_CLIENT_PREALLOCATE=false


  from .autonotebook import tqdm as notebook_tqdm


<torch._C.Generator at 0x7ff0e20de6f0>

In [2]:
# Init torch and bf models
BF_FROM_MODEL_ID = False
TORCH_FROM_MODEL_ID = True
# model_id = "bert-base-uncased"
model_id = "google/bert_uncased_L-2_H-128_A-2"
config = BertConfig.from_pretrained(pretrained_model_name_or_path="../../brunoflow/models/bert/config-tiny.json")

torch_model = BertModel.from_pretrained(model_id)
torch_model2 = BertModel(config)

Some weights of the model checkpoint at google/bert_uncased_L-2_H-128_A-2 were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [3]:
# Establish data
tokenizer = BertTokenizerFast.from_pretrained(model_id)
text = ["hello I want to eat some [MASK] meat today. It's thanksgiving [MASK] all!", "yo hi what's up"]
tokens = tokenizer(text, return_tensors="pt", padding=True)

# Create torch and bf inputs to model
input_ids_torch = tokens["input_ids"]
labels_torch = torch.ones_like(input_ids_torch)

In [4]:
%%time
outputs_torch = torch_model(input_ids_torch)
print(type(outputs_torch))


<class 'transformers.modeling_outputs.BaseModelOutputWithPoolingAndCrossAttentions'>
CPU times: user 92.3 ms, sys: 0 ns, total: 92.3 ms
Wall time: 14 ms


In [5]:
# Save torch BertModel to file
save_path = "bertmodel_torch.pt"
torch.save(torch_model.state_dict(), save_path)

In [6]:
# Load state dict for BertSelfAttention into BF and check weights, outputs, and backprop
torch_model2.load_state_dict(torch.load(save_path))

<All keys matched successfully>

### Check weights of BF model and Torch model match exactly

In [7]:
# Check weights match
from torch.nn import Module
def check_torch_param_weights_match_torch(torch_module: Module, torch_module2: Module):
    """Used to verify the weights of the bf model and torch module are equal."""
    torch_params = {name: param for name, param in torch_module.named_parameters()}
    torch_params2 = {name: param for name, param in torch_module2.named_parameters()}
    assert set(torch_params.keys()) == set(
        torch_params2.keys()
    ), f"BF and torch keys do not match: BF contains following extra keys {set(torch_params.keys()).difference(set(torch_params2.keys()))} and is missing keys {set(torch_params.keys()).difference(set(bf_params.keys()))}"

    for name in torch_params.keys():
        print(
            f"Value of param weight {name} for bf and torch are equal? {torch.equal(torch_params[name], torch_params2[name])}"
        )
        assert torch.equal(
            torch_params[name], torch_params2[name]
        ), f"Value of param {name} for bf and torch are not equal."
check_torch_param_weights_match_torch(torch_model, torch_model2)

Value of param weight embeddings.word_embeddings.weight for bf and torch are equal? True
Value of param weight embeddings.position_embeddings.weight for bf and torch are equal? True
Value of param weight embeddings.token_type_embeddings.weight for bf and torch are equal? True
Value of param weight embeddings.LayerNorm.weight for bf and torch are equal? True
Value of param weight embeddings.LayerNorm.bias for bf and torch are equal? True
Value of param weight encoder.layer.0.attention.self.query.weight for bf and torch are equal? True
Value of param weight encoder.layer.0.attention.self.query.bias for bf and torch are equal? True
Value of param weight encoder.layer.0.attention.self.key.weight for bf and torch are equal? True
Value of param weight encoder.layer.0.attention.self.key.bias for bf and torch are equal? True
Value of param weight encoder.layer.0.attention.self.value.weight for bf and torch are equal? True
Value of param weight encoder.layer.0.attention.self.value.bias for bf a

### Check model output after forward pass matches for BF and Torch

In [8]:
# Check output from forward passes match for bf and torch
torch_model.train(False)
torch_model2.train(False)
outputs_torch = torch_model(input_ids_torch)
outputs_torch2 = torch_model2(input_ids_torch)

if isinstance(outputs_torch2, (list, tuple)) or is_dataclass(outputs_torch2):
    assert len(outputs_torch2) == len(outputs_torch)
    for i in range(len(outputs_torch2)):
        out_bf, out_torch = outputs_torch2[i], outputs_torch[i]
        assert torch.allclose(out_bf, out_torch, atol=1e-6)
# elif is_dataclass(outputs_torch2):
#     assert torch.allclose(outputs_torch2, outputs_torch, atol=1e-6)
    # check_model_outputs_allclose(outputs_torch2, outputs_torch, print_stats=True, atol=1e-2)
else:
    assert torch.allclose(outputs_torch2, outputs_torch, atol=1e-6)
    # check_bf_model_outputs_match_torch_outputs(outputs_torch2, outputs_torch, atol=1e-6)

### Check grad after backward pass matches for BF and torch

In [9]:
# Check grads equal before backward passes
torch_model.train(True)
torch_model2.train(True)

def check_torch2_param_grads_allclose_torch(
    torch_module: Module, torch_module2: Module, atol=1e-6
):
    """Used to verify that grad after backward passes for bf and torch are close for all params in the network."""
    bf_params = {name: param for name, param in torch_module2.named_parameters()}
    torch_params = {name: param for name, param in torch_module.named_parameters()}
    assert set(bf_params.keys()) == set(
        torch_params.keys()
    ), f"BF and torch keys do not match: BF contains following extra keys {set(bf_params.keys()).difference(set(torch_params.keys()))} and is missing keys {set(torch_params.keys()).difference(set(bf_params.keys()))}"

    not_allclose_params = []
    for name in bf_params.keys():
        if torch_params[name].grad is None:
            assert bf_params[name].grad is None
            # bf_grad_is_zero = jnp.array_equal(bf_params[name].grad, jnp.zeros_like(bf_params[name].grad))
            # print(f"No grad for param {name} for torch. BF grad is zero? {bf_params.grad is None}")
            # if not bf_grad_is_zero:
            #     not_allclose_params.append(name)
        else:
            assert torch.allclose(torch_params[name].grad, bf_params[name].grad, atol=atol)
            # is_allclose = jnp.allclose(bf_params[name].grad, torch_params[name].grad.numpy(), atol=atol)
            # if print_output:
            #     print(f"Grad of param {name} for bf and torch are within {atol}? {is_allclose}")
            # if not is_allclose:
            #     diff = jnp.abs(bf_params[name].grad - torch_params[name].grad.numpy())
            #     diff_df = pd.DataFrame(diff)
            #     not_allclose_params.append(name)
            #     if print_stats:
            #         print(f"\tStats on diff in grad for {name} between bf and torch: {diff_df.describe()}")

    # if use_assert:
    #     assert not not_allclose_params, f"Grad of params {not_allclose_params} for bf and torch are not within {atol}."


check_torch2_param_grads_allclose_torch(torch_model2, torch_model, atol=1e-6)

In [10]:
%%time
# Torch backward pass
torch_model.train(True)

if isinstance(outputs_torch, (list, tuple)):
    assert len(outputs_torch) == len(outputs_torch2)
    backprop_node_torch = outputs_torch[0]
elif is_dataclass(outputs_torch):
    backprop_node_torch = outputs_torch.last_hidden_state
else:
    backprop_node_torch = outputs_torch
    
backprop_node_torch.backward(gradient=torch.ones_like(backprop_node_torch))

CPU times: user 22.7 ms, sys: 379 ms, total: 401 ms
Wall time: 864 ms


  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


In [11]:
%%time
# Torch2 backward pass
torch_model2.train(True)

if isinstance(outputs_torch2, (list, tuple)):
    assert len(outputs_torch) == len(outputs_torch2)
    backprop_node_torch2 = outputs_torch2[0]
elif is_dataclass(outputs_torch2):
    backprop_node_torch2 = outputs_torch2.last_hidden_state
else:
    backprop_node_torch2 = outputs_torch2
    
backprop_node_torch2.backward(gradient=torch.ones_like(backprop_node_torch))

CPU times: user 4.43 ms, sys: 25.6 ms, total: 30 ms
Wall time: 4.45 ms


In [12]:
# Run the actual check
check_torch2_param_grads_allclose_torch(torch_model, torch_model2, atol=1e-6)

In [14]:
str(torch_model.state_dict()) == str(torch_model2.state_dict())

True

In [15]:
diff = torch.abs(dict(torch_model.named_parameters())["embeddings.word_embeddings.weight"].grad - dict(torch_model2.named_parameters())["embeddings.word_embeddings.weight"].grad)

In [16]:
import pandas as pd
pd.DataFrame(torch.argmax(diff, axis=0)).value_counts()

0    128
dtype: int64

In [17]:
diff.shape

torch.Size([30522, 128])

In [18]:
max(diff[diff.nonzero()]), torch.argmax(diff[diff.nonzero()]), diff[diff.nonzero()][torch.argmax(diff[diff.nonzero()])]

ValueError: max() arg is an empty sequence

In [None]:
import numpy as np
np.array(diff.nonzero())

array([[  101,   101,   101, ..., 15060, 15060, 15060],
       [    0,     1,     2, ...,   125,   126,   127]])

In [None]:
dict(torch_model.named_parameters())["embeddings.word_embeddings.weight"]

Parameter containing:
tensor([[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [-0.0074, -0.0071, -0.0071,  ..., -0.0098, -0.0134,  0.0195],
        [ 0.0149, -0.0253,  0.0369,  ...,  0.0013,  0.0134, -0.0148],
        ...,
        [-0.0044, -0.0172, -0.0067,  ..., -0.0144, -0.0035,  0.0195],
        [-0.0062, -0.0042,  0.0309,  ...,  0.0092, -0.0042,  0.0180],
        [ 0.0138,  0.0062, -0.0471,  ...,  0.0233,  0.0133, -0.0152]],
       requires_grad=True)

In [None]:
weight = dict(torch_model.named_parameters())["embeddings.word_embeddings.weight"]
weight_grads = weight.grad
weight_grads[weight_grads.nonzero()[0]], weight_grads.nonzero()[0]
weight_grads[weight_grads.nonzero()].shape, weight_grads.nonzero().shape

(torch.Size([2816, 2, 128]), torch.Size([2816, 2]))

In [None]:
weight, dict(bf_model.named_parameters())["embeddings.word_embeddings.weight"].val


(Parameter containing:
 tensor([[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0074, -0.0071, -0.0071,  ..., -0.0098, -0.0134,  0.0195],
         [ 0.0149, -0.0253,  0.0369,  ...,  0.0013,  0.0134, -0.0148],
         ...,
         [-0.0044, -0.0172, -0.0067,  ..., -0.0144, -0.0035,  0.0195],
         [-0.0062, -0.0042,  0.0309,  ...,  0.0092, -0.0042,  0.0180],
         [ 0.0138,  0.0062, -0.0471,  ...,  0.0233,  0.0133, -0.0152]],
        requires_grad=True),
 DeviceArray([[ 0.        ,  0.        ,  0.        , ...,  0.        ,
                0.        ,  0.        ],
              [-0.00743891, -0.00712189, -0.00712206, ..., -0.00981665,
               -0.01344507,  0.01948544],
              [ 0.01491467, -0.02525557,  0.03694721, ...,  0.00127098,
                0.01342611, -0.01483316],
              ...,
              [-0.00436839, -0.01722462, -0.00671953, ..., -0.01438164,
               -0.00348806,  0.01950711],
              [-0.00620748, -0.0