# Goal: Download and convert the weights of LlaVA into MLX, and test the forward pass of this model on example data

In [2]:
import shutil
from pathlib import Path
import os


In [3]:
mlx_path = Path('mlx_model')

if not os.path.exists(mlx_path):
    os.makedirs(mlx_path)


In [4]:
import mlx.core as mx
from convert import get_model_path, fetch_from_hub, hf_repo


model_path = get_model_path(hf_repo)
model_config, model_weights, model_weight_files, config, tokenizer = fetch_from_hub(model_path)

  from .autonotebook import tqdm as notebook_tqdm
Fetching 12 files: 100%|██████████| 12/12 [00:00<00:00, 202950.19it/s]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [5]:
from utils import map_weights, should_keep_weight
do_convert = False
if do_convert:

    print("[INFO] Converting")
    mlx_weights = dict(map_weights(k, v) for (k, v) in model_weights.items())
    mlx_weights = {k: v for (k, v) in mlx_weights.items() if should_keep_weight(k)}
    print("[INFO] Saving")
    mx.savez(str(mlx_path / "weights.npz"), **mlx_weights)
    for fn in ["config.json", "merges.txt", "vocab.json", "preprocessor_config.json"]:
        if fn in os.listdir(model_path):
            shutil.copyfile(
                str(model_path / f"{fn}"),
                str(mlx_path / f"{fn}"),
            )


In [1]:
from llava import LlaVAConfig, LLMConfig, VisionConfig, ProjectionConfig, LlavaModel

llava_mlx_config = LlaVAConfig(
    llm_config=LLMConfig(
        model_type='vicuna',
        hidden_size=4096,
        num_hidden_layers=32,
        intermediate_size=11008,
        num_attention_heads=32,
        rms_norm_eps=1e-5,
        vocab_size=32064,
        num_key_value_heads=32,
        rope_theta=0,
        rope_traditional=False,
        rope_scaling=None
        ),
    vision_config=VisionConfig(
        num_hidden_layers=24,
        hidden_size=1024,
        intermediate_size=4096,
        num_attention_heads=16,
        num_channels=3,
        image_size=336,
        patch_size=14
    ),
    projection_config=ProjectionConfig(
        in_features=1024,
        out_features=4096
    )
)


mlx_model = LlavaModel(llava_mlx_config)



In [7]:
11008 / 4096

2.6875

In [2]:
mlx_model.load_weights('mlx_model/weights.npz')


ValueError: Received parameters not in model: language_model.layers.20.feed_forward.w3.weight language_model.layers.18.feed_forward.w2.weight language_model.layers.20.feed_forward.w1.weight language_model.layers.0.feed_forward.w1.weight language_model.layers.23.attention.wq.weight language_model.layers.23.ffn_norm.weight language_model.layers.16.feed_forward.w1.weight language_model.layers.31.attention_norm.weight language_model.output.weight language_model.layers.27.attention_norm.weight language_model.layers.25.attention_norm.weight language_model.layers.28.attention_norm.weight language_model.layers.20.feed_forward.w2.weight language_model.layers.17.attention.wo.weight language_model.layers.1.attention.wq.weight language_model.layers.27.feed_forward.w3.weight language_model.layers.19.feed_forward.w1.weight language_model.layers.14.attention_norm.weight language_model.layers.21.feed_forward.w2.weight language_model.layers.16.attention.wo.weight language_model.layers.22.attention_norm.weight language_model.layers.4.attention.wk.weight language_model.layers.13.feed_forward.w1.weight language_model.layers.30.attention.wv.weight language_model.layers.5.feed_forward.w3.weight language_model.layers.20.attention_norm.weight language_model.layers.13.feed_forward.w2.weight language_model.layers.22.feed_forward.w2.weight language_model.layers.15.attention.wo.weight language_model.layers.26.attention.wo.weight language_model.layers.5.feed_forward.w1.weight language_model.layers.16.attention_norm.weight language_model.layers.4.attention.wq.weight language_model.layers.9.feed_forward.w1.weight language_model.layers.20.attention.wq.weight language_model.layers.9.attention.wv.weight language_model.layers.10.ffn_norm.weight language_model.layers.8.attention.wo.weight language_model.layers.3.attention.wv.weight language_model.layers.0.ffn_norm.weight language_model.layers.4.feed_forward.w3.weight language_model.layers.2.attention.wv.weight language_model.layers.7.attention.wv.weight language_model.layers.24.attention.wq.weight language_model.layers.11.feed_forward.w2.weight language_model.layers.0.attention.wo.weight language_model.layers.7.feed_forward.w3.weight language_model.layers.17.feed_forward.w1.weight language_model.layers.31.attention.wo.weight language_model.layers.26.attention.wk.weight language_model.layers.0.feed_forward.w3.weight language_model.layers.2.ffn_norm.weight language_model.layers.13.attention_norm.weight language_model.layers.19.attention.wk.weight language_model.layers.18.attention.wq.weight language_model.layers.10.attention.wo.weight language_model.layers.30.attention.wq.weight language_model.layers.5.feed_forward.w2.weight language_model.layers.5.attention.wv.weight language_model.layers.25.attention.wq.weight language_model.layers.3.feed_forward.w3.weight language_model.layers.9.attention.wo.weight language_model.layers.29.feed_forward.w1.weight language_model.layers.2.feed_forward.w3.weight language_model.layers.0.attention.wk.weight language_model.layers.11.attention.wv.weight language_model.layers.20.attention.wo.weight language_model.layers.16.ffn_norm.weight language_model.layers.2.feed_forward.w2.weight language_model.layers.27.attention.wk.weight language_model.tok_embeddings.weight language_model.layers.14.ffn_norm.weight language_model.layers.12.ffn_norm.weight language_model.layers.22.attention.wo.weight language_model.layers.12.attention.wq.weight language_model.layers.19.attention.wq.weight language_model.layers.11.attention.wq.weight language_model.layers.6.attention.wv.weight language_model.layers.26.feed_forward.w3.weight language_model.layers.26.feed_forward.w2.weight language_model.layers.17.attention.wq.weight language_model.layers.18.feed_forward.w3.weight language_model.layers.29.attention.wk.weight language_model.layers.29.feed_forward.w2.weight language_model.layers.8.attention.wq.weight language_model.layers.2.attention_norm.weight language_model.layers.5.attention.wo.weight language_model.layers.23.feed_forward.w1.weight language_model.layers.30.feed_forward.w1.weight language_model.layers.2.attention.wo.weight language_model.layers.18.attention.wk.weight language_model.layers.13.attention.wo.weight language_model.layers.3.ffn_norm.weight language_model.layers.23.attention_norm.weight language_model.layers.10.feed_forward.w2.weight language_model.layers.1.ffn_norm.weight language_model.layers.21.ffn_norm.weight language_model.layers.30.attention.wo.weight language_model.layers.11.attention.wk.weight language_model.layers.7.attention.wo.weight language_model.layers.17.feed_forward.w3.weight language_model.layers.13.ffn_norm.weight language_model.layers.3.feed_forward.w1.weight language_model.layers.18.attention.wo.weight language_model.layers.22.feed_forward.w1.weight language_model.layers.15.attention.wk.weight language_model.layers.15.attention.wv.weight language_model.layers.28.feed_forward.w2.weight language_model.layers.21.feed_forward.w1.weight language_model.layers.12.feed_forward.w2.weight language_model.layers.23.feed_forward.w3.weight language_model.layers.19.ffn_norm.weight language_model.layers.18.attention_norm.weight language_model.layers.22.feed_forward.w3.weight language_model.layers.14.attention.wo.weight language_model.layers.9.ffn_norm.weight language_model.layers.13.attention.wk.weight language_model.layers.28.attention.wo.weight language_model.layers.26.attention.wq.weight language_model.layers.24.ffn_norm.weight language_model.layers.23.attention.wo.weight language_model.layers.10.attention_norm.weight language_model.layers.16.feed_forward.w2.weight language_model.layers.19.feed_forward.w2.weight language_model.layers.23.attention.wk.weight language_model.layers.2.feed_forward.w1.weight language_model.layers.11.feed_forward.w1.weight language_model.layers.4.feed_forward.w2.weight language_model.layers.23.attention.wv.weight language_model.layers.27.feed_forward.w1.weight language_model.layers.17.feed_forward.w2.weight language_model.layers.12.attention_norm.weight language_model.layers.30.feed_forward.w2.weight language_model.layers.15.ffn_norm.weight language_model.layers.12.feed_forward.w1.weight language_model.layers.28.attention.wq.weight language_model.layers.10.attention.wq.weight language_model.layers.4.ffn_norm.weight language_model.layers.14.feed_forward.w1.weight language_model.layers.3.feed_forward.w2.weight language_model.layers.12.feed_forward.w3.weight language_model.layers.21.attention.wq.weight language_model.layers.10.attention.wv.weight language_model.layers.21.feed_forward.w3.weight language_model.layers.6.feed_forward.w2.weight language_model.layers.20.ffn_norm.weight language_model.layers.25.attention.wk.weight language_model.layers.17.attention.wk.weight language_model.layers.26.attention_norm.weight language_model.layers.25.feed_forward.w2.weight language_model.layers.1.attention_norm.weight language_model.layers.26.attention.wv.weight language_model.layers.19.attention.wo.weight language_model.layers.14.feed_forward.w3.weight language_model.layers.14.attention.wv.weight language_model.layers.29.ffn_norm.weight language_model.layers.14.feed_forward.w2.weight language_model.layers.1.attention.wk.weight language_model.layers.4.attention.wv.weight language_model.layers.22.attention.wq.weight language_model.layers.3.attention.wq.weight language_model.layers.16.attention.wv.weight language_model.layers.21.attention.wo.weight language_model.layers.26.ffn_norm.weight language_model.layers.29.attention.wq.weight language_model.layers.7.attention.wq.weight language_model.layers.21.attention_norm.weight language_model.layers.24.attention.wo.weight language_model.layers.5.attention_norm.weight language_model.layers.18.feed_forward.w1.weight language_model.layers.26.feed_forward.w1.weight language_model.layers.31.attention.wv.weight language_model.layers.25.feed_forward.w1.weight language_model.layers.27.ffn_norm.weight language_model.layers.6.feed_forward.w1.weight language_model.layers.28.feed_forward.w1.weight language_model.layers.1.feed_forward.w3.weight language_model.layers.8.feed_forward.w2.weight language_model.layers.20.attention.wk.weight language_model.layers.2.attention.wq.weight language_model.layers.4.feed_forward.w1.weight language_model.layers.9.attention.wq.weight language_model.layers.15.feed_forward.w1.weight language_model.layers.7.ffn_norm.weight language_model.layers.0.feed_forward.w2.weight language_model.layers.30.attention_norm.weight language_model.layers.13.attention.wv.weight language_model.layers.10.feed_forward.w1.weight language_model.layers.5.attention.wq.weight language_model.layers.16.feed_forward.w3.weight language_model.layers.28.ffn_norm.weight language_model.layers.31.feed_forward.w1.weight language_model.layers.12.attention.wo.weight language_model.layers.27.attention.wo.weight language_model.layers.15.feed_forward.w3.weight language_model.layers.29.attention.wo.weight language_model.layers.27.attention.wv.weight language_model.layers.14.attention.wq.weight language_model.layers.5.attention.wk.weight language_model.layers.1.feed_forward.w1.weight language_model.layers.20.attention.wv.weight language_model.layers.23.feed_forward.w2.weight language_model.layers.8.attention.wk.weight language_model.layers.5.ffn_norm.weight language_model.layers.21.attention.wv.weight language_model.layers.29.attention_norm.weight language_model.layers.10.feed_forward.w3.weight language_model.layers.1.feed_forward.w2.weight language_model.layers.24.feed_forward.w3.weight language_model.layers.11.ffn_norm.weight language_model.layers.9.attention_norm.weight language_model.layers.4.attention.wo.weight language_model.layers.25.attention.wo.weight language_model.layers.7.feed_forward.w2.weight language_model.layers.9.feed_forward.w2.weight language_model.layers.14.attention.wk.weight language_model.layers.27.feed_forward.w2.weight language_model.layers.13.attention.wq.weight language_model.layers.15.attention_norm.weight language_model.layers.28.attention.wv.weight language_model.layers.0.attention_norm.weight language_model.layers.0.attention.wv.weight language_model.layers.7.attention.wk.weight language_model.layers.29.feed_forward.w3.weight language_model.layers.3.attention.wk.weight language_model.layers.28.feed_forward.w3.weight language_model.layers.22.attention.wv.weight language_model.layers.22.attention.wk.weight language_model.layers.6.attention.wq.weight language_model.layers.1.attention.wo.weight language_model.layers.18.attention.wv.weight language_model.layers.8.attention.wv.weight language_model.layers.6.ffn_norm.weight language_model.layers.25.ffn_norm.weight language_model.layers.8.attention_norm.weight language_model.layers.6.attention.wk.weight language_model.layers.29.attention.wv.weight language_model.layers.19.attention_norm.weight language_model.layers.19.attention.wv.weight language_model.layers.6.attention.wo.weight language_model.layers.12.attention.wk.weight language_model.layers.9.feed_forward.w3.weight language_model.layers.8.feed_forward.w1.weight language_model.layers.10.attention.wk.weight language_model.layers.17.ffn_norm.weight language_model.layers.21.attention.wk.weight language_model.layers.15.attention.wq.weight language_model.layers.11.attention_norm.weight language_model.layers.24.attention.wk.weight language_model.layers.31.feed_forward.w2.weight language_model.layers.18.ffn_norm.weight language_model.layers.30.feed_forward.w3.weight language_model.layers.22.ffn_norm.weight language_model.layers.28.attention.wk.weight language_model.layers.9.attention.wk.weight language_model.layers.24.feed_forward.w2.weight language_model.layers.17.attention_norm.weight language_model.layers.17.attention.wv.weight language_model.layers.1.attention.wv.weight language_model.layers.31.ffn_norm.weight language_model.layers.31.attention.wk.weight language_model.layers.24.feed_forward.w1.weight language_model.layers.8.feed_forward.w3.weight language_model.layers.25.attention.wv.weight language_model.layers.7.feed_forward.w1.weight language_model.layers.31.attention.wq.weight language_model.layers.15.feed_forward.w2.weight language_model.layers.30.ffn_norm.weight language_model.layers.0.attention.wq.weight language_model.layers.31.feed_forward.w3.weight language_model.layers.13.feed_forward.w3.weight language_model.layers.19.feed_forward.w3.weight language_model.layers.6.attention_norm.weight language_model.layers.4.attention_norm.weight language_model.layers.12.attention.wv.weight language_model.layers.8.ffn_norm.weight language_model.layers.30.attention.wk.weight language_model.layers.3.attention.wo.weight language_model.layers.16.attention.wq.weight language_model.layers.11.feed_forward.w3.weight language_model.layers.25.feed_forward.w3.weight language_model.layers.3.attention_norm.weight language_model.layers.2.attention.wk.weight language_model.layers.16.attention.wk.weight language_model.layers.7.attention_norm.weight language_model.layers.27.attention.wq.weight language_model.layers.6.feed_forward.w3.weight language_model.layers.24.attention.wv.weight language_model.layers.11.attention.wo.weight language_model.layers.24.attention_norm.weight.

In [7]:
# TODO: load images, and test generate 

In [8]:
# TODO: compare with hf version's model weights as well 

# Load model directly
from transformers import AutoProcessor, AutoModelForPreTraining
processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
model = AutoModelForPreTraining.from_pretrained("llava-hf/llava-1.5-7b-hf")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Loading checkpoint shards: 100%|██████████| 3/3 [00:07<00:00,  2.48s/it]


In [31]:
mlx_model.language_model.layers[0].attention.wq.weight

array([[-0.00692749, -0.0147705, -0.00254822, ..., 0.00500488, 0.00238037, -0.0027771],
       [0.0155029, -0.00343323, 0.00121307, ..., -0.00964355, -0.0110474, 0.00744629],
       [-0.0157471, 0.0144043, 0.000104904, ..., 0.00619507, 0.0189209, -0.00415039],
       ...,
       [1.54972e-06, 0.00866699, 0.000881195, ..., 0.00946045, -0.0301514, 0.0107422],
       [0.0253906, 0.00994873, 0.00454712, ..., -0.0319824, -0.0148926, -0.0130005],
       [-0.0108643, -0.00534058, 0.00102234, ..., 0.0164795, 0.0150146, -0.00811768]], dtype=float16)

In [32]:
model.language_model.model.layers[0].self_attn.q_proj.weight

Parameter containing:
tensor([[-6.9275e-03, -1.4771e-02, -2.5482e-03,  ...,  5.0049e-03,
          2.3804e-03, -2.7771e-03],
        [ 1.5503e-02, -3.4332e-03,  1.2131e-03,  ..., -9.6436e-03,
         -1.1047e-02,  7.4463e-03],
        [-1.5747e-02,  1.4404e-02,  1.0490e-04,  ...,  6.1951e-03,
          1.8921e-02, -4.1504e-03],
        ...,
        [ 1.5497e-06,  8.6670e-03,  8.8120e-04,  ...,  9.4604e-03,
         -3.0151e-02,  1.0742e-02],
        [ 2.5391e-02,  9.9487e-03,  4.5471e-03,  ..., -3.1982e-02,
         -1.4893e-02, -1.3000e-02],
        [-1.0864e-02, -5.3406e-03,  1.0223e-03,  ...,  1.6479e-02,
          1.5015e-02, -8.1177e-03]], requires_grad=True)

In [None]:
# They seem to be the same!

In [9]:
import requests
from PIL import Image

image = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)

In [12]:
prompts = ['USER: <image> What are the things I should think aboutwhen I visit this place? ASSISTANT:'
           ]

In [13]:
inputs = processor(prompts, images=[image], padding=True, return_tensors="pt")

In [18]:

inputs

{'input_ids': tensor([[    1,  3148,  1001, 29901, 29871, 32000, 29871,  1724,   526,   278,
          2712,   306,   881,  1348,  1048,  8256,   306,  6493,   445,  2058,
         29973,   319,  1799,  9047, 13566, 29901]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1]]), 'pixel_values': tensor([[[[ 1.2734,  1.2734,  1.2734,  ...,  1.1274,  1.1274,  1.0982],
          [ 1.2734,  1.2734,  1.2880,  ...,  1.1274,  1.1274,  1.1128],
          [ 1.2880,  1.2880,  1.2880,  ...,  1.1274,  1.1274,  1.1274],
          ...,
          [-0.9456, -0.9164, -0.9164,  ..., -1.0769, -1.0769, -1.0769],
          [-0.9602, -0.9310, -0.9018,  ..., -1.0915, -1.0915, -1.0915],
          [-0.9602, -0.9748, -0.2448,  ..., -1.1061, -1.1061, -1.1207]],

         [[ 1.6397,  1.6397,  1.6397,  ...,  1.5196,  1.5196,  1.5196],
          [ 1.6397,  1.6397,  1.6547,  ...,  1.5196,  1.5196,  1.5196],
          [ 1.6547,  1.6547,  1.6547,  ...,  1.5

In [17]:
config

LlavaConfig {
  "_name_or_path": "/Users/noahkasmanoff/.cache/huggingface/hub/models--llava-hf--llava-1.5-7b-hf/snapshots/05ae2434cbb430be33edcba0c5203e7023f785b7",
  "architectures": [
    "LlavaForConditionalGeneration"
  ],
  "ignore_index": -100,
  "image_token_index": 32000,
  "model_type": "llava",
  "pad_token_id": 32001,
  "projector_hidden_act": "gelu",
  "text_config": {
    "_name_or_path": "lmsys/vicuna-7b-v1.5",
    "architectures": [
      "LlamaForCausalLM"
    ],
    "max_position_embeddings": 4096,
    "model_type": "llama",
    "rms_norm_eps": 1e-05,
    "torch_dtype": "float16",
    "vocab_size": 32064
  },
  "tie_word_embeddings": false,
  "torch_dtype": "float16",
  "transformers_version": "4.37.2",
  "vision_config": {
    "hidden_size": 1024,
    "image_size": 336,
    "intermediate_size": 4096,
    "model_type": "clip_vision_model",
    "num_attention_heads": 16,
    "num_hidden_layers": 24,
    "patch_size": 14,
    "projection_dim": 768,
    "vocab_size": 3200