In [1]:
%env XLA_PYTHON_CLIENT_PREALLOCATE=false
from jax.config import config
config.update("jax_enable_x64", True)
from dataclasses import is_dataclass
from jax import numpy as jnp
import transformers
import torch
from transformers import (
    AutoTokenizer, 
    AutoModelForMaskedLM, 
    BertTokenizer, 
    BertTokenizerFast, 
    BertEmbeddings,
    BfBertEmbeddings,
    BertConfig,
    BertSelfAttention,
    BfBertSelfAttention,
    BertSelfOutput,
    BfBertSelfOutput,
    BertAttention,
    BfBertAttention,
    BertLayer,
    BfBertLayer,
    BertEncoder,
    BfBertEncoder,
    BaseModelOutputWithPastAndCrossAttentions,
    BfBaseModelOutputWithPastAndCrossAttentions,
    BertModel,
    BfBertModel,
)
from brunoflow.ad.utils import check_node_equals_tensor, check_node_allclose_tensor
from utils import check_bf_param_weights_match_torch, check_equivalent_class, check_dataclass_keys_match, check_model_outputs_allclose, check_bf_model_outputs_match_torch_outputs, check_bf_param_grads_allclose_torch
torch.manual_seed(0)


env: XLA_PYTHON_CLIENT_PREALLOCATE=false


  from .autonotebook import tqdm as notebook_tqdm


<torch._C.Generator at 0x7feba2cce6f0>

In [2]:
# Init torch and bf models
BF_FROM_MODEL_ID = False ### NOTE: BECAUSE THIS IS SUPER HACKY THIS SOMEWHAT DOES NOT WORK WHEN SET TO TRUE. FROM_PRETRAINED FOR BRUNOFLOW IS PROBABLY SOMEWHAT BROKEN, BUT AT LEAST THIS IS A WORKAROUND. Also it looks like the errors are only bounded by 0.01 :/.
TORCH_FROM_MODEL_ID = True
# model_id = "bert-base-uncased"
model_id = "google/bert_uncased_L-2_H-128_A-2"
config = BertConfig.from_pretrained(pretrained_model_name_or_path="../../brunoflow/models/bert/config-tiny.json")

if TORCH_FROM_MODEL_ID:
    torch_model = BertModel.from_pretrained(model_id)
else:
    torch_model = BertModel(config)
if BF_FROM_MODEL_ID:
    bf_model = BfBertModel.from_pretrained(model_id)
else:
    bf_model = BfBertModel(config)

Some weights of the model checkpoint at google/bert_uncased_L-2_H-128_A-2 were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
2023-01-05 00:13:21.320889: E external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_dri



In [3]:
# Establish data
tokenizer = BertTokenizerFast.from_pretrained(model_id)
# text = ["hello hello hello there"]#, "yo hi what's up", "bleh this sucks"]
text = ["hello I want to eat some [MASK] meat today. It's thanksgiving [MASK] all!", "yo uo hi what's up"] #, "bleh this sucks"]
tokens = tokenizer(text, return_tensors="pt", padding=True)

# Create torch and bf inputs to model
input_ids_torch = tokens["input_ids"]
labels_torch = torch.ones_like(input_ids_torch)

input_ids_bf = jnp.array(input_ids_torch.numpy())
labels_bf = jnp.array(labels_torch.numpy())

In [4]:
%%time
outputs_torch = torch_model(input_ids_torch)
print(type(outputs_torch))


<class 'transformers.modeling_outputs.BaseModelOutputWithPoolingAndCrossAttentions'>
CPU times: user 89.4 ms, sys: 0 ns, total: 89.4 ms
Wall time: 13.8 ms


In [5]:
%%time
outputs_bf = bf_model(input_ids_bf)
print(type(outputs_bf))

<class 'transformers.modeling_bf_outputs.BfBaseModelOutputWithPoolingAndCrossAttentions'>
CPU times: user 1 s, sys: 14.4 ms, total: 1.01 s
Wall time: 1 s


In [6]:
# Check that forward pass for bf works and matches output shape with torch
if isinstance(outputs_bf, (list, tuple)):
    # Handle case where outputs is a tuple/list and not just a single item
    assert len(outputs_bf) == len(outputs_torch)
    for i in range(len(outputs_bf)):
        out_bf, out_torch = outputs_bf[i], outputs_torch[i] 
        assert(out_torch.shape == out_bf.shape)
elif is_dataclass(outputs_bf):
    check_equivalent_class(outputs_bf, outputs_torch)
    check_dataclass_keys_match(outputs_bf, outputs_torch)
else:
    assert(outputs_torch.shape == outputs_bf.shape)


In [7]:
# Save torch BertModel to file
save_path = "bertmodel_torch.pt"
torch.save(torch_model.state_dict(), save_path)

In [8]:
# Load state dict for BertModel into BF and check weights, outputs, and backprop
if not BF_FROM_MODEL_ID:
    bf_model.load_state_dict(torch.load(save_path))

### Check weights of BF model and Torch model match exactly

In [9]:
def compare_models(model_1, model_2):
    models_differ = 0
    for key_item_1, key_item_2 in zip(model_1.state_dict().items(), model_2.state_dict().items()):
        if check_node_equals_tensor(key_item_1[1], key_item_2[1]):
            pass
        else:
            models_differ += 1
            if (key_item_1[0] == key_item_2[0]):
                print('Mismatch found at', key_item_1[0])
            else:
                raise Exception
    if models_differ == 0:
        print('Models match perfectly! :)')

compare_models(bf_model, torch_model)

Models match perfectly! :)


In [10]:
# Check weights match
check_bf_param_weights_match_torch(bf_model, torch_model)

Value of param weight embeddings.word_embeddings.weight for bf and torch are equal? True
Value of param weight embeddings.position_embeddings.weight for bf and torch are equal? True
Value of param weight embeddings.token_type_embeddings.weight for bf and torch are equal? True
Value of param weight embeddings.LayerNorm.weight for bf and torch are equal? True
Value of param weight embeddings.LayerNorm.bias for bf and torch are equal? True
Value of param weight encoder.layer.0.attention.self.query.weight for bf and torch are equal? True
Value of param weight encoder.layer.0.attention.self.query.bias for bf and torch are equal? True
Value of param weight encoder.layer.0.attention.self.key.weight for bf and torch are equal? True
Value of param weight encoder.layer.0.attention.self.key.bias for bf and torch are equal? True
Value of param weight encoder.layer.0.attention.self.value.weight for bf and torch are equal? True
Value of param weight encoder.layer.0.attention.self.value.bias for bf a

### Check model output after forward pass matches for BF and Torch

In [11]:
# Set all dropouts to 0
for name, module in torch_model.named_modules():
    if module._get_name() == "Dropout":
        print(name, module.p)
        module.p = 0
        print(name, module.p)

embeddings.dropout 0.1
embeddings.dropout 0
encoder.layer.0.attention.self.dropout 0.1
encoder.layer.0.attention.self.dropout 0
encoder.layer.0.attention.output.dropout 0.1
encoder.layer.0.attention.output.dropout 0
encoder.layer.0.output.dropout 0.1
encoder.layer.0.output.dropout 0
encoder.layer.1.attention.self.dropout 0.1
encoder.layer.1.attention.self.dropout 0
encoder.layer.1.attention.output.dropout 0.1
encoder.layer.1.attention.output.dropout 0
encoder.layer.1.output.dropout 0.1
encoder.layer.1.output.dropout 0


In [12]:
# Check output from forward passes match for bf and torch
torch_model.train(False)
bf_model.train(False)
outputs_bf = bf_model(input_ids_bf)
outputs_torch = torch_model(input_ids_torch)

if isinstance(outputs_bf, (list, tuple)):
    assert len(outputs_bf) == len(outputs_torch)
    for i in range(len(outputs_bf)):
        out_bf, out_torch = outputs_bf[i], outputs_torch[i]
        check_bf_model_outputs_match_torch_outputs(out_bf, out_torch, atol=1e-6)
elif is_dataclass(outputs_bf):
    check_model_outputs_allclose(outputs_bf, outputs_torch, print_stats=True, atol=1e-2)
else:
    check_bf_model_outputs_match_torch_outputs(outputs_bf, outputs_torch, atol=1e-6)

Checking diff between BF and torch for last_hidden_state:
Output of bf and torch are within 0.01? True
	Stats on diff in outputs between bf and torch:                   0
count  4.864000e+03
mean   5.016467e-07
std    4.475989e-07
min    3.336587e-11
25%    1.814181e-07
50%    3.889932e-07
75%    6.923858e-07
max    4.264745e-06
Checking diff between BF and torch for pooler_output:
Output of bf and torch are within 0.01? True
	Stats on diff in outputs between bf and torch:                   0
count  2.560000e+02
mean   1.013631e-07
std    1.682405e-07
min    1.373492e-10
25%    1.481915e-08
50%    3.548915e-08
75%    9.879269e-08
max    1.147149e-06


### Check grad after backward pass matches for BF and torch

In [13]:
# Check grads equal before backward passes
torch_model.train(True)
bf_model.train(True)
check_bf_param_grads_allclose_torch(bf_model, torch_model, atol=1e-2, print_output=False, print_stats=False, use_assert=True)

No grad for param embeddings.word_embeddings.weight for torch. BF grad is zero? True
No grad for param embeddings.position_embeddings.weight for torch. BF grad is zero? True
No grad for param embeddings.token_type_embeddings.weight for torch. BF grad is zero? True
No grad for param embeddings.LayerNorm.weight for torch. BF grad is zero? True
No grad for param embeddings.LayerNorm.bias for torch. BF grad is zero? True
No grad for param encoder.layer.0.attention.self.query.weight for torch. BF grad is zero? True
No grad for param encoder.layer.0.attention.self.query.bias for torch. BF grad is zero? True
No grad for param encoder.layer.0.attention.self.key.weight for torch. BF grad is zero? True
No grad for param encoder.layer.0.attention.self.key.bias for torch. BF grad is zero? True
No grad for param encoder.layer.0.attention.self.value.weight for torch. BF grad is zero? True
No grad for param encoder.layer.0.attention.self.value.bias for torch. BF grad is zero? True
No grad for param e

In [14]:
%%time
# Torch backward pass
torch_model.train(True)

if isinstance(outputs_torch, (list, tuple)):
    assert len(outputs_bf) == len(outputs_torch)
    backprop_node_torch = outputs_torch[0]
elif is_dataclass(outputs_torch):
    backprop_node_torch = outputs_torch.last_hidden_state
else:
    backprop_node_torch = outputs_torch
    
backprop_node_torch.backward(gradient=torch.ones_like(backprop_node_torch))

CPU times: user 14.3 ms, sys: 2.49 ms, total: 16.8 ms
Wall time: 7.77 ms


  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


In [15]:
%%time 
# BF backward pass
bf_model.train(True)

if isinstance(outputs_bf, (list, tuple)):
    assert len(outputs_bf) == len(outputs_torch)
    backprop_node = outputs_bf[0]
elif is_dataclass(outputs_torch):
    backprop_node = outputs_bf.last_hidden_state
else:
    backprop_node = outputs_bf
    
backprop_node.backprop(values_to_compute=("grad",))

CPU times: user 9.51 s, sys: 5.61 s, total: 15.1 s
Wall time: 7.52 s




In [16]:
# Run the actual check
check_bf_param_grads_allclose_torch(bf_model, torch_model, rtol=1e-4, atol=1e-5, print_output=True, print_stats=True, use_assert=True)

Grad of param embeddings.word_embeddings.weight for bf and torch are within rtol=0.0001, atol=1e-05? True
Grad of param embeddings.position_embeddings.weight for bf and torch are within rtol=0.0001, atol=1e-05? True
Grad of param embeddings.token_type_embeddings.weight for bf and torch are within rtol=0.0001, atol=1e-05? True
Grad of param embeddings.LayerNorm.weight for bf and torch are within rtol=0.0001, atol=1e-05? True
Grad of param embeddings.LayerNorm.bias for bf and torch are within rtol=0.0001, atol=1e-05? True
Grad of param encoder.layer.0.attention.self.query.weight for bf and torch are within rtol=0.0001, atol=1e-05? True
Grad of param encoder.layer.0.attention.self.query.bias for bf and torch are within rtol=0.0001, atol=1e-05? True
Grad of param encoder.layer.0.attention.self.key.weight for bf and torch are within rtol=0.0001, atol=1e-05? True
Grad of param encoder.layer.0.attention.self.key.bias for bf and torch are within rtol=0.0001, atol=1e-05? True
Grad of param enco

### Miscellaneous extra code for testing/debugging the problem

In [17]:
for child in torch_model.children():
    if child._get_name() == "BertEmbeddings":
        break
for name, grandchild in child.named_modules():
    if name == "word_embeddings":
        break

torch_emb = child
torch_word_emb = grandchild

for child in bf_model.children():
    if child._get_name() == "BfBertEmbeddings":
        break
for name, grandchild in child.named_modules():
    if name == "word_embeddings":
        break
    
bf_emb = child
bf_word_emb = grandchild

# print(torch_emb)
# print(bf_emb)
# print(torch_word_emb)
# print(bf_word_emb)

In [18]:
check_bf_param_weights_match_torch(bf_emb, torch_emb)
out_emb_bf = bf_emb(input_ids_bf)
out_emb_torch = torch_emb(input_ids_torch)
check_bf_model_outputs_match_torch_outputs(out_emb_bf, out_emb_torch, atol=1e-3)
out_emb_bf.backprop(values_to_compute=("grad",))
out_emb_torch.backward(gradient=torch.ones_like(out_emb_torch))

topk = torch.topk(dict(torch_emb.named_parameters())["word_embeddings.weight"].grad.flatten(), 5).indices
dict(torch_emb.named_parameters())["word_embeddings.weight"].reshape(-1)[13056]

for k in topk:
    k = int(k)
    print(f"torch grad at [{k // 128, k % 128}]:", dict(torch_emb.named_parameters())["word_embeddings.weight"].grad[k // 128, k % 128])
    print(f"bf grad at [{k // 128, k % 128}]:", dict(bf_emb.named_parameters())["word_embeddings.weight"].grad[k // 128, k % 128])
    print(f"torch val at [{k // 128, k % 128}]:", dict(torch_emb.named_parameters())["word_embeddings.weight"][k // 128, k % 128])
    print(f"bf val at [{k // 128, k % 128}]:", dict(bf_emb.named_parameters())["word_embeddings.weight"].val[k // 128, k % 128])
    print()

check_bf_param_grads_allclose_torch(bf_emb, torch_emb, atol=1e-2, print_output=True, print_stats=False, use_assert=True)

Value of param weight word_embeddings.weight for bf and torch are equal? True
Value of param weight position_embeddings.weight for bf and torch are equal? True
Value of param weight token_type_embeddings.weight for bf and torch are equal? True
Value of param weight LayerNorm.weight for bf and torch are equal? True
Value of param weight LayerNorm.bias for bf and torch are equal? True
Output of bf and torch are within 0.001? True
torch grad at [(102, 0)]: tensor(26.4956)
bf grad at [(102, 0)]: 26.495648581566883
torch val at [(102, 0)]: tensor(-0.0756, grad_fn=<SelectBackward0>)
bf val at [(102, 0)]: -0.07556523

torch grad at [(1055, 0)]: tensor(20.6873)
bf grad at [(1055, 0)]: 20.68725014962451
torch val at [(1055, 0)]: tensor(-0.0034, grad_fn=<SelectBackward0>)
bf val at [(1055, 0)]: -0.0034130977

torch grad at [(1005, 43)]: tensor(20.6858)
bf grad at [(1005, 43)]: 20.685824248006
torch val at [(1005, 43)]: tensor(-0.0300, grad_fn=<SelectBackward0>)
bf val at [(1005, 43)]: -0.0300145



In [19]:
torch_word_emb.zero_grad()
bf_word_emb.zero_grad()

check_bf_param_weights_match_torch(bf_word_emb, torch_word_emb)
out_emb_bf = bf_word_emb(input_ids_bf)
out_emb_torch = torch_word_emb(input_ids_torch)
check_bf_model_outputs_match_torch_outputs(out_emb_bf, out_emb_torch, atol=1e-6)
out_emb_bf.backprop(values_to_compute=("grad",))
out_emb_torch.backward(gradient=torch.ones_like(out_emb_torch))

topk = torch.topk(dict(torch_word_emb.named_parameters())["weight"].grad.flatten(), 5).indices

for k in topk:
    k = int(k)
    print(f"torch grad at [{k // 128, k % 128}]:", dict(torch_word_emb.named_parameters())["weight"].grad[k // 128, k % 128])
    print(f"bf grad at [{k // 128, k % 128}]:", dict(bf_word_emb.named_parameters())["weight"].grad[k // 128, k % 128])
    print()

print("Max torch word emb grad:", torch_word_emb.weight.grad.max())
print("Max bf_word_emb grad:", bf_word_emb.weight.grad.max())
check_bf_param_grads_allclose_torch(bf_word_emb, torch_word_emb, atol=1e-2, print_output=True, print_stats=False, use_assert=True)


Value of param weight weight for bf and torch are equal? True
Output of bf and torch are within 1e-06? True
torch grad at [(101, 2)]: tensor(2.)
bf grad at [(101, 2)]: 2.0

torch grad at [(101, 1)]: tensor(2.)
bf grad at [(101, 1)]: 2.0

torch grad at [(101, 3)]: tensor(2.)
bf grad at [(101, 3)]: 2.0

torch grad at [(101, 4)]: tensor(2.)
bf grad at [(101, 4)]: 2.0

torch grad at [(101, 0)]: tensor(2.)
bf grad at [(101, 0)]: 2.0

Max torch word emb grad: tensor(2.)
Max bf_word_emb grad: 2.0
Grad of param weight for bf and torch are within rtol=0.001, atol=0.01? True


In [20]:
### Check that when the embedding is saved and then loaded by torch that the grads are equal
word_emb_save_path = "bertmodel_torch_word_emb.pt"
torch.save(torch_word_emb.state_dict(), word_emb_save_path)
torch_word_emb_loaded = torch.nn.Embedding(num_embeddings=torch_word_emb.num_embeddings, embedding_dim=torch_word_emb.embedding_dim, padding_idx=torch_word_emb.padding_idx)
torch_word_emb_loaded.load_state_dict(torch.load(word_emb_save_path))

# Check state_dicts are same
str(torch_word_emb.state_dict()) == str(torch_word_emb_loaded.state_dict())
torch_word_emb.zero_grad()
torch_word_emb_loaded.zero_grad()

# Check outputs are same
out_emb_torch = torch_word_emb(input_ids_torch)
out_emb_torch_loaded = torch_word_emb_loaded(input_ids_torch)
assert torch.equal(out_emb_torch, out_emb_torch_loaded)

# check grads are the same after backprop
out_emb_torch.backward(gradient=torch.ones_like(out_emb_torch))
out_emb_torch_loaded.backward(gradient=torch.ones_like(out_emb_torch_loaded))

def check_torch2_param_grads_allclose_torch(
    torch_module, torch_module2, atol=1e-6
):
    """Used to verify that grad after backward passes for bf and torch are close for all params in the network."""
    bf_params = {name: param for name, param in torch_module2.named_parameters()}
    torch_params = {name: param for name, param in torch_module.named_parameters()}
    assert set(bf_params.keys()) == set(
        torch_params.keys()
    ), f"BF and torch keys do not match: BF contains following extra keys {set(bf_params.keys()).difference(set(torch_params.keys()))} and is missing keys {set(torch_params.keys()).difference(set(bf_params.keys()))}"

    not_allclose_params = []
    for name in bf_params.keys():
        if torch_params[name].grad is None:
            assert bf_params[name].grad is None
        else:
            assert torch.allclose(torch_params[name].grad, bf_params[name].grad, atol=atol)


check_torch2_param_grads_allclose_torch(torch_word_emb, torch_word_emb_loaded, atol=1e-6)

torch_word_emb_loaded.weight.grad.max()
torch_word_emb.weight.grad.max()

tensor(2.)

In [21]:
### Check that when the embedding is saved by torch and then loaded by bf that the grads are equal
from brunoflow.net import Embedding
bf_word_emb_loaded = Embedding(num_embeddings=torch_word_emb.num_embeddings, embedding_dim=torch_word_emb.embedding_dim, padding_idx=torch_word_emb.padding_idx)
bf_word_emb_loaded.load_state_dict(torch.load(word_emb_save_path))

torch_word_emb_loaded.zero_grad()
bf_word_emb_loaded.zero_grad()

# Check weights
check_bf_param_weights_match_torch(bf_word_emb_loaded, torch_word_emb_loaded)

# Check outputs
out_emb_bf = bf_word_emb_loaded(input_ids_bf)
out_emb_torch = torch_word_emb_loaded(input_ids_torch)
check_bf_model_outputs_match_torch_outputs(out_emb_bf, out_emb_torch, atol=1e-6)

# Check grads
out_emb_bf.backprop(values_to_compute=("grad",))
out_emb_torch.backward(gradient=torch.ones_like(out_emb_torch))

print("Max torch word emb grad:", torch_word_emb_loaded.weight.grad.max())
print("Max bf_word_emb grad:", bf_word_emb_loaded.weight.grad.max())

check_bf_param_grads_allclose_torch(bf_word_emb_loaded, torch_word_emb_loaded, atol=1e-2, print_output=True, print_stats=False, use_assert=True)


Value of param weight weight for bf and torch are equal? True
Output of bf and torch are within 1e-06? True
Max torch word emb grad: tensor(2.)
Max bf_word_emb grad: 2.0
Grad of param weight for bf and torch are within rtol=0.001, atol=0.01? True


In [22]:
vars(torch_word_emb), vars(bf_word_emb)
# num_embeddings: int
#     embedding_dim: int
#     padding_idx: Optional[int]
#     max_norm: Optional[float]
#     norm_type: float
#     scale_grad_by_freq: bool
#     weight: Tensor
#     sparse: bool

({'training': True,
  '_parameters': OrderedDict([('weight',
                Parameter containing:
                tensor([[-4.1018e-03, -3.0695e-02, -3.5295e-03,  ...,  1.8925e-02,
                          3.7396e-03, -2.9233e-03],
                        [-4.2748e-04, -3.6929e-02, -1.7168e-02,  ...,  2.9314e-02,
                         -1.0398e-02,  2.6772e-02],
                        [ 5.9418e-03,  4.2119e-03, -1.9566e-02,  ...,  1.6799e-02,
                         -2.7802e-02, -6.9017e-03],
                        ...,
                        [ 3.5573e-02, -1.5891e-02,  4.9951e-03,  ...,  5.4071e-03,
                         -1.1270e-02, -6.9528e-05],
                        [-8.7018e-03, -2.2516e-02,  3.1993e-03,  ...,  2.7591e-02,
                         -1.9554e-02,  2.4023e-03],
                        [-7.8904e-02, -7.5407e-02, -4.6660e-03,  ..., -5.3340e-03,
                         -4.4993e-02,  5.9842e-02]], requires_grad=True))]),
  '_buffers': OrderedDict(),
  '_non_

In [23]:
topk = torch.topk(dict(torch_model.named_parameters())["embeddings.word_embeddings.weight"].grad.flatten(), 5).indices
dict(torch_model.named_parameters())["embeddings.word_embeddings.weight"].reshape(-1)[13056]

for k in topk:
    k = int(k)
    print(f"torch grad at [{k // 128, k % 128}]:", dict(torch_model.named_parameters())["embeddings.word_embeddings.weight"].grad[k // 128, k % 128])
    print(f"bf grad at [{k // 128, k % 128}]:", dict(bf_model.named_parameters())["embeddings.word_embeddings.weight"].grad[k // 128, k % 128])
    print(f"torch val at [{k // 128, k % 128}]:", dict(torch_model.named_parameters())["embeddings.word_embeddings.weight"][k // 128, k % 128])
    print(f"bf val at [{k // 128, k % 128}]:", dict(bf_model.named_parameters())["embeddings.word_embeddings.weight"].val[k // 128, k % 128])
    print()

torch grad at [(101, 2)]: tensor(2.)
bf grad at [(101, 2)]: 2.0
torch val at [(101, 2)]: tensor(-0.5026, grad_fn=<SelectBackward0>)
bf val at [(101, 2)]: -0.5025546

torch grad at [(101, 1)]: tensor(2.)
bf grad at [(101, 1)]: 2.0
torch val at [(101, 1)]: tensor(-0.0182, grad_fn=<SelectBackward0>)
bf val at [(101, 1)]: -0.01819513

torch grad at [(101, 3)]: tensor(2.)
bf grad at [(101, 3)]: 2.0
torch val at [(101, 3)]: tensor(-0.0100, grad_fn=<SelectBackward0>)
bf val at [(101, 3)]: -0.010009427

torch grad at [(101, 4)]: tensor(2.)
bf grad at [(101, 4)]: 2.0
torch val at [(101, 4)]: tensor(0.0039, grad_fn=<SelectBackward0>)
bf val at [(101, 4)]: 0.0039283177

torch grad at [(101, 0)]: tensor(2.)
bf grad at [(101, 0)]: 2.0
torch val at [(101, 0)]: tensor(0.0177, grad_fn=<SelectBackward0>)
bf val at [(101, 0)]: 0.017665198



In [24]:
dict(torch_model.named_parameters())["embeddings.position_embeddings.weight"].grad, dict(bf_model.named_parameters())["embeddings.position_embeddings.weight"].grad

(tensor([[ 13.3900,  -7.8846,  -4.9994,  ...,  -7.8793,   1.5591,  -8.5621],
         [ 20.8252, -15.4823,  -9.9196,  ...,   7.4338,  -3.8761,  -9.5488],
         [ 24.2847, -20.2617,  -9.8825,  ...,   7.9612,  -6.3824, -13.3367],
         ...,
         [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
         [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
         [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000]]),
 DeviceArray([[ 13.3899493 ,  -7.88462552,  -4.99939799, ...,
                -7.87930506,   1.55909437,  -8.56214457],
              [ 20.8252424 , -15.48232402,  -9.91958797, ...,
                 7.43374901,  -3.87605592,  -9.54876955],
              [ 24.28469348, -20.26168625,  -9.88253943, ...,
                 7.96123939,  -6.38237227, -13.33671944],
              ...,
              [  0.        ,   0.        ,   0.        , ...,
                 0.        ,   0.        ,   0.        ],
              [  0.    

In [25]:
diff = jnp.abs(dict(torch_model.named_parameters())["embeddings.word_embeddings.weight"].grad.numpy() - dict(bf_model.named_parameters())["embeddings.word_embeddings.weight"].grad)

In [26]:
import pandas as pd
pd.DataFrame(jnp.argmax(diff, axis=0)).value_counts()

0    128
dtype: int64

In [27]:
import numpy as np
np.array(diff.nonzero())

array([], shape=(2, 0), dtype=int64)

In [28]:
dict(torch_model.named_parameters())["embeddings.word_embeddings.weight"]

Parameter containing:
tensor([[-4.1018e-03, -3.0695e-02, -3.5295e-03,  ...,  1.8925e-02,
          3.7396e-03, -2.9233e-03],
        [-4.2748e-04, -3.6929e-02, -1.7168e-02,  ...,  2.9314e-02,
         -1.0398e-02,  2.6772e-02],
        [ 5.9418e-03,  4.2119e-03, -1.9566e-02,  ...,  1.6799e-02,
         -2.7802e-02, -6.9017e-03],
        ...,
        [ 3.5573e-02, -1.5891e-02,  4.9951e-03,  ...,  5.4071e-03,
         -1.1270e-02, -6.9528e-05],
        [-8.7018e-03, -2.2516e-02,  3.1993e-03,  ...,  2.7591e-02,
         -1.9554e-02,  2.4023e-03],
        [-7.8904e-02, -7.5407e-02, -4.6660e-03,  ..., -5.3340e-03,
         -4.4993e-02,  5.9842e-02]], requires_grad=True)

In [29]:
weight = dict(torch_model.named_parameters())["embeddings.word_embeddings.weight"]
weight_grads = weight.grad
weight_grads[weight_grads.nonzero()[0]], weight_grads.nonzero()[0]
weight_grads[weight_grads.nonzero()].shape, weight_grads.nonzero().shape

(torch.Size([3072, 2, 128]), torch.Size([3072, 2]))

In [30]:
weight, dict(bf_model.named_parameters())["embeddings.word_embeddings.weight"].val


(Parameter containing:
 tensor([[-4.1018e-03, -3.0695e-02, -3.5295e-03,  ...,  1.8925e-02,
           3.7396e-03, -2.9233e-03],
         [-4.2748e-04, -3.6929e-02, -1.7168e-02,  ...,  2.9314e-02,
          -1.0398e-02,  2.6772e-02],
         [ 5.9418e-03,  4.2119e-03, -1.9566e-02,  ...,  1.6799e-02,
          -2.7802e-02, -6.9017e-03],
         ...,
         [ 3.5573e-02, -1.5891e-02,  4.9951e-03,  ...,  5.4071e-03,
          -1.1270e-02, -6.9528e-05],
         [-8.7018e-03, -2.2516e-02,  3.1993e-03,  ...,  2.7591e-02,
          -1.9554e-02,  2.4023e-03],
         [-7.8904e-02, -7.5407e-02, -4.6660e-03,  ..., -5.3340e-03,
          -4.4993e-02,  5.9842e-02]], requires_grad=True),
 DeviceArray([[-4.1018268e-03, -3.0694773e-02, -3.5295275e-03, ...,
                1.8925212e-02,  3.7396429e-03, -2.9232893e-03],
              [-4.2748044e-04, -3.6928687e-02, -1.7167933e-02, ...,
                2.9313693e-02, -1.0397688e-02,  2.6771577e-02],
              [ 5.9417770e-03,  4.2118742e-03, 