In [None]:
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
import torch

model_id="llava-hf/llama3-llava-next-8b-hf"

device = torch.device("mps")

processor = LlavaNextProcessor.from_pretrained(model_id)
model = LlavaNextForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.float16).to(device)

In [ ]:
from torch.nn import Module
from MoLE import LoRA_MOE_LM

class Args:
    dense_moe = False  # Switch between dense and sparse routing for MoLE
    lora_rank = 4
    lora_alpha = 1
    num_experts = 4

args = Args()

def replace_with_mole_layers(module: Module, args: Args):
    for name, child in module.named_children():
        if hasattr(child, "gate_proj") and hasattr(child, "down_proj") and hasattr(child, "up_proj"):
            setattr(module, name, LoRA_MOE_LM(args, args.lora_rank, args.lora_alpha, args.num_experts, child))
        
        else:
            replace_with_mole_layers(child, args)
    
        
replace_with_mole_layers(model, args)

In [ ]:
from torch.optim import AdamW
from torch.nn import CrossEntropyLoss
from PIL import Image
import requests

# Define the single training example
url = "https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true"
image = Image.open(requests.get(url, stream=True).raw)

conversation = [
    {
        "role": "user",
        "content": [
            {"type": "text", "text": "What is shown in this image?"},
            {"type": "image"},
        ],
    },
]
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)

# Prepare inputs
inputs = processor(images=image, text=prompt, return_tensors="pt").to(device)

# Define the target output (you must define the expected correct answer here)
target_answer = "The image shows a radar system."  # Example target output
target_ids = processor.tokenizer(target_answer, return_tensors="pt",).input_ids.to(device)

# Fine-tuning parameters
optimizer = AdamW(model.parameters(), lr=1e-5)
loss_fn = CrossEntropyLoss()

# Fine-tune the model
model.train()
epochs = 1  # Fine-tune for a single epoch for this test
for epoch in range(epochs):
    optimizer.zero_grad()
    # Forward pass
    outputs = model(**inputs,  labels=target_ids)
    loss = outputs.loss

    # Backward pass and optimization
    loss.backward()
    optimizer.step()

    print(f"Epoch {epoch + 1}, Loss: {loss.item()}")