In [1]:
%env XLA_PYTHON_CLIENT_PREALLOCATE=false
from jax.config import config
config.update("jax_enable_x64", True)
from jax import numpy as jnp
import transformers
import torch
from transformers import (
    AutoTokenizer, 
    AutoModelForMaskedLM, 
    BertForMaskedLM, 
    BertTokenizer, 
    BertTokenizerFast, 
    BertEmbeddings,
    BfBertEmbeddings,
    BertConfig,
    BertSelfAttention,
    BfBertSelfAttention,
    BertSelfOutput,
    BfBertSelfOutput,
    BertAttention,
    BfBertAttention,
    BertLayer,
    BfBertLayer,
    BertEncoder,
    BfBertEncoder,
    BaseModelOutputWithPastAndCrossAttentions,
    BfBaseModelOutputWithPastAndCrossAttentions,
)
from brunoflow.ad.utils import check_node_equals_tensor, check_node_allclose_tensor
from utils import check_bf_param_weights_match_torch, check_bf_model_outputs_match_torch_outputs, check_bf_param_grads_allclose_torch
torch.manual_seed(0)


env: XLA_PYTHON_CLIENT_PREALLOCATE=false


  from .autonotebook import tqdm as notebook_tqdm


<torch._C.Generator at 0x7fbd049e1710>

In [2]:
# Init torch and bf models
config = BertConfig.from_pretrained(pretrained_model_name_or_path="../../brunoflow/models/bert/config.json")
torch_model = BertEncoder(config)
bf_model = BfBertEncoder(config)

In [3]:
# Init inputs to bf and torch models
hidden_states_torch = torch.randn(size=(2, 19, 768))
attention_mask_torch = torch.randn(size=(2, 1, 1, 19))

hidden_states = jnp.array(hidden_states_torch.numpy(), dtype=jnp.float64)
attention_mask = jnp.array(attention_mask_torch.numpy(), dtype=jnp.float64)

In [4]:
%%time
outputs_torch = torch_model(hidden_states_torch, attention_mask_torch)


CPU times: user 377 ms, sys: 32.1 ms, total: 409 ms
Wall time: 53.7 ms


In [12]:
type(outputs_torch)

transformers.modeling_outputs.BaseModelOutputWithPastAndCrossAttentions

In [5]:
%%time
outputs_bf = bf_model(hidden_states, attention_mask)


2022-12-26 09:40:40.105076: W external/org_tensorflow/tensorflow/tsl/framework/bfc_allocator.cc:479] Allocator (GPU_0_bfc) ran out of memory trying to allocate 18.00MiB (rounded to 18874368)requested by op 
2022-12-26 09:40:40.109100: W external/org_tensorflow/tensorflow/tsl/framework/bfc_allocator.cc:492] ****************************************************************************************************
2022-12-26 09:40:40.109193: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2153] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 18874368 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:   18.00MiB
              constant allocation:         0B
        maybe_live_out allocation:   18.00MiB
     preallocated temp allocation:         0B
                 total allocation:   36.00MiB
              total fragmentation:         0B (0.00%)
Peak buffers:
	Buffer 1:


XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 18874368 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:   18.00MiB
              constant allocation:         0B
        maybe_live_out allocation:   18.00MiB
     preallocated temp allocation:         0B
                 total allocation:   36.00MiB
              total fragmentation:         0B (0.00%)
Peak buffers:
	Buffer 1:
		Size: 18.00MiB
		Entry Parameter Subshape: f64[768,3072]
		==========================

	Buffer 2:
		Size: 18.00MiB
		Operator: op_name="jit(transpose)/jit(main)/transpose[permutation=(1, 0)]" source_file="/home/kevin/code/rycolab/brunoflow/brunoflow/func/linalg.py" source_line=41
		XLA Label: transpose
		Shape: f64[3072,768]
		==========================



In [6]:
# Check that forward pass for bf works and matches output shape with torch
if isinstance(outputs_bf, (list, tuple)):
    # Handle case where outputs is a tuple/list and not just a single item
    assert len(outputs_bf) == len(outputs_torch)
    for i in range(len(outputs_bf)):
        out_bf, out_torch = outputs_bf[i], outputs_torch[i] 
        assert(out_torch.shape == out_bf.shape)

else:
    assert(outputs_torch.shape == outputs_bf.shape)


In [7]:
# Save torch BertSelfAttention to file
save_path = "bertlayer_torch.pt"
torch.save(torch_model.state_dict(), save_path)

In [8]:
# Load state dict for BertSelfAttention into BF and check weights, outputs, and backprop
bf_model.load_state_dict(torch.load(save_path))

<All keys matched successfully>

### Check weights of BF model and Torch model match exactly

In [9]:
# Check weights match
check_bf_param_weights_match_torch(bf_model, torch_model)

Value of param weight attention.self.query.bias for bf and torch are equal? True
Value of param weight attention.self.query.weight for bf and torch are equal? True
Value of param weight attention.self.key.bias for bf and torch are equal? True
Value of param weight attention.self.key.weight for bf and torch are equal? True
Value of param weight attention.self.value.bias for bf and torch are equal? True
Value of param weight attention.self.value.weight for bf and torch are equal? True
Value of param weight attention.output.dense.bias for bf and torch are equal? True
Value of param weight attention.output.dense.weight for bf and torch are equal? True
Value of param weight attention.output.LayerNorm.weight for bf and torch are equal? True
Value of param weight attention.output.LayerNorm.bias for bf and torch are equal? True
Value of param weight intermediate.dense.bias for bf and torch are equal? True
Value of param weight intermediate.dense.weight for bf and torch are equal? True
Value of

### Check model output after forward pass matches for BF and Torch

In [10]:
# Check output from forward passes match for bf and torch
torch_model.train(False)
outputs_bf = bf_model(hidden_states=hidden_states, attention_mask=attention_mask)
outputs_torch = torch_model(hidden_states=hidden_states_torch, attention_mask=attention_mask_torch)

if isinstance(outputs_bf, (list, tuple)):
    assert len(outputs_bf) == len(outputs_torch)
    for i in range(len(outputs_bf)):
        out_bf, out_torch = outputs_bf[i], outputs_torch[i]
        check_bf_model_outputs_match_torch_outputs(out_bf, out_torch, atol=1e-6)
else:
    check_bf_model_outputs_match_torch_outputs(outputs_bf, outputs_torch, atol=1e-6)

Output of bf and torch are within 1e-06? True


### Check grad after backward pass matches for BF and torch

In [11]:
%%time
# Torch backward pass
torch_model.train(True)

if isinstance(outputs_torch, (list, tuple)):
    assert len(outputs_bf) == len(outputs_torch)
    outputs_torch[0].backward(gradient=torch.ones_like(outputs_torch[0]))
else:
    outputs_torch.backward(gradient=torch.ones_like(outputs_torch))

CPU times: user 45.8 ms, sys: 1.09 ms, total: 46.9 ms
Wall time: 45.4 ms


In [12]:
%%time 
# BF backward pass
if isinstance(outputs_bf, (list, tuple)):
    assert len(outputs_bf) == len(outputs_torch)
    outputs_bf[0].backprop(values_to_compute=("grad",))
else:
    outputs_bf.backprop(values_to_compute=("grad",))

CPU times: user 1min 38s, sys: 1min 7s, total: 2min 46s
Wall time: 1min 37s


In [14]:
# Run the actual check
check_bf_param_grads_allclose_torch(bf_model, torch_model, atol=1e-5, print_output=True)

Grad of param attention.self.query.bias for bf and torch are within 1e-05? True
Grad of param attention.self.query.weight for bf and torch are within 1e-05? True
Grad of param attention.self.key.bias for bf and torch are within 1e-05? True
Grad of param attention.self.key.weight for bf and torch are within 1e-05? True
Grad of param attention.self.value.bias for bf and torch are within 1e-05? True
Grad of param attention.self.value.weight for bf and torch are within 1e-05? True
Grad of param attention.output.dense.bias for bf and torch are within 1e-05? True
Grad of param attention.output.dense.weight for bf and torch are within 1e-05? True
Grad of param attention.output.LayerNorm.weight for bf and torch are within 1e-05? True
Grad of param attention.output.LayerNorm.bias for bf and torch are within 1e-05? True
Grad of param intermediate.dense.bias for bf and torch are within 1e-05? True
Grad of param intermediate.dense.weight for bf and torch are within 1e-05? True
Grad of param output