In [14]:
from models.bert import MLXBertModel
from transformers import BertModel, AutoTokenizer
from mlx.utils import tree_unflatten
import mlx.core as mx
import mlx.nn as nn

## Load transformers Bert Model

In [2]:
og_model = BertModel.from_pretrained("bert-base-uncased") 
config = og_model.config
og_state = og_model.state_dict()

In [3]:
converted_weights = {k: mx.array(v.numpy()) for k, v in og_state.items() if k is not None}
print(converted_weights.keys())

dict_keys(['embeddings.word_embeddings.weight', 'embeddings.position_embeddings.weight', 'embeddings.token_type_embeddings.weight', 'embeddings.LayerNorm.weight', 'embeddings.LayerNorm.bias', 'encoder.layer.0.attention.self.query.weight', 'encoder.layer.0.attention.self.query.bias', 'encoder.layer.0.attention.self.key.weight', 'encoder.layer.0.attention.self.key.bias', 'encoder.layer.0.attention.self.value.weight', 'encoder.layer.0.attention.self.value.bias', 'encoder.layer.0.attention.output.dense.weight', 'encoder.layer.0.attention.output.dense.bias', 'encoder.layer.0.attention.output.LayerNorm.weight', 'encoder.layer.0.attention.output.LayerNorm.bias', 'encoder.layer.0.intermediate.dense.weight', 'encoder.layer.0.intermediate.dense.bias', 'encoder.layer.0.output.dense.weight', 'encoder.layer.0.output.dense.bias', 'encoder.layer.0.output.LayerNorm.weight', 'encoder.layer.0.output.LayerNorm.bias', 'encoder.layer.1.attention.self.query.weight', 'encoder.layer.1.attention.self.query.bia

In [4]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") 

In [13]:
## ORIGINAL OUTPUT
encoded = tokenizer("hello", return_tensors="pt")
print(encoded)

{'input_ids': tensor([[ 101, 7592,  102]]), 'token_type_ids': tensor([[0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1]])}


In [15]:
## Inference Test
og_model_outputs = og_model(
	encoded["input_ids"],
	attention_mask = encoded["attention_mask"],
	token_type_ids = encoded["token_type_ids"]
)

In [19]:
print(og_model_outputs[0].shape)
print(og_model_outputs[1].shape)

torch.Size([1, 3, 768])
torch.Size([1, 768])


## Load MLX Model

In [5]:
model = MLXBertModel(config)
print(model)

MLXBertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm(768, eps=1e-12, affine=True)
    (dropout): Dropout(p=0.09999999999999998)
  )
  (encoder): BertEncoder(
    (layer.0): BertLayer(
      (attention): BertAttention(
        (self): BertSelfAttention(
          (query): Linear(input_dims=768, output_dims=768, bias=True)
          (key): Linear(input_dims=768, output_dims=768, bias=True)
          (value): Linear(input_dims=768, output_dims=768, bias=True)
          (dropout): Dropout(p=0.09999999999999998)
        )
        (output): BertSelfOutput(
          (dense): Linear(input_dims=768, output_dims=768, bias=True)
          (LayerNorm): LayerNorm(768, eps=1e-12, affine=True)
          (dropout): Dropout(p=0.09999999999999998)
        )
      )
      (intermediate): BertIntermediate(
        (dense): Linear(input_dims=768, ou

In [6]:
model.update(tree_unflatten(list(converted_weights.items())))

In [7]:
mx.eval(model.parameters())

In [8]:
encoded = tokenizer("hello", return_tensors="np")
print(encoded)

input_ids = mx.array(encoded["input_ids"])
token_type_ids = mx.array(encoded["token_type_ids"])
attention_mask = mx.array(encoded["attention_mask"])

{'input_ids': array([[ 101, 7592,  102]]), 'token_type_ids': array([[0, 0, 0]]), 'attention_mask': array([[1, 1, 1]])}


In [9]:
print(input_ids)
print(type(model.embeddings.word_embeddings.weight))
model.embeddings.word_embeddings.weight[input_ids]

array([[101, 7592, 102]], dtype=int64)
<class 'mlx.core.array'>


array([[[0.0136303, -0.0264904, -0.0235031, ..., 0.00868047, 0.00713399, 0.0151473],
        [-0.00431649, -0.0330471, -0.0217315, ..., -0.0424661, -0.0126787, -0.0388732],
        [-0.0145212, -0.00996149, 0.00602628, ..., -0.0250345, 0.00463789, -0.00153777]]], dtype=float32)

In [10]:
## Inference Test
model_outputs = model(
	input_ids,
	attention_mask = attention_mask,
	token_type_ids = token_type_ids
)

In [11]:
sequence_output, pooled_output = model_outputs

In [12]:
print(sequence_output.shape)
print(pooled_output.shape)

[1, 3, 768]
[1, 768]


In [29]:
print(pooled_output[:,0])
print(og_model_outputs[1][:,0])

array([-0.944804], dtype=float32)
tensor([-0.7736], grad_fn=<SelectBackward0>)
