# Goal: Download and convert the weights of LlaVA into MLX, and test the forward pass of this model on example data

In [1]:
import shutil
from pathlib import Path
import os


In [2]:
mlx_path = Path('mlx_model')

if not os.path.exists(mlx_path):
    os.makedirs(mlx_path)


In [3]:
import mlx.core as mx
from convert import get_model_path, fetch_from_hub, hf_repo


model_path = get_model_path(hf_repo)
model_config, model_weights, model_weight_files, config, tokenizer = fetch_from_hub(model_path)

  from .autonotebook import tqdm as notebook_tqdm
Fetching 12 files: 100%|██████████| 12/12 [00:00<00:00, 214177.23it/s]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [4]:
from utils import map_weights, should_keep_weight
do_convert = True
if do_convert:

    print("[INFO] Converting")
    mlx_weights = dict(map_weights(k, v) for (k, v) in model_weights.items())
    mlx_weights = {k: v for (k, v) in mlx_weights.items() if should_keep_weight(k)}
    print("[INFO] Saving")
    mx.savez(str(mlx_path / "weights.npz"), **mlx_weights)
    for fn in ["config.json", "merges.txt", "vocab.json", "preprocessor_config.json"]:
        if fn in os.listdir(model_path):
            shutil.copyfile(
                str(model_path / f"{fn}"),
                str(mlx_path / f"{fn}"),
            )


[INFO] Converting
[INFO] Saving


In [5]:
from llava import LlaVAConfig, LLMConfig, VisionConfig, ProjectionConfig, LlavaModel

llava_mlx_config = LlaVAConfig(
    llm_config=LLMConfig(
        model_type='vicuna',
        hidden_size=4096,
        num_hidden_layers=32,
        intermediate_size=11008,
        num_attention_heads=32,
        rms_norm_eps=1e-5,
        vocab_size=32064,
        num_key_value_heads=32,
        rope_theta=0,
        rope_traditional=False,
        rope_scaling=None
        ),
    vision_config=VisionConfig(
        num_hidden_layers=24,
        hidden_size=1024,
        intermediate_size=4096,
        num_attention_heads=16,
        num_channels=3,
        image_size=336,
        patch_size=14
    ),
    projection_config=ProjectionConfig(
        in_features=1024,
        out_features=4096
    )
)


mlx_model = LlavaModel(llava_mlx_config)



In [6]:
11008 / 4096

2.6875

In [7]:
mlx_model.load_weights('mlx_model/weights.npz')


In [27]:
# Now that model weights are loaded in, now we can try and run inference code / set that up.

# load the processor
from transformers import AutoProcessor
import requests
from PIL import Image
processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")

prompt = "<image>\nUSER: What's the content of the image?\nASSISTANT:"
url = "https://www.ilankelman.org/stopsigns/australia.jpg"
image = Image.open(requests.get(url, stream=True).raw)

inputs = processor(text=prompt, images=image, return_tensors="pt")



Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [31]:

input_ids = mx.array(inputs["input_ids"].numpy())
pixel_values = mx.array(inputs["pixel_values"].numpy())


In [37]:
vision_model_output = mlx_model.vision_tower(pixel_values.transpose(0,2,3,1))

(1, 577, 1024)

In [57]:
vision_model_output

CLIPVisionOutput(pooler_output=array([[-0.721487, -0.476275, 0.0173661, ..., 0.190072, -1.71528, 1.36224]], dtype=float32), last_hidden_state=array([[[-0.333623, -0.269844, 0.025435, ..., -0.0516554, -0.729696, 0.542679],
        [0.208684, 0.92752, 0.0233985, ..., 1.59934, -0.024813, 0.879629],
        [0.550235, 0.45201, 0.80935, ..., 1.63056, -0.37727, 0.699322],
        ...,
        [0.740987, 0.445616, 0.893172, ..., 0.523529, 0.0230118, -0.457155],
        [0.49297, 0.0680847, 0.79401, ..., 0.476083, 0.274526, -0.284749],
        [-0.0411091, 0.290756, 0.518906, ..., 0.242572, 0.40785, 0.420446]]], dtype=float32))