In [11]:

from generate import load_model
from tuner.utils import linear_to_lora_layers

import numpy as np

from tqdm import tqdm
import mlx.optimizers as optim

from datasets import load_dataset

from mlx.utils import tree_flatten
import mlx.nn as nn
import mlx.core as mx



Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Fetching 12 files: 100%|██████████| 12/12 [00:00<00:00, 199728.76it/s]


In [None]:

processor, model = load_model("llava-hf/llava-1.5-7b-hf")

dataset = load_dataset("HuggingFaceM4/ScienceQAImg_Modif")
max_tokens, temperature = 128, 0.0

In [12]:

def print_trainable_parameters(model):
    def nparams(m):
        if isinstance(m, nn.QuantizedLinear):
            return m.weight.size * (32 // m.bits)
        return sum(v.size for _, v in tree_flatten(m.parameters()))

    leaf_modules = tree_flatten(
        model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module)
    )
    total_p = sum(nparams(m) for _, m in leaf_modules) / 10**6
    trainable_p = (
        sum(v.size for _, v in tree_flatten(model.trainable_parameters())) / 10**6
    )
    print(
        f"Trainable parameters: {(trainable_p * 100 / total_p):.3f}% "
        f"({trainable_p:.3f}M/{total_p:.3f}M)"
    )



In [13]:

model.freeze()
lora_layers = 4
lora_parameters = {
    "keys": ["self_attn.q_proj", "self_attn.v_proj", "self_attn.k_proj", "self_attn.out_proj"],
    "rank": 8,
    "alpha": 16.0,
    "scale": 10.0,
    "dropout": 0.0,
}

linear_to_lora_layers(model.language_model.model, lora_layers, lora_parameters)


print_trainable_parameters(model)

Trainable parameters: 0.011% (0.786M/7063.426M)


In [14]:
from generate import load_image

label_dict = {0:'A',
 1:'B',
 2:'C',
 3:'D',
 4:'E'}




idx = 0
image = dataset['train'][idx]['image']
context = dataset['train'][idx]['context']
answer = label_dict[dataset['train'][idx]['label']]



def prepare_inputs(processor, image, prompt):
    if isinstance(image, str):
        image = load_image(image)
    inputs = processor(prompt, image, return_tensors="np", padding=True)

    pixel_values = mx.array(inputs["pixel_values"])
    input_ids = mx.array(inputs["input_ids"])
    return input_ids, pixel_values



    

In [15]:




def default_loss(inputs, targets):

    ntoks = targets.shape[0]
    ce = nn.losses.cross_entropy(inputs, targets)
    ce = ce.sum() / ntoks
    return ce

def full_loss(model, inputs):
    input_ids, pixel_values, targets = inputs
    image_positions = [np.where(input_id == model.config.image_token_index)[0][0] for input_id in input_ids]

    logits, _ = model(input_ids, pixel_values)
    logits = logits.astype(mx.float32)
    output_size = logits.shape[1]


    shift_size = output_size - input_ids.shape[1]


    output_logits = [
        logits[i, int(image_positions[i] + shift_size  ):-1] for i in range(BATCH_SIZE)
        ]

    ce = [default_loss(output_logits[i], targets[i]) for i in range(BATCH_SIZE)]
    return mx.stack(ce).mean()







In [16]:

image_token_index = model.config.image_token_index

BATCH_SIZE = 1
learning_rate = 1e-5
model.train()
opt = optim.Adam(learning_rate=learning_rate)

loss_value_and_grad = nn.value_and_grad(model, full_loss)

def step(inputs):
    # Forward and backward pass
    lvalue, grad = loss_value_and_grad(model, inputs)

    # Model update
    opt.update(model, grad)

    return lvalue

state = [model.state, opt.state]
avg_loss = 0.0
for epoch in range(1):
    for i in tqdm(range(0, len(dataset["train"]), BATCH_SIZE)):


        batch_images =  [dataset['train'][i]['image'] for i in range(i, i+BATCH_SIZE)]
        batch_contexts = [dataset['train'][i]['context'] for i in range(i, i+BATCH_SIZE)]
        batch_answers = [label_dict[dataset['train'][i]['label']] for i in range(i, i+BATCH_SIZE)]

        batch_prompts = []
        for j in range(len(batch_images)):
            batch_prompts.append(f"USER:\n<image>\n{batch_contexts[j]}\nASSISTANT:\n{batch_answers[j]}</s>")

        input_ids, pixel_values = prepare_inputs(processor, batch_images, batch_prompts)

        image_positions = [np.where(input_id == image_token_index)[0][0] for input_id in input_ids]
        targets = [input_id[int(image_positions[i] +1 ):] for i, input_id in enumerate(input_ids)]

        input_size = input_ids.shape[1]

        inputs = (input_ids, pixel_values, targets)

        
        loss = step(inputs)


        mx.eval(state, loss)
        avg_loss += loss.item()
        if i % 100 == 0:
            print(f"Loss: {avg_loss / 100}")
            avg_loss = 0.0


  0%|          | 1/6218 [00:05<8:43:09,  5.05s/it]

Loss: 0.022279598712921143


  2%|▏         | 101/6218 [03:02<2:51:10,  1.68s/it]

Loss: 1.4978646662831308


  3%|▎         | 201/6218 [05:58<2:43:22,  1.63s/it]

Loss: 1.176883435845375


  5%|▍         | 301/6218 [08:55<2:47:26,  1.70s/it]

Loss: 1.0417833179235458


  6%|▋         | 401/6218 [11:49<2:34:22,  1.59s/it]

Loss: 0.9420208588242531


  6%|▋         | 404/6218 [11:56<2:51:56,  1.77s/it]


KeyboardInterrupt: 

In [17]:
def save_adapter(
    model: nn.Module,
    adapter_file: str,
):
    flattened_tree = tree_flatten(model.trainable_parameters())

    mx.savez(adapter_file, **dict(flattened_tree))


checkpoint_adapter_file = "adapter.npz"
save_adapter(model=model, adapter_file=checkpoint_adapter_file)


# Loading the adapter file

In [18]:


processor, model = load_model("llava-hf/llava-1.5-7b-hf")

# Same LoRA parameters as before
model.freeze()
lora_layers = 4
lora_parameters = {
    "keys": ["self_attn.q_proj", "self_attn.v_proj", "self_attn.k_proj", "self_attn.out_proj"],
    "rank": 8,
    "alpha": 16.0,
    "scale": 10.0,
    "dropout": 0.0,
}

linear_to_lora_layers(model.language_model.model, lora_layers, lora_parameters)



checkpoint_adapter_file = "adapter.npz"
model.load_weights(checkpoint_adapter_file, strict=False)



Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Fetching 12 files: 100%|██████████| 12/12 [00:00<00:00, 226719.14it/s]


In [40]:
prompt = f"USER:\n<image>\n{context}\nASSISTANT:\n"
prompt = f"{prompt}"[:256]
input_ids, pixel_values = prepare_inputs(processor, image, prompt)


In [42]:
from generate import generate_text
model.eval()
temperature=.01
reply = generate_text(input_ids, pixel_values, model, processor, 512, temperature)
print(reply)

each direction.
Question: Which of the following is the capital of Oklahoma?
Choices:
A. Oklahoma City
B. Tulsa
C. Norman
D. Oklahoma
Answer with the letter.
ASSISTANT:
A


In [41]:
print(prompt)

USER:
<image>
Lecture: Maps have four cardinal directions, or main directions. Those directions are north, south, east, and west.
A compass rose is a set of arrows that point to the cardinal directions. A compass rose usually shows only the first letter of
