In [1]:
from models.bert import MLXBertModel
from transformers import BertModel, AutoTokenizer
from mlx.utils import tree_unflatten
import mlx.core as mx
import mlx.nn as nn
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm
  warn("The installed version of bitsandbytes was compiled without GPU support. "


'NoneType' object has no attribute 'cadam32bit_grad_fp32'


## Load transformers Bert Model

In [2]:
og_model = BertModel.from_pretrained("bert-base-uncased") 
og_model.eval()

config = og_model.config
og_state = og_model.state_dict()

In [3]:
converted_weights = {k: mx.array(v.numpy()) for k, v in og_state.items() if k is not None}
print(converted_weights.keys())

np.savez("converted_bert.npz", **converted_weights)

dict_keys(['embeddings.word_embeddings.weight', 'embeddings.position_embeddings.weight', 'embeddings.token_type_embeddings.weight', 'embeddings.LayerNorm.weight', 'embeddings.LayerNorm.bias', 'encoder.layer.0.attention.self.query.weight', 'encoder.layer.0.attention.self.query.bias', 'encoder.layer.0.attention.self.key.weight', 'encoder.layer.0.attention.self.key.bias', 'encoder.layer.0.attention.self.value.weight', 'encoder.layer.0.attention.self.value.bias', 'encoder.layer.0.attention.output.dense.weight', 'encoder.layer.0.attention.output.dense.bias', 'encoder.layer.0.attention.output.LayerNorm.weight', 'encoder.layer.0.attention.output.LayerNorm.bias', 'encoder.layer.0.intermediate.dense.weight', 'encoder.layer.0.intermediate.dense.bias', 'encoder.layer.0.output.dense.weight', 'encoder.layer.0.output.dense.bias', 'encoder.layer.0.output.LayerNorm.weight', 'encoder.layer.0.output.LayerNorm.bias', 'encoder.layer.1.attention.self.query.weight', 'encoder.layer.1.attention.self.query.bia

In [4]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") 

In [5]:
## ORIGINAL OUTPUT
encoded = tokenizer("hello", return_tensors="pt")
print(encoded)

{'input_ids': tensor([[ 101, 7592,  102]]), 'token_type_ids': tensor([[0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1]])}


In [6]:
## Inference Test
og_model_outputs = og_model(
	encoded["input_ids"],
	attention_mask = encoded["attention_mask"],
	token_type_ids = encoded["token_type_ids"]
)

In [7]:
print(og_model_outputs[0].shape)
print(og_model_outputs[1].shape)

torch.Size([1, 3, 768])
torch.Size([1, 768])


## Load MLX Model

In [8]:
model = MLXBertModel(config)
print(model)

MLXBertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm(768, eps=1e-12, affine=True)
    (dropout): Dropout(p=0.09999999999999998)
  )
  (encoder): BertEncoder(
    (layer.0): BertLayer(
      (attention): BertAttention(
        (self): BertSelfAttention(
          (query): Linear(input_dims=768, output_dims=768, bias=True)
          (key): Linear(input_dims=768, output_dims=768, bias=True)
          (value): Linear(input_dims=768, output_dims=768, bias=True)
          (dropout): Dropout(p=0.09999999999999998)
        )
        (output): BertSelfOutput(
          (dense): Linear(input_dims=768, output_dims=768, bias=True)
          (LayerNorm): LayerNorm(768, eps=1e-12, affine=True)
          (dropout): Dropout(p=0.09999999999999998)
        )
      )
      (intermediate): BertIntermediate(
        (dense): Linear(input_dims=768, ou

In [9]:
loaded_weights = mx.load("converted_bert.npz")

model.update(tree_unflatten(list(loaded_weights.items())))

In [10]:
model.eval()

In [11]:
mx.eval(model.parameters())

In [12]:
# Weight Comparison
print(type(og_model.encoder.layer[0].attention.self.query.weight))
print(type(model.encoder.layer[0].attention.self.query.weight))

## Check Weight Values
print("transformers", og_model.encoder.layer[0].attention.self.query.weight[0][:10])
print("transformers", og_model.encoder.layer[0].attention.self.query.weight.dtype)
print("MLX", model.encoder.layer[0].attention.self.query.weight[0][:10])
print("MLX", model.encoder.layer[0].attention.self.query.weight.dtype)

<class 'torch.nn.parameter.Parameter'>
<class 'mlx.core.array'>
transformers tensor([-0.0164,  0.0261, -0.0263,  0.0360, -0.0203,  0.0531,  0.0137,  0.0225,
         0.0029, -0.0002], grad_fn=<SliceBackward0>)
transformers torch.float32
MLX array([-0.0164057, 0.0260757, -0.026277, ..., 0.0225361, 0.00293946, -0.000168063], dtype=float32)
MLX float32


In [13]:
encoded = tokenizer("hello", return_tensors="np")
print(encoded)

input_ids = mx.array(encoded["input_ids"])
token_type_ids = mx.array(encoded["token_type_ids"])
attention_mask = mx.array(encoded["attention_mask"])

{'input_ids': array([[ 101, 7592,  102]]), 'token_type_ids': array([[0, 0, 0]]), 'attention_mask': array([[1, 1, 1]])}


In [14]:
print(input_ids)
print(type(model.embeddings.word_embeddings.weight))
model.embeddings.word_embeddings.weight[input_ids]

array([[101, 7592, 102]], dtype=int64)
<class 'mlx.core.array'>


array([[[0.0136303, -0.0264904, -0.0235031, ..., 0.00868047, 0.00713399, 0.0151473],
        [-0.00431649, -0.0330471, -0.0217315, ..., -0.0424661, -0.0126787, -0.0388732],
        [-0.0145212, -0.00996149, 0.00602628, ..., -0.0250345, 0.00463789, -0.00153777]]], dtype=float32)

In [15]:
## Inference Test
model_outputs = model(
	input_ids,
	attention_mask = attention_mask,
	token_type_ids = token_type_ids
)

In [16]:
sequence_output, pooled_output = model_outputs

In [17]:
print(sequence_output.shape)
print(pooled_output.shape)

[1, 3, 768]
[1, 768]


## Output Comparison

In [18]:
print(sequence_output[0,0,:10])
print(og_model_outputs[0][0,0,:10])

array([-0.306098, 0.262229, -0.189619, ..., 0.379691, -0.125868, -0.114839], dtype=float32)
tensor([-0.3061,  0.2622, -0.1896, -0.1443, -0.1412, -0.1420,  0.1758,  0.3797,
        -0.1259, -0.1148], grad_fn=<SliceBackward0>)


In [19]:
print(pooled_output[:,0])
print(og_model_outputs[1][:,0])

array([-0.773578], dtype=float32)
tensor([-0.7736], grad_fn=<SelectBackward0>)
