In [1]:
import brunoflow as bf
import os
import torch
import jax.numpy as jnp
from pprint import PrettyPrinter
pp = PrettyPrinter()

In [2]:
model_path = os.path.join(os.environ["HOME"], "code/rycolab/brunoflow/models/bert/pytorch_model.bin")

In [3]:
model_dict = torch.load(model_path)
model_keys = list(model_dict.keys())
model_keys_to_shape = {k: v.shape for k, v in model_dict.items()}

In [10]:
# Deserialize from torch tensor to JAX nparray
[(type(v), v.dtype, v.shape, type(jnp.array(v.numpy())), jnp.array(v.numpy()).dtype, jnp.array(v.numpy()).shape) for v in model_dict.values()]

[(torch.Tensor,
  torch.float32,
  torch.Size([30522, 768]),
  jaxlib.xla_extension.DeviceArray,
  dtype('float32'),
  (30522, 768)),
 (torch.Tensor,
  torch.float32,
  torch.Size([512, 768]),
  jaxlib.xla_extension.DeviceArray,
  dtype('float32'),
  (512, 768)),
 (torch.Tensor,
  torch.float32,
  torch.Size([2, 768]),
  jaxlib.xla_extension.DeviceArray,
  dtype('float32'),
  (2, 768)),
 (torch.Tensor,
  torch.float32,
  torch.Size([768]),
  jaxlib.xla_extension.DeviceArray,
  dtype('float32'),
  (768,)),
 (torch.Tensor,
  torch.float32,
  torch.Size([768]),
  jaxlib.xla_extension.DeviceArray,
  dtype('float32'),
  (768,)),
 (torch.Tensor,
  torch.float32,
  torch.Size([768, 768]),
  jaxlib.xla_extension.DeviceArray,
  dtype('float32'),
  (768, 768)),
 (torch.Tensor,
  torch.float32,
  torch.Size([768]),
  jaxlib.xla_extension.DeviceArray,
  dtype('float32'),
  (768,)),
 (torch.Tensor,
  torch.float32,
  torch.Size([768, 768]),
  jaxlib.xla_extension.DeviceArray,
  dtype('float32'),
  

In [5]:
model_dict.keys()

odict_keys(['bert.embeddings.word_embeddings.weight', 'bert.embeddings.position_embeddings.weight', 'bert.embeddings.token_type_embeddings.weight', 'bert.embeddings.LayerNorm.gamma', 'bert.embeddings.LayerNorm.beta', 'bert.encoder.layer.0.attention.self.query.weight', 'bert.encoder.layer.0.attention.self.query.bias', 'bert.encoder.layer.0.attention.self.key.weight', 'bert.encoder.layer.0.attention.self.key.bias', 'bert.encoder.layer.0.attention.self.value.weight', 'bert.encoder.layer.0.attention.self.value.bias', 'bert.encoder.layer.0.attention.output.dense.weight', 'bert.encoder.layer.0.attention.output.dense.bias', 'bert.encoder.layer.0.attention.output.LayerNorm.gamma', 'bert.encoder.layer.0.attention.output.LayerNorm.beta', 'bert.encoder.layer.0.intermediate.dense.weight', 'bert.encoder.layer.0.intermediate.dense.bias', 'bert.encoder.layer.0.output.dense.weight', 'bert.encoder.layer.0.output.dense.bias', 'bert.encoder.layer.0.output.LayerNorm.gamma', 'bert.encoder.layer.0.output.La

In [4]:
# pp.pprint(model_keys)
model_dict['bert.embeddings.LayerNorm.gamma']



tensor([0.9261, 0.8851, 0.8581, 0.8617, 0.8937, 0.8969, 0.9297, 0.9137, 0.9371,
        0.8084, 0.7992, 0.8071, 0.9031, 0.8198, 0.9100, 0.8493, 0.8152, 0.8613,
        0.9142, 0.8652, 0.9234, 0.8672, 0.9008, 0.8684, 0.8440, 0.8990, 0.7891,
        0.9275, 0.8501, 0.8413, 0.9179, 0.8641, 0.9185, 0.9657, 0.8861, 0.8710,
        0.9103, 0.8739, 0.9133, 0.8880, 0.9130, 0.9374, 0.8823, 0.8622, 0.8812,
        0.8708, 0.8570, 0.9445, 0.9163, 0.9356, 0.9265, 0.8504, 0.9300, 0.3447,
        0.8650, 0.8197, 0.8722, 0.8566, 0.8939, 0.8051, 0.9007, 0.8483, 0.3870,
        0.8889, 0.8923, 0.8772, 0.8963, 0.9548, 0.8944, 0.8946, 0.9471, 0.9489,
        0.9349, 0.7814, 0.9255, 0.7943, 0.8806, 0.3857, 0.7900, 0.8478, 0.8886,
        0.9215, 0.9292, 0.8990, 0.7790, 0.8255, 0.8717, 0.8778, 0.9021, 0.9190,
        0.8605, 0.8762, 0.7084, 0.8599, 0.8981, 0.8092, 0.4021, 0.7917, 0.8923,
        0.9118, 0.9459, 0.9489, 0.8744, 0.8402, 0.8031, 0.2923, 0.9314, 0.9065,
        0.8852, 0.8115, 0.9090, 0.8948, 

In [4]:
pp.pprint(model_keys[:8])
print(f"Num keys: {len(model_keys)}")

['bert.embeddings.word_embeddings.weight',
 'bert.embeddings.position_embeddings.weight',
 'bert.embeddings.token_type_embeddings.weight',
 'bert.embeddings.LayerNorm.gamma',
 'bert.embeddings.LayerNorm.beta',
 'bert.encoder.layer.0.attention.self.query.weight',
 'bert.encoder.layer.0.attention.self.query.bias',
 'bert.encoder.layer.0.attention.self.key.weight']
Num keys: 207


In [5]:
pp.pprint({k: model_keys_to_shape[k] for k in list(model_keys_to_shape.keys())[:8]})

{'bert.embeddings.LayerNorm.beta': torch.Size([768]),
 'bert.embeddings.LayerNorm.gamma': torch.Size([768]),
 'bert.embeddings.position_embeddings.weight': torch.Size([512, 768]),
 'bert.embeddings.token_type_embeddings.weight': torch.Size([2, 768]),
 'bert.embeddings.word_embeddings.weight': torch.Size([30522, 768]),
 'bert.encoder.layer.0.attention.self.key.weight': torch.Size([768, 768]),
 'bert.encoder.layer.0.attention.self.query.bias': torch.Size([768]),
 'bert.encoder.layer.0.attention.self.query.weight': torch.Size([768, 768])}


In [58]:
def torch_tensor_to_jax(tensor: torch.Tensor):
    return jnp.array(tensor.numpy())

def torch_tensor_to_bf(tensor: torch.Tensor, name=None):
    return bf.Node(torch_tensor_to_jax(tensor), name=name)

def torch_bert_layer_to_bf_bert_layer(torch_model_dict, layer_name):
    torch_layer = torch_model_dict[layer_name]
    return torch_tensor_to_bf(torch_layer, name=layer_name)

jax_tensor = torch_tensor_to_jax(model_dict['bert.embeddings.token_type_embeddings.weight'])
bf_tensor = torch_tensor_to_bf(model_dict['bert.embeddings.token_type_embeddings.weight'])
bf_bert_layer = torch_bert_layer_to_bf_bert_layer(model_dict, 'bert.embeddings.token_type_embeddings.weight')

print(f"JAX tensor: {jax_tensor}, dtype={type(jax_tensor)}, shape={jax_tensor.shape}\n")
print(f"bf tensor: {bf_tensor}, dtype={type(bf_tensor)}, shape={bf_tensor.shape}\n")
print(f"bf bert layer: {bf_bert_layer}, dtype={type(bf_bert_layer)}, shape={bf_bert_layer.shape}\n")

JAX tensor: [[ 0.00043164  0.01098826  0.00370439 ... -0.00661185 -0.00336983
  -0.00864201]
 [ 0.00111319 -0.00299169 -0.00317028 ...  0.00474542 -0.0052443
  -0.01121742]], dtype=<class 'jaxlib.xla_extension.DeviceArray'>, shape=(2, 768)

bf tensor: node(name: None, val: [[ 0.00043164  0.01098826  0.00370439 ... -0.00661185 -0.00336983
  -0.00864201]
 [ 0.00111319 -0.00299169 -0.00317028 ...  0.00474542 -0.0052443
  -0.01121742]], grad: [[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]), dtype=<class 'brunoflow.ad.node.Node'>, shape=(2, 768)

bf bert layer: node(name: bert.embeddings.token_type_embeddings.weight, val: [[ 0.00043164  0.01098826  0.00370439 ... -0.00661185 -0.00336983
  -0.00864201]
 [ 0.00111319 -0.00299169 -0.00317028 ...  0.00474542 -0.0052443
  -0.01121742]], grad: [[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]), dtype=<class 'brunoflow.ad.node.Node'>, shape=(2, 768)



In [59]:
bf.Node(torch_tensor_to_jax(model_dict['bert.embeddings.token_type_embeddings.weight']))

node(name: None, val: [[ 0.00043164  0.01098826  0.00370439 ... -0.00661185 -0.00336983
  -0.00864201]
 [ 0.00111319 -0.00299169 -0.00317028 ...  0.00474542 -0.0052443
  -0.01121742]], grad: [[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]])

In [64]:
from transformers import BertForMaskedLM



In [66]:
model = BertForMaskedLM.from_pretrained("bert-base-uncased")


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
