# Goal: Download and convert the weights of LlaVA into MLX, and test the forward pass of this model on example data

In [1]:
import shutil
from pathlib import Path
import os


In [2]:
mlx_path = Path('mlx_model')

if not os.path.exists(mlx_path):
    os.makedirs(mlx_path)


In [3]:
import mlx.core as mx
from convert import get_model_path, fetch_from_hub, hf_repo


model_path = get_model_path(hf_repo)
model_config, model_weights, model_weight_files, config, tokenizer = fetch_from_hub(model_path)

  from .autonotebook import tqdm as notebook_tqdm
Fetching 12 files: 100%|██████████| 12/12 [00:00<00:00, 181702.70it/s]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [4]:
from utils import map_weights, should_keep_weight


print("[INFO] Converting")
mlx_weights = dict(map_weights(k, v) for (k, v) in model_weights.items())
mlx_weights = {k: v for (k, v) in mlx_weights.items() if should_keep_weight(k)}
print("[INFO] Saving")
mx.savez(str(mlx_path / "weights.npz"), **mlx_weights)
for fn in ["config.json", "merges.txt", "vocab.json", "preprocessor_config.json"]:
    if fn in os.listdir(model_path):
        shutil.copyfile(
            str(model_path / f"{fn}"),
            str(mlx_path / f"{fn}"),
        )


[INFO] Converting
[INFO] Saving


In [5]:
from llava import LlaVAConfig, LLMConfig, VisionConfig, ProjectionConfig, LlavaModel

llava_mlx_config = LlaVAConfig(
    llm_config=LLMConfig(
        dim=4096,
        n_layers=32,
        head_dim=4096,
        hidden_dim=11008,
        norm_eps=1e-5,
        n_heads=1, # TODO: should be 32 https://huggingface.co/lmsys/vicuna-7b-v1.5/blob/main/config.json#L14. But only works with 1. Please see llama file for how heads are split. Is this wrong?
        n_kv_heads=1, # TODO: should be 32 https://huggingface.co/lmsys/vicuna-7b-v1.5/blob/main/config.json#L16
        vocab_size=32064,
        rope_theta=0,
        rope_traditional=False
    ),
    vision_config=VisionConfig(
        num_hidden_layers=24,
        hidden_size=1024,
        intermediate_size=4096,
        num_attention_heads=16,
        num_channels=3,
        image_size=336,
        patch_size=14
    ),
    projection_config=ProjectionConfig(
        in_features=1024,
        out_features=4096
    )
)


mlx_model = LlavaModel(llava_mlx_config)



In [6]:
mlx_model.load_weights('mlx_model/weights.npz')


In [7]:
# TODO: load images, and test generate 

In [8]:
# TODO: compare with hf version's model weights as well 

# Load model directly
from transformers import AutoProcessor, AutoModelForPreTraining
processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
model = AutoModelForPreTraining.from_pretrained("llava-hf/llava-1.5-7b-hf")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Loading checkpoint shards: 100%|██████████| 3/3 [00:07<00:00,  2.48s/it]


In [31]:
mlx_model.language_model.layers[0].attention.wq.weight

array([[-0.00692749, -0.0147705, -0.00254822, ..., 0.00500488, 0.00238037, -0.0027771],
       [0.0155029, -0.00343323, 0.00121307, ..., -0.00964355, -0.0110474, 0.00744629],
       [-0.0157471, 0.0144043, 0.000104904, ..., 0.00619507, 0.0189209, -0.00415039],
       ...,
       [1.54972e-06, 0.00866699, 0.000881195, ..., 0.00946045, -0.0301514, 0.0107422],
       [0.0253906, 0.00994873, 0.00454712, ..., -0.0319824, -0.0148926, -0.0130005],
       [-0.0108643, -0.00534058, 0.00102234, ..., 0.0164795, 0.0150146, -0.00811768]], dtype=float16)

In [32]:
model.language_model.model.layers[0].self_attn.q_proj.weight

Parameter containing:
tensor([[-6.9275e-03, -1.4771e-02, -2.5482e-03,  ...,  5.0049e-03,
          2.3804e-03, -2.7771e-03],
        [ 1.5503e-02, -3.4332e-03,  1.2131e-03,  ..., -9.6436e-03,
         -1.1047e-02,  7.4463e-03],
        [-1.5747e-02,  1.4404e-02,  1.0490e-04,  ...,  6.1951e-03,
          1.8921e-02, -4.1504e-03],
        ...,
        [ 1.5497e-06,  8.6670e-03,  8.8120e-04,  ...,  9.4604e-03,
         -3.0151e-02,  1.0742e-02],
        [ 2.5391e-02,  9.9487e-03,  4.5471e-03,  ..., -3.1982e-02,
         -1.4893e-02, -1.3000e-02],
        [-1.0864e-02, -5.3406e-03,  1.0223e-03,  ...,  1.6479e-02,
          1.5015e-02, -8.1177e-03]], requires_grad=True)

In [None]:
# They seem to be the same!

In [9]:
import requests
from PIL import Image

image = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)

In [None]:
prompts = ['USER: <image> What are the things I should think aboutwhen I visit this place? ASSISTANT:'
           ]

In [None]:
inputs = processor(prompts, images=[image], padding=True, return_tensors="pt")