In [2]:
from models.bert import MLXBertForSequenceClassification
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from mlx.utils import tree_unflatten
import mlx.core as mx
import mlx.nn as nn
import numpy as np

## Load transformers Bert Model

In [3]:
pretrained_model_name = "nlptown/bert-base-multilingual-uncased-sentiment"
og_model = AutoModelForSequenceClassification.from_pretrained(pretrained_model_name) 
og_model.eval()

config = og_model.config
og_state = og_model.state_dict()

config.json:   0%|          | 0.00/953 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/669M [00:00<?, ?B/s]

In [4]:
converted_weights = {k: mx.array(v.numpy()) for k, v in og_state.items() if k is not None}
print(converted_weights.keys())

np.savez("converted_bert_clf.npz", **converted_weights)

dict_keys(['bert.embeddings.word_embeddings.weight', 'bert.embeddings.position_embeddings.weight', 'bert.embeddings.token_type_embeddings.weight', 'bert.embeddings.LayerNorm.weight', 'bert.embeddings.LayerNorm.bias', 'bert.encoder.layer.0.attention.self.query.weight', 'bert.encoder.layer.0.attention.self.query.bias', 'bert.encoder.layer.0.attention.self.key.weight', 'bert.encoder.layer.0.attention.self.key.bias', 'bert.encoder.layer.0.attention.self.value.weight', 'bert.encoder.layer.0.attention.self.value.bias', 'bert.encoder.layer.0.attention.output.dense.weight', 'bert.encoder.layer.0.attention.output.dense.bias', 'bert.encoder.layer.0.attention.output.LayerNorm.weight', 'bert.encoder.layer.0.attention.output.LayerNorm.bias', 'bert.encoder.layer.0.intermediate.dense.weight', 'bert.encoder.layer.0.intermediate.dense.bias', 'bert.encoder.layer.0.output.dense.weight', 'bert.encoder.layer.0.output.dense.bias', 'bert.encoder.layer.0.output.LayerNorm.weight', 'bert.encoder.layer.0.output.

In [5]:
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name) 

tokenizer_config.json:   0%|          | 0.00/39.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/872k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

In [6]:
## ORIGINAL OUTPUT
encoded = tokenizer("hello", return_tensors="pt")
print(encoded)

{'input_ids': tensor([[  101, 29155,   102]]), 'token_type_ids': tensor([[0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1]])}


In [11]:
## Inference Test
og_model_outputs = og_model(
	encoded["input_ids"],
	attention_mask = encoded["attention_mask"],
	token_type_ids = encoded["token_type_ids"],
	return_dict = False
)

In [12]:
print(og_model_outputs)

(tensor([[-0.4207, -1.1027, -0.2918,  0.1923,  1.3280]],
       grad_fn=<AddmmBackward0>),)


## Load MLX Model

In [13]:
model = MLXBertForSequenceClassification(config)
print(model)

MLXBertForSequenceClassification(
  (bert): MLXBertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(105879, 768)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm(768, eps=1e-12, affine=True)
      (dropout): Dropout(p=0.09999999999999998)
    )
    (encoder): BertEncoder(
      (layer.0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(input_dims=768, output_dims=768, bias=True)
            (key): Linear(input_dims=768, output_dims=768, bias=True)
            (value): Linear(input_dims=768, output_dims=768, bias=True)
            (dropout): Dropout(p=0.09999999999999998)
          )
          (output): BertSelfOutput(
            (dense): Linear(input_dims=768, output_dims=768, bias=True)
            (LayerNorm): LayerNorm(768, eps=1e-12, affine=True)
            (dropout): Dropout(p=0.09999999999999998)
          )
   

In [14]:
loaded_weights = mx.load("converted_bert_clf.npz")

model.update(tree_unflatten(list(loaded_weights.items())))

In [15]:
model.eval()

In [16]:
mx.eval(model.parameters())

In [18]:
encoded = tokenizer("hello", return_tensors="np")
print(encoded)

input_ids = mx.array(encoded["input_ids"])
token_type_ids = mx.array(encoded["token_type_ids"])
attention_mask = mx.array(encoded["attention_mask"])

{'input_ids': array([[  101, 29155,   102]]), 'token_type_ids': array([[0, 0, 0]]), 'attention_mask': array([[1, 1, 1]])}


In [20]:
## Inference Test
model_outputs = model(
	input_ids,
	attention_mask = attention_mask,
	token_type_ids = token_type_ids
)

In [22]:
mlx_output = model_outputs

In [23]:
print(og_model_outputs)
print(mlx_output)

(tensor([[-0.4207, -1.1027, -0.2918,  0.1923,  1.3280]],
       grad_fn=<AddmmBackward0>),)
(array([[-0.420726, -1.10273, -0.291769, 0.192265, 1.32797]], dtype=float32),)
