In [None]:
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
import torch

model_id="llava-hf/llama3-llava-next-8b-hf"

device = torch.device("mps")

processor = LlavaNextProcessor.from_pretrained(model_id)
model = LlavaNextForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.float16).to(device)

In [ ]:
#import torch
#from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration

#MODEL_ID ="google/paligemma2-3b-pt-448"
#DEVICE = torch.device("mps")

#processor = PaliGemmaProcessor.from_pretrained(MODEL_ID)
#model = PaliGemmaForConditionalGeneration.from_pretrained(MODEL_ID).to(DEVICE)

In [ ]:
for param in model.vision_tower.parameters():
    param.requires_grad = False

for param in model.multi_modal_projector.parameters():
    param.requires_grad = False
    
for param in model.language_model.parameters():
    param.requires_grad = False 

In [ ]:
# Iterate through all parameters
for name, param in model.named_parameters():
    if param.requires_grad:
        print(f"Parameter Name: {name}")
        print(f"Shape: {param.shape}")
        print(f"Requires Grad: {param.requires_grad}")
        print("-" * 50)

In [ ]:
#from peft import get_peft_model, LoraConfig

#lora_config = LoraConfig(
#     r=32,
#     target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
#     task_type="CAUSAL_LM",
# )

#model = get_peft_model(model, lora_config)
#model.print_trainable_parameters()

In [ ]:
from MoLE import LoRA_MOE_LM

class Args:
    dense_moe = False  # Switch between dense and sparse routing for MoLE
    lora_rank = 32
    lora_alpha = 64
    num_experts = 3

args = Args()

num_layers = len(model.language_model.model.layers)

for i in range(num_layers):
    original_mlp = model.language_model.model.layers[i].mlp
    model.language_model.model.layers[i].mlp = LoRA_MOE_LM(args=args,
                                                           lora_rank=args.lora_rank,
                                                           lora_alpha=args.lora_alpha,
                                                           num_experts=args.num_experts,
                                                           original_module=original_mlp).to(device)

In [ ]:
from torch.optim import AdamW
from torch.nn import CrossEntropyLoss
from PIL import Image
import requests

# Define the single training example
url = "https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true"
image = Image.open(requests.get(url, stream=True).raw)

conversation = [
    {
        "role": "user",
        "content": [
            {"type": "text", "text": "What is shown in this image?"},
            {"type": "image"},
        ],
    },
]
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)

# Prepare inputs
inputs = processor(images=image, text=prompt, return_tensors="pt").to(device)

# Define the target output (you must define the expected correct answer here)
target_answer = "The image shows a radar system."  # Example target output
target_ids = processor.tokenizer(target_answer, return_tensors="pt",).input_ids.to(device)

# Fine-tuning parameters
optimizer = AdamW(model.parameters(), lr=1e-5)
loss_fn = CrossEntropyLoss()

# Fine-tune the model
model.train()
epochs = 1  # Fine-tune for a single epoch for this test
for epoch in range(epochs):
    optimizer.zero_grad()
    # Forward pass
    outputs = model(**inputs,  labels=target_ids)
    loss = outputs.loss

    # Backward pass and optimization
    loss.backward()
    optimizer.step()

    print(f"Epoch {epoch + 1}, Loss: {loss.item()}")