## Import Libraries

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from peft import LoraConfig, get_peft_model, PeftModel
from datasets import load_dataset
from trl import SFTConfig, SFTTrainer
import json
import transformers
import gradio as gr
import spaces
from threading import Thread
import time

## Download and save model

In [None]:
model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"

with open('access_token.json', 'r') as file:
    data = json.load(file)
access_token = data["access_token"]

model = transformers.MllamaForConditionalGeneration.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="mps",
    token=access_token
)

processor = transformers.AutoProcessor.from_pretrained(model_id, token=access_token)
# Save model locally
model.save_pretrained("llama")
processor.save_pretrained("processor")

## Fine-tune the model with custom dataset

In [None]:
# ==============================
# 1️⃣ Load LLaMA-3.2-11B-Vision-Instruct Model and Tokenizer
# ==============================
model_name = "llama/"

tokenizer = AutoTokenizer.from_pretrained('processor/')
tokenizer.pad_token = tokenizer.eos_token  # Ensure padding token is set

# Load model in MPS (Apple GPU)
device = "mps" if torch.backends.mps.is_available() else "cpu"

base_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,  # Use `bfloat16` if preferred: torch.bfloat16
    device_map={"": device}  # Map to MPS
)

# ==============================
# 2️⃣ Define LoRA PEFT Configuration
# ==============================
peft_config = LoraConfig(
    r=16,  # Rank of LoRA
    lora_alpha=32,  # Scaling factor
    lora_dropout=0.05,  # Dropout for stability
    target_modules=["q_proj", "v_proj"],  # Target attention layers
    bias="none",
    task_type="CAUSAL_LM"
)

# Apply PEFT to the model
model = get_peft_model(base_model, peft_config)
model.print_trainable_parameters()  # Verify trainable params

# ==============================
# 3️⃣ Load Custom Dataset
# ==============================
# Ensure dataset is formatted as {"instruction": "...", "input": "...", "output": "..."}
dataset = load_dataset("json", data_files="dataset/dataset.json")

# Format the dataset into instruction-tuning format
def format_data(example):
    if example["input"]:  # Include input if present
        return {"text": f"### Instruction: {example['instruction']}\n### Input: {example['input']}\n### Response: {example['output']}"}
    return {"text": f"### Instruction: {example['instruction']}\n### Response: {example['output']}"}

dataset = dataset.map(format_data)

# ==============================
# 4️⃣ Define Training Arguments
# ==============================
training_args = SFTConfig(
    output_dir="./fine-tuned-llama",
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    eval_strategy="no",
    save_strategy="steps",
    save_steps=500,
    # eval_steps=500,
    logging_steps=50,
    learning_rate=2e-4,
    weight_decay=0.01,
    warmup_ratio=0.1,
    max_steps=-1,  # Adjust based on dataset size
    bf16=True,  # Use fp16 instead of bf16 for cuda
    optim="adamw_torch",
    push_to_hub=False,
    max_seq_length=512,
    packing = False
)

# ==============================
# 5️⃣ Train the Model with SFTTrainer
# ==============================
trainer = SFTTrainer(
    model=model,
    train_dataset=dataset['train'],
    processing_class=tokenizer,
    args=training_args,
)

trainer.train()

# ==============================
# 6️⃣ Save the Fine-Tuned Model
# ==============================
trainer.save_model("./fine-tuned-llama")
tokenizer.save_pretrained("./fine-tuned-llama")

print("Fine-tuning complete! Model saved at ./fine-tuned-llama")

## Load trained model

In [None]:
model = PeftModel.from_pretrained(base_model, './fine-tuned-llama')
model = model.to(torch.device("mps"))
model.eval()

## Test with a sample prompt

In [None]:
# Example instruction prompt
input_text = "### Instruction: Which player is going to get most points in next gameweek in FPL?\n### Response:"

# Tokenize input
inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(model.device) for k, v in inputs.items()}  # Move to GPU if available

# Generate response
with torch.no_grad():
    output = model.generate(**inputs, max_new_tokens=200, temperature=0.7, top_p=0.9)

# Decode and print response
response = tokenizer.decode(output[0], skip_special_tokens=True)
print("\n📝 Model Response:\n", response)

## Test in a chatbot

In [None]:
@spaces.GPU
def bot_streaming(message, history, max_new_tokens=250):
    
    txt = message["text"]

    messages= [] 
    images = []
    

    for i, msg in enumerate(history): 
        if isinstance(msg[0], tuple):
            messages.append({"role": "user", "content": [{"type": "text", "text": history[i+1][0]}, {"type": "image"}]})
            messages.append({"role": "assistant", "content": [{"type": "text", "text": history[i+1][1]}]})
            images.append(Image.open(msg[0][0]).convert("RGB"))
        elif isinstance(history[i-1], tuple) and isinstance(msg[0], str):
            # messages are already handled
            pass
        elif isinstance(history[i-1][0], str) and isinstance(msg[0], str): # text only turn
            messages.append({"role": "user", "content": [{"type": "text", "text": msg[0]}]})
            messages.append({"role": "assistant", "content": [{"type": "text", "text": msg[1]}]})

    # add current message
    if len(message["files"]) == 1:
        
        if isinstance(message["files"][0], str): # examples
            image = Image.open(message["files"][0]).convert("RGB")
        else: # regular input
            image = Image.open(message["files"][0]["path"]).convert("RGB")
        images.append(image)
        messages.append({"role": "user", "content": [{"type": "text", "text": txt}, {"type": "image"}]})
    else:
        messages.append({"role": "user", "content": [{"type": "text", "text": txt}]})

    texts = processor.apply_chat_template(messages, add_generation_prompt=True)

    if images == []:
        inputs = processor(text=texts, return_tensors="pt").to("mps")
    else:
        inputs = processor(text=texts, images=images, return_tensors="pt").to("mps")
    streamer = transformers.TextIteratorStreamer(processor, skip_special_tokens=True, skip_prompt=True)

    generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens)
    generated_text = ""
    
    thread = Thread(target=MODEL.generate, kwargs=generation_kwargs)
    thread.start()
    buffer = ""
    
    for new_text in streamer:
        buffer += new_text
        generated_text_without_prompt = buffer
        time.sleep(0.01)
        yield buffer


MODEL = model
demo = gr.ChatInterface(fn=bot_streaming, title="Multimodal Llama",
          textbox=gr.MultimodalTextbox(), 
          additional_inputs = [gr.Slider(
                  minimum=10,
                  maximum=150,
                  value=250,
                  step=10,
                  label="Maximum number of new tokens to generate",
              )
            ],
          cache_examples=False,
          description="Test LLAMA",
          stop_btn="Stop Generation", 
          fill_height=False,
        multimodal=True)
        
demo.launch(debug=True)