In [1]:
%env XLA_PYTHON_CLIENT_PREALLOCATE=false
%env JAX_PLATFORM_NAME=cpu

from jax.config import config
config.update("jax_enable_x64", True)
from dataclasses import is_dataclass
from jax import numpy as jnp
import transformers
import torch
from transformers import (
    AutoTokenizer, 
    AutoModelForMaskedLM, 
    BertTokenizer, 
    BertTokenizerFast, 
    BertEmbeddings,
    BfBertEmbeddings,
    BertConfig,
    BertSelfAttention,
    BfBertSelfAttention,
    BertSelfOutput,
    BfBertSelfOutput,
    BertAttention,
    BfBertAttention,
    BertLayer,
    BfBertLayer,
    BertEncoder,
    BfBertEncoder,
    BaseModelOutputWithPastAndCrossAttentions,
    BfBaseModelOutputWithPastAndCrossAttentions,
    BertForMaskedLM,
    BfBertForMaskedLM,
    BertForSequenceClassification,
    BfBertForSequenceClassification,
)
from brunoflow.ad.utils import check_node_equals_tensor, check_node_allclose_tensor
from utils import check_bf_param_weights_match_torch, check_equivalent_class, check_dataclass_keys_match, check_model_outputs_allclose, check_bf_model_outputs_match_torch_outputs, check_bf_param_grads_allclose_torch
torch.manual_seed(0)


env: XLA_PYTHON_CLIENT_PREALLOCATE=false
env: JAX_PLATFORM_NAME=cpu


  from .autonotebook import tqdm as notebook_tqdm


<torch._C.Generator at 0x7fc832be1750>

In [2]:
config = BertConfig.from_pretrained(pretrained_model_name_or_path="../../brunoflow/models/bert/config-toy.json")
m = BertForSequenceClassification(config=config)
sum(p.numel() for p in m.parameters())

9378

In [4]:
m = BfBertForSequenceClassification(config=config)




In [4]:
# Init torch and bf models
BF_FROM_MODEL_ID = False ### NOTE: BECAUSE THIS IS SUPER HACKY THIS SOMEWHAT DOES NOT WORK WHEN SET TO TRUE. FROM_PRETRAINED FOR BRUNOFLOW IS PROBABLY SOMEWHAT BROKEN, BUT AT LEAST THIS IS A WORKAROUND. Also it looks like the errors are only bounded by 0.01 :/.
TORCH_FROM_MODEL_ID = True
# model_id = "bert-base-uncased"
model_id = "google/bert_uncased_L-2_H-128_A-2"
config = BertConfig.from_pretrained(pretrained_model_name_or_path="../../brunoflow/models/bert/config-tiny.json")

if TORCH_FROM_MODEL_ID:
    torch_model = BertForSequenceClassification.from_pretrained(model_id)
else:
    torch_model = BertForSequenceClassification(config)
if BF_FROM_MODEL_ID:
    bf_model = BfBertForSequenceClassification.from_pretrained(model_id)
else:
    bf_model = BfBertForSequenceClassification(config)

sum(p.numel() for p in torch_model.parameters())

Some weights of the model checkpoint at google/bert_uncased_L-2_H-128_A-2 were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification w



4386178

In [5]:
# Establish data
tokenizer = BertTokenizerFast.from_pretrained(model_id)
text = ["hello I want to eat some [MASK] meat today. It's thanksgiving [MASK] all!", "yo yo what's up"]
tokens = tokenizer(text, return_tensors="pt", padding=True)

# Create torch and bf inputs to model
input_ids_torch = tokens["input_ids"]
labels_torch = torch.tensor([0, 1], dtype=torch.long)
input_ids_bf = jnp.array(input_ids_torch.numpy())
labels_bf = jnp.array(labels_torch.numpy())

In [6]:
%%time
outputs_torch = torch_model(input_ids_torch)
print(type(outputs_torch))


<class 'transformers.modeling_outputs.SequenceClassifierOutput'>
CPU times: user 189 ms, sys: 109 µs, total: 189 ms
Wall time: 36.1 ms


In [7]:
%%time
outputs_bf = bf_model(input_ids_bf)
print(type(outputs_bf))

<class 'transformers.modeling_bf_outputs.BfSequenceClassifierOutput'>
CPU times: user 1.12 s, sys: 34.1 ms, total: 1.16 s
Wall time: 1.16 s


In [8]:
# Check that forward pass for bf works and matches output shape with torch
if isinstance(outputs_bf, (list, tuple)):
    # Handle case where outputs is a tuple/list and not just a single item
    assert len(outputs_bf) == len(outputs_torch)
    for i in range(len(outputs_bf)):
        out_bf, out_torch = outputs_bf[i], outputs_torch[i] 
        assert(out_torch.shape == out_bf.shape)
elif is_dataclass(outputs_bf):
    check_equivalent_class(outputs_bf, outputs_torch)
    check_dataclass_keys_match(outputs_bf, outputs_torch)
else:
    assert(outputs_torch.shape == outputs_bf.shape)


In [9]:
# Save torch BertForMLM to file
save_path = "bertforseqclass_torch.pt"
torch.save(torch_model.state_dict(), save_path)

In [10]:
# Load state dict for BertForMLM into BF and check weights, outputs, and backprop
if not BF_FROM_MODEL_ID:
    bf_model.load_state_dict(torch.load(save_path))

### Check weights of BF model and Torch model match exactly

In [13]:
# Check weights match
check_bf_param_weights_match_torch(bf_model, torch_model)

Value of param weight bert.embeddings.word_embeddings.weight for bf and torch are equal? True
Value of param weight bert.embeddings.position_embeddings.weight for bf and torch are equal? True
Value of param weight bert.embeddings.token_type_embeddings.weight for bf and torch are equal? True
Value of param weight bert.embeddings.LayerNorm.weight for bf and torch are equal? True
Value of param weight bert.embeddings.LayerNorm.bias for bf and torch are equal? True
Value of param weight bert.encoder.layer.0.attention.self.query.weight for bf and torch are equal? True
Value of param weight bert.encoder.layer.0.attention.self.query.bias for bf and torch are equal? True
Value of param weight bert.encoder.layer.0.attention.self.key.weight for bf and torch are equal? True
Value of param weight bert.encoder.layer.0.attention.self.key.bias for bf and torch are equal? True
Value of param weight bert.encoder.layer.0.attention.self.value.weight for bf and torch are equal? True
Value of param weight 

### Check model output after forward pass matches for BF and Torch

In [14]:
# Set all dropouts to 0
for name, module in torch_model.named_modules():
    if module._get_name() == "Dropout":
        print(name, module.p)
        module.p = 0
        print(name, module.p)

bert.embeddings.dropout 0.1
bert.embeddings.dropout 0
bert.encoder.layer.0.attention.self.dropout 0.1
bert.encoder.layer.0.attention.self.dropout 0
bert.encoder.layer.0.attention.output.dropout 0.1
bert.encoder.layer.0.attention.output.dropout 0
bert.encoder.layer.0.output.dropout 0.1
bert.encoder.layer.0.output.dropout 0
bert.encoder.layer.1.attention.self.dropout 0.1
bert.encoder.layer.1.attention.self.dropout 0
bert.encoder.layer.1.attention.output.dropout 0.1
bert.encoder.layer.1.attention.output.dropout 0
bert.encoder.layer.1.output.dropout 0.1
bert.encoder.layer.1.output.dropout 0
dropout 0.1
dropout 0


In [15]:
# Check output from forward passes match for bf and torch
torch_model.train(False)
bf_model.train(False)

outputs_bf = bf_model(input_ids_bf)
outputs_torch = torch_model(input_ids_torch)

if isinstance(outputs_bf, (list, tuple)):
    assert len(outputs_bf) == len(outputs_torch)
    for i in range(len(outputs_bf)):
        out_bf, out_torch = outputs_bf[i], outputs_torch[i]
        check_bf_model_outputs_match_torch_outputs(out_bf, out_torch, atol=1e-6)
elif is_dataclass(outputs_bf):
    check_model_outputs_allclose(outputs_bf, outputs_torch, print_stats=True, atol=1e-2)
else:
    check_bf_model_outputs_match_torch_outputs(outputs_bf, outputs_torch, atol=1e-6)

Checking diff between BF and torch for logits:
Output of bf and torch are within 0.01? True
	Stats on diff in outputs between bf and torch:                   0
count  4.000000e+00
mean   5.000582e-08
std    1.432968e-08
min    2.938199e-08
25%    4.707900e-08
50%    5.410022e-08
75%    5.702704e-08
max    6.244085e-08


### Check grad after backward pass matches for BF and torch

In [16]:
%%time
# Torch backward pass
torch_model.train(True)

if isinstance(outputs_torch, (list, tuple)):
    assert len(outputs_bf) == len(outputs_torch)
    backprop_node_torch = outputs_torch[0]
elif is_dataclass(outputs_torch):
    backprop_node_torch = outputs_torch.logits
else:
    backprop_node_torch = outputs_torch
    
backprop_node_torch.backward(gradient=torch.ones_like(backprop_node_torch))

CPU times: user 20.2 ms, sys: 981 µs, total: 21.2 ms
Wall time: 9.69 ms


In [17]:
%%time 
# BF backward pass

if isinstance(outputs_bf, (list, tuple)):
    assert len(outputs_bf) == len(outputs_torch)
    backprop_node = outputs_bf[0]
elif is_dataclass(outputs_torch):
    backprop_node = outputs_bf.logits
else:
    backprop_node = outputs_bf
    
backprop_node.backprop(values_to_compute=("grad",))



CPU times: user 10.3 s, sys: 5.16 s, total: 15.4 s
Wall time: 7.86 s


In [18]:
# Run the actual check
check_bf_param_grads_allclose_torch(bf_model, torch_model, rtol=6e-2, atol=1e-2, print_output=True, print_stats=True, use_assert=True)

Grad of param bert.embeddings.word_embeddings.weight for bf and torch are within rtol=0.06, atol=0.01? True
Grad of param bert.embeddings.position_embeddings.weight for bf and torch are within rtol=0.06, atol=0.01? True
Grad of param bert.embeddings.token_type_embeddings.weight for bf and torch are within rtol=0.06, atol=0.01? True
Grad of param bert.embeddings.LayerNorm.weight for bf and torch are within rtol=0.06, atol=0.01? True
Grad of param bert.embeddings.LayerNorm.bias for bf and torch are within rtol=0.06, atol=0.01? True
Grad of param bert.encoder.layer.0.attention.self.query.weight for bf and torch are within rtol=0.06, atol=0.01? True
Grad of param bert.encoder.layer.0.attention.self.query.bias for bf and torch are within rtol=0.06, atol=0.01? True
Grad of param bert.encoder.layer.0.attention.self.key.weight for bf and torch are within rtol=0.06, atol=0.01? True
Grad of param bert.encoder.layer.0.attention.self.key.bias for bf and torch are within rtol=0.06, atol=0.01? True
