In [1]:
%env XLA_PYTHON_CLIENT_PREALLOCATE=false
import brunoflow as bf
from brunoflow.ad.utils import check_node_equals_tensor, check_node_allclose_tensor
from jax import numpy as jnp
import numpy as np
import transformers
import torch
from transformers import (
    AutoTokenizer, 
    AutoModelForMaskedLM, 
    BertForMaskedLM, 
    BertTokenizer, 
    BertTokenizerFast, 
    BertEmbeddings,
    BfBertEmbeddings,
    BfBertEncoder,
    BertConfig,
    BfBertSelfAttention,
    BfBertForMaskedLM
)
from collections import Counter, OrderedDict
from typing import List

torch.manual_seed(0)

env: XLA_PYTHON_CLIENT_PREALLOCATE=false


  from .autonotebook import tqdm as notebook_tqdm


<torch._C.Generator at 0x7fe5ae9d94b0>

In [2]:
def convert_sentence_to_tokens_and_target_idx(sent: str, tokenizer):
    pre, target, post = sent.split("***")
    if "mask" in target.lower():
        target = ["[MASK]"]
    else:
        target = tokenizer.tokenize(target)
    tokens = ["[CLS]"] + tokenizer.tokenize(pre)
    target_idx = len(tokens)
    # print(target_idx)
    tokens += target + tokenizer.tokenize(post) + ["[SEP]"]
    return tokens, target_idx

In [3]:
# Establish data
# model_id = "google/bert_uncased_L-2_H-128_A-2"
model_id = "google/bert_uncased_L-6_H-128_A-2"
config = BertConfig.from_pretrained(pretrained_model_name_or_path="../../brunoflow/models/bert/config-tiny.json")
tokenizer = BertTokenizerFast.from_pretrained(model_id)
text = "a 1770s map of philadelphia 's naval defenses ***mask*** a fort on the island , but it is unidentified ."
good_word = "shows"
bad_word = "show"
word_ids = tokenizer.convert_tokens_to_ids([good_word, bad_word])

# tokenize text and pass into model
tokens, target_idx = convert_sentence_to_tokens_and_target_idx(text, tokenizer)
input_ids = np.expand_dims(tokenizer.convert_tokens_to_ids(tokens), axis=0)
jax_input_ids = bf.Node(jnp.array(input_ids, dtype=int), name="inputs")

print(input_ids, tokens)
# tokens = tokenizer(text, return_tensors="pt", padding=True)
# input_ids = tokens["input_ids"]
# jax_input_ids = bf.Node(jnp.array(input_ids.numpy(), dtype=int), name="inputs")
# print(input_ids, input_ids.shape)

2023-01-11 11:42:18.562355: E external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:267] failed call to cuInit: CUDA_ERROR_UNKNOWN: unknown error


[[  101  1037 17711  2015  4949  1997  4407  1005  1055  3987 13345   103
   1037  3481  2006  1996  2479  1010  2021  2009  2003 20293  1012   102]] ['[CLS]', 'a', '1770', '##s', 'map', 'of', 'philadelphia', "'", 's', 'naval', 'defenses', '[MASK]', 'a', 'fort', 'on', 'the', 'island', ',', 'but', 'it', 'is', 'unidentified', '.', '[SEP]']


In [4]:
# Create BfBertForMaskedLM model
config = BertConfig.from_pretrained(pretrained_model_name_or_path="../../brunoflow/models/bert/config-tiny.json")
bf_model = BfBertForMaskedLM.from_pretrained(model_id)

# Visualize output of forward pass of BfBertEmbeddings
bf_model.train(False)
out_bf = bf_model(input_ids=jax_input_ids).logits # shape = (vs, seq_len)
qoi = out_bf[:, target_idx] # shape = (1, vs)
qoi = qoi[:, word_ids[0]] - qoi[:, word_ids[1]] # shape = (1,)
# out_bf.visualize(collapse_to_modules=True)
# print(bf_embs)
bf_model.train(True)

qoi.backprop(values_to_compute=("max_grad",))

  bf.Parameter(jnp.zeros(self.position_ids.shape, dtype=jnp.int64), name="position_ids"),




Some weights of the model checkpoint at google/bert_uncased_L-6_H-128_A-2 were not used when initializing BfBertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BfBertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BfBertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).






In [5]:
qoi.visualize()




In [6]:
def find_matching_nodes(root: bf.Node, name: str):
    def _find_matching_nodes(root: bf.Node, name: str, visited=set()):
        assert isinstance(root, bf.Node), f"root input must be a Node, instead received {root}"
        if root in visited:
            return []
            
        matching_nodes = []
        if root.name is not None and name in root.name:
            matching_nodes.append(root)
            # return [root]
        for inp in root.inputs:
            if isinstance(inp, bf.Node):
                matching_nodes_in_subtree = _find_matching_nodes(inp, name, visited=visited)
                visited.add(inp)
                if matching_nodes_in_subtree:
                    matching_nodes = matching_nodes_in_subtree + matching_nodes
        
        return matching_nodes
    return _find_matching_nodes(root, name, visited=set())

### Does BERT use the skip or self-attention mechanism more?

In [7]:
input_to_bert_attn_nodes = find_matching_nodes(out_bf[0], "input to bertattention")


In [8]:
print(len(input_to_bert_attn_nodes)) # should match number of bert layers
print(input_to_bert_attn_nodes[-1].shape) # (1, seq_len, hidden_sz)
layer0_bert_attn_input = input_to_bert_attn_nodes[0]
print("names of parent nodes of input to bert attn:", len(layer0_bert_attn_input.get_parents()), [p.name for p in layer0_bert_attn_input.get_parents()])

6
(1, 24, 128)
names of parent nodes of input to bert attn: 4 ['matmul', 'matmul', 'matmul', 'combine self_attention_output and bert attention input 8788907714527']


In [9]:
def summarize_max_grad_parents_bert_attn_input(input_to_bert_attn_node: bf.Node, tokens: List[str]):
    # Number of hidden units corresponding to each max grad parent option (for the input to bert attention Node)
    count_hidden_unit_max_grad_parents: OrderedDict = OrderedDict({
        tokens[i]: [
            (k.name, v) for k,v in Counter(input_to_bert_attn_node.get_max_grad_parent()[0, i]).items()
        ] for i in range(len(input_to_bert_attn_node.get_max_grad_parent()[0]))
    }) # keys are tokens, values are a list of the counts of # emb units for that token which have each of the possible max grad parents

    skip_and_attn_max_grads_per_word = []
    for i in range(len(input_to_bert_attn_node.get_max_grad_parent()[0])): # each word
        max_grad_parent_for_emb = input_to_bert_attn_node.get_max_grad_parent()[0, i] # shape = (emb_sz,)
        skip_max_grad = 0
        attn_max_grad = 0
        for j in range(len(max_grad_parent_for_emb)):
            emb_unit_max_grad_val = input_to_bert_attn_node.max_grad_of_output_wrt_node[0][0][i][j]
            emb_unit_max_grad_parent = input_to_bert_attn_node.max_grad_of_output_wrt_node[1][0][i][j]
            if emb_unit_max_grad_parent.name == "matmul":
                attn_max_grad += emb_unit_max_grad_val
            elif "combine self_attention_output and bert attention input" in emb_unit_max_grad_parent.name:
                skip_max_grad += emb_unit_max_grad_val
            else:
                raise ValueError(f"uhoh! received an unexpected parent, {emb_unit_max_grad_parent}")
        skip_and_attn_max_grads_per_word.append((skip_max_grad, attn_max_grad))

    grad_diff_between_skip_and_attention = OrderedDict({tokens[i]: skip_and_attn_max_grads_per_word[i][0] - skip_and_attn_max_grads_per_word[i][1] for i in range(len(skip_and_attn_max_grads_per_word))})

    return count_hidden_unit_max_grad_parents, grad_diff_between_skip_and_attention

In [10]:
from pprint import PrettyPrinter
p = PrettyPrinter()
for layer in input_to_bert_attn_nodes:
    counts, grads = summarize_max_grad_parents_bert_attn_input(layer, tokens)
    p.pprint(counts)
    p.pprint(grads)
    print()

OrderedDict([('[CLS]',
              [('combine self_attention_output and bert attention input '
                '8788907714527',
                128)]),
             ('a',
              [('combine self_attention_output and bert attention input '
                '8788907714527',
                128)]),
             ('1770',
              [('combine self_attention_output and bert attention input '
                '8788907714527',
                110),
               ('matmul', 17),
               ('matmul', 1)]),
             ('##s',
              [('combine self_attention_output and bert attention input '
                '8788907714527',
                127),
               ('matmul', 1)]),
             ('map',
              [('combine self_attention_output and bert attention input '
                '8788907714527',
                128)]),
             ('of',
              [('combine self_attention_output and bert attention input '
                '8788907714527',
                128)]

### Does the gradient go more through the LM prediction head or the bert encoder?

In [36]:
word_embs_nodes = find_matching_nodes(qoi, "emb weights (30522")
assert len(word_embs_nodes) == 1
word_embs_node = word_embs_nodes[0]
word_embs_per_token_node = word_embs_node[input_ids]
word_embs_per_token_node.max_grad_of_output_wrt_node = (word_embs_node.max_grad_of_output_wrt_node[0][input_ids], word_embs_node.max_grad_of_output_wrt_node[1][input_ids])
word_embs_per_token_node.max_neg_grad_of_output_wrt_node = (word_embs_node.max_neg_grad_of_output_wrt_node[0][input_ids], word_embs_node.max_neg_grad_of_output_wrt_node[1][input_ids])
word_embs_per_token_node.parents = word_embs_node.parents

In [43]:
word_embs_per_token_node.shape

(1, 24, 128)

In [38]:
[p.name for p in word_embs_per_token_node.get_parents()]

['transpose', 'get_embedding']

In [39]:
def summarize_max_grad_word_embs(word_embs_node: bf.Node, tokens: List[str]):
    # Number of hidden units corresponding to each max grad parent option (for the input to bert attention Node)
    count_hidden_unit_max_grad_parents: OrderedDict = OrderedDict({
        tokens[i]: [
            (k.name, v) for k,v in Counter(word_embs_node.get_max_grad_parent()[0, i]).items()
        ] for i in range(len(word_embs_node.get_max_grad_parent()[0]))
    }) # keys are tokens, values are a list of the counts of # emb units for that token which have each of the possible max grad parents

    max_grad_buckets_for_all_words = []
    parent_names = [p.name for p in word_embs_node.get_parents()]
    for i in range(len(word_embs_node.get_max_grad_parent()[0])): # each word
        max_grad_parent_for_emb = word_embs_node.get_max_grad_parent()[0, i] # shape = (emb_sz,)
        max_grad_buckets = dict.fromkeys(parent_names, 0.)
        for j in range(len(max_grad_parent_for_emb)):
            emb_unit_max_grad_val = word_embs_node.max_grad_of_output_wrt_node[0][0][i][j]
            emb_unit_max_grad_parent = word_embs_node.max_grad_of_output_wrt_node[1][0][i][j]
            max_grad_buckets[emb_unit_max_grad_parent.name] += emb_unit_max_grad_val

        max_grad_buckets_for_all_words.append((max_grad_buckets[parent_names[0]], max_grad_buckets[parent_names[1]]))

    grad_diff_between_skip_and_attention = OrderedDict({**{"parent_names": parent_names}, **{tokens[i]: max_grad_buckets_for_all_words[i][0] - max_grad_buckets_for_all_words[i][1] for i in range(len(max_grad_buckets_for_all_words))}})

    return count_hidden_unit_max_grad_parents, grad_diff_between_skip_and_attention

In [40]:
summarize_max_grad_word_embs(word_embs_per_token_node, tokens)

(OrderedDict([('[CLS]', [('get_embedding', 128)]),
              ('a', [('get_embedding', 128)]),
              ('1770', [('get_embedding', 128)]),
              ('##s', [('get_embedding', 128)]),
              ('map', [('get_embedding', 128)]),
              ('of', [('get_embedding', 128)]),
              ('philadelphia', [('get_embedding', 128)]),
              ("'", [('get_embedding', 128)]),
              ('s', [('get_embedding', 128)]),
              ('naval', [('get_embedding', 128)]),
              ('defenses', [('get_embedding', 128)]),
              ('[MASK]', [('get_embedding', 128)]),
              ('fort', [('get_embedding', 128)]),
              ('on', [('get_embedding', 128)]),
              ('the', [('get_embedding', 128)]),
              ('island', [('get_embedding', 128)]),
              (',', [('get_embedding', 128)]),
              ('but', [('get_embedding', 128)]),
              ('it', [('get_embedding', 128)]),
              ('is', [('get_embedding', 128)]),
      

Conclusion: grad entirely goes through the bert encoder

### How does gradient travel through key/value/query structures per layer?

In [51]:
# Distinguish the matmul parents
for input_to_bert_attn_layer in input_to_bert_attn_nodes:
    for parent in input_to_bert_attn_layer.get_parents():
        curr_parent = parent
        if curr_parent.name == "matmul":
            while "bertselfattention" not in curr_parent.name:
                assert len(curr_parent.get_parents()) == 1
                curr_parent = curr_parent.get_parents()[0]
            parent.name = f"matmul ({curr_parent.name})"


In [52]:
input_to_bert_attn_nodes[0].get_parents()

[node(name: matmul (bertselfattention key), val: [[[ 0.6101066  -1.758796   -0.9247101  ... -1.3825     -1.5679675
     0.36370578]
   [-0.11861765 -2.4003637  -0.5975313  ...  2.406166    0.49865592
     0.96619177]
   [ 0.15247516  1.4226599   1.0073868  ... -0.13554578  0.38859507
     0.127669  ]
   ...
   [-0.64928347 -0.44245568  0.5023466  ... -0.23630762  1.103121
     1.1080121 ]
   [-1.3285136  -1.8599983  -0.3786667  ...  0.14449848  0.48642528
    -0.21755914]
   [-0.4534198  -2.4247804  -1.1860352  ...  0.74313396 -0.37907493
     0.11684031]]], grad: [[[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   ...
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]]),
 node(name: matmul (bertselfattention value), val: [[[-1.4332561  -1.8980492  -0.8303858  ...  0.08597035 -1.1728458
    -1.1266255 ]
   [-1.1197389  -0.8480332   0.42633006 ...  0.01381378 -0.81014633
    -0.21485434]
   [-0.6148334   0.27644622 -1.0346075

In [53]:
def summarize_max_grad_kvq(input_to_bert_attn_node: bf.Node, tokens: List[str]):
    # Number of hidden units corresponding to each max grad parent option (for the input to bert attention Node)
    count_hidden_unit_max_grad_parents: OrderedDict = OrderedDict({
        tokens[i]: [
            (k.name, v) for k,v in Counter(input_to_bert_attn_node.get_max_grad_parent()[0, i]).items()
        ] for i in range(len(input_to_bert_attn_node.get_max_grad_parent()[0]))
    }) # keys are tokens, values are a list of the counts of # emb units for that token which have each of the possible max grad parents

    max_grad_buckets_for_all_words = []
    parent_names = [p.name for p in input_to_bert_attn_node.get_parents()]
    for i in range(len(input_to_bert_attn_node.get_max_grad_parent()[0])): # each word
        max_grad_parent_for_emb = input_to_bert_attn_node.get_max_grad_parent()[0, i] # shape = (emb_sz,)
        max_grad_buckets = dict.fromkeys(parent_names, 0.)
        for j in range(len(max_grad_parent_for_emb)):
            emb_unit_max_grad_val = input_to_bert_attn_node.max_grad_of_output_wrt_node[0][0][i][j]
            emb_unit_max_grad_parent = input_to_bert_attn_node.max_grad_of_output_wrt_node[1][0][i][j]
            max_grad_buckets[emb_unit_max_grad_parent.name] += emb_unit_max_grad_val

        max_grad_buckets_for_all_words.append(max_grad_buckets)

    grad_diff_between_skip_and_attention = OrderedDict({tokens[i]: max_grad_buckets_for_all_words[i] for i in range(len(max_grad_buckets_for_all_words))})

    return count_hidden_unit_max_grad_parents, grad_diff_between_skip_and_attention

In [57]:
summarize_max_grad_kvq(input_to_bert_attn_nodes[-3], tokens)

(OrderedDict([('[CLS]',
               [('combine self_attention_output and bert attention input 8788907682070',
                 128)]),
              ('a',
               [('combine self_attention_output and bert attention input 8788907682070',
                 80),
                ('matmul (bertselfattention key)', 25),
                ('matmul (bertselfattention value)', 23)]),
              ('1770',
               [('matmul (bertselfattention key)', 101),
                ('combine self_attention_output and bert attention input 8788907682070',
                 25),
                ('matmul (bertselfattention value)', 2)]),
              ('##s',
               [('matmul (bertselfattention value)', 23),
                ('combine self_attention_output and bert attention input 8788907682070',
                 101),
                ('matmul (bertselfattention key)', 3),
                ('matmul (bertselfattention mixed_query_layer)', 1)]),
              ('map',
               [('combine

In [None]:
# summarize_max_grad_parents_bert_attn_input(layer0_bert_attn_input, tokens)

({'[CLS]': [('matmul', 53),
   ('matmul', 74),
   ('combine self_attention_output and bert attention input 8758245107119',
    1)],
  'a': [('combine self_attention_output and bert attention input 8758245107119',
    128)],
  '1770': [('matmul', 31), ('matmul', 97)],
  '##s': [('matmul', 31), ('matmul', 97)],
  'map': [('matmul', 68), ('matmul', 60)],
  'of': [('matmul', 29),
   ('matmul', 87),
   ('combine self_attention_output and bert attention input 8758245107119',
    12)],
  'philadelphia': [('matmul', 85), ('matmul', 43)],
  "'": [('matmul', 25),
   ('matmul', 84),
   ('combine self_attention_output and bert attention input 8758245107119',
    19)],
  's': [('matmul', 55),
   ('matmul', 56),
   ('combine self_attention_output and bert attention input 8758245107119',
    17)],
  'naval': [('combine self_attention_output and bert attention input 8758245107119',
    104),
   ('matmul', 22),
   ('matmul', 2)],
  'defenses': [('combine self_attention_output and bert attention input 8

In [None]:
# summarize_max_grad_parents_bert_attn_input(layer1_bert_attn_input, tokens)

({'[CLS]': [('matmul', 79), ('matmul', 49)],
  'a': [('matmul', 96), ('matmul', 32)],
  '1770': [('matmul', 77), ('matmul', 51)],
  '##s': [('matmul', 53), ('matmul', 75)],
  'map': [('matmul', 49), ('matmul', 79)],
  'of': [('matmul', 89), ('matmul', 39)],
  'philadelphia': [('matmul', 76), ('matmul', 52)],
  "'": [('matmul', 98), ('matmul', 30)],
  's': [('matmul', 78), ('matmul', 50)],
  'naval': [('matmul', 33), ('matmul', 95)],
  'defenses': [('matmul', 93), ('matmul', 35)],
  '[MASK]': [('matmul', 71),
   ('combine self_attention_output and bert attention input 8758245089735',
    57)],
  'fort': [('matmul', 73), ('matmul', 55)],
  'on': [('matmul', 72), ('matmul', 56)],
  'the': [('matmul', 94), ('matmul', 34)],
  'island': [('matmul', 77), ('matmul', 51)],
  ',': [('matmul', 72), ('matmul', 56)],
  'but': [('matmul', 77), ('matmul', 51)],
  'it': [('matmul', 92), ('matmul', 36)],
  'is': [('matmul', 72), ('matmul', 56)],
  'unidentified': [('matmul', 79), ('matmul', 49)],
  '.'

In [None]:
# Number of hidden units corresponding to each max grad parent option (for the input to bert attention Node)
{
    tokens[i]: [
        (k.name, v) for k,v in Counter(layer0_bert_attn_input.get_max_grad_parent()[0, i]).items()
    ] for i in range(len(layer0_bert_attn_input.get_max_grad_parent()[0]))
}

{'[CLS]': [('matmul', 53),
  ('matmul', 74),
  ('combine self_attention_output and bert attention input 8758245107119', 1)],
 'a': [('combine self_attention_output and bert attention input 8758245107119',
   128)],
 '1770': [('matmul', 31), ('matmul', 97)],
 '##s': [('matmul', 31), ('matmul', 97)],
 'map': [('matmul', 68), ('matmul', 60)],
 'of': [('matmul', 29),
  ('matmul', 87),
  ('combine self_attention_output and bert attention input 8758245107119',
   12)],
 'philadelphia': [('matmul', 85), ('matmul', 43)],
 "'": [('matmul', 25),
  ('matmul', 84),
  ('combine self_attention_output and bert attention input 8758245107119',
   19)],
 's': [('matmul', 55),
  ('matmul', 56),
  ('combine self_attention_output and bert attention input 8758245107119',
   17)],
 'naval': [('combine self_attention_output and bert attention input 8758245107119',
   104),
  ('matmul', 22),
  ('matmul', 2)],
 'defenses': [('combine self_attention_output and bert attention input 8758245107119',
   120),
  ('ma

In [None]:
# Number of hidden units corresponding to each max grad parent option (for the input to bert attention Node)
skip_and_attn_max_grads_per_word = []
for i in range(len(layer0_bert_attn_input.get_max_grad_parent()[0])): # each word
    max_grad_parent_for_emb = layer0_bert_attn_input.get_max_grad_parent()[0, i] # shape = (emb_sz,)
    skip_max_grad = 0
    attn_max_grad = 0
    for j in range(len(max_grad_parent_for_emb)):
        emb_unit_max_grad_val = layer0_bert_attn_input.max_grad_of_output_wrt_node[0][0][i][j]
        emb_unit_max_grad_parent = layer0_bert_attn_input.max_grad_of_output_wrt_node[1][0][i][j]
        if emb_unit_max_grad_parent.name == "matmul":
            attn_max_grad += emb_unit_max_grad_val
        elif "combine self_attention_output and bert attention input" in emb_unit_max_grad_parent.name:
            skip_max_grad += emb_unit_max_grad_val
        else:
            raise ValueError(f"uhoh! received an unexpected parent, {emb_unit_max_grad_parent}")
    skip_and_attn_max_grads_per_word.append((skip_max_grad, attn_max_grad))

{tokens[i]: skip_and_attn_max_grads_per_word[i][0] - skip_and_attn_max_grads_per_word[i][1] for i in range(len(skip_and_attn_max_grads_per_word))}
# print({tokens[i]: skip_and_attn_max_grads_per_word[i][0] for i in range(len(skip_and_attn_max_grads_per_word))})

# {
#     tokenizer.convert_ids_to_tokens(input_ids[i]): [
#         (k.name, v) for k,v in Counter(layer0_bert_attn_input.get_max_grad_parent()[0, i]).items()
#     ] for i in range(len(layer0_bert_attn_input.get_max_grad_parent()[0]))
# }

{'[CLS]': DeviceArray(-28.904734, dtype=float32),
 'a': DeviceArray(79.73425, dtype=float32),
 '1770': DeviceArray(-7.1745143, dtype=float32),
 '##s': DeviceArray(-2.7482662, dtype=float32),
 'map': DeviceArray(-17.659466, dtype=float32),
 'of': DeviceArray(-4.1306715, dtype=float32),
 'philadelphia': DeviceArray(-19.473045, dtype=float32),
 "'": DeviceArray(-6.3064027, dtype=float32),
 's': DeviceArray(-6.633123, dtype=float32),
 'naval': DeviceArray(12.5280485, dtype=float32),
 'defenses': DeviceArray(66.57097, dtype=float32),
 '[MASK]': DeviceArray(289.098, dtype=float32),
 'fort': DeviceArray(26.292835, dtype=float32),
 'on': DeviceArray(-5.327896, dtype=float32),
 'the': DeviceArray(-3.2386632, dtype=float32),
 'island': DeviceArray(-9.03645, dtype=float32),
 ',': DeviceArray(-3.7696393, dtype=float32),
 'but': DeviceArray(-5.7957606, dtype=float32),
 'it': DeviceArray(-2.9873562, dtype=float32),
 'is': DeviceArray(-7.618865, dtype=float32),
 'unidentified': DeviceArray(-6.953268,

In [None]:
out_bf.visualize()