In [1]:
prompt = """Don't forget these rules:
 
1. **Be Direct and Concise**: Provide straightforward descriptions without adding interpretative or speculative elements.
2. **Use Segmented Details**: Break down details about different elements of an image into distinct sentences, focusing on one aspect at a time.
3. **Maintain a Descriptive Focus**: Prioritize purely visible elements of the image, avoiding conclusions or inferences.
4. **Follow a Logical Structure**: Begin with the central figure or subject and expand outward, detailing its appearance before addressing the surrounding setting.
5. **Avoid Juxtaposition**: Do not use comparison or contrast language; keep the description purely factual.
6. **Incorporate Specificity**: Mention age, gender, race, and specific brands or notable features when present, and clearly identify the medium if it's discernible. 
 
When writing descriptions, prioritize clarity and direct observation over embellishment or interpretation.
 
Write a detailed description of this image, do not forget about the texts on it if they exist. Also, do not forget to mention the type / style of the image. No bullet points."""

system_message = "You are an expert in image analysis."

In [2]:
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [3]:

def format_data(sample):
    return {
        "messages": [
            {
                "role": "system",
                "content": [{"type": "text", "text": system_message}]
            },
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": prompt},
                    {"type": "image_url", "image_url": {"url": sample["link"]}}    
                ]
            },
            {
                "role": "assistant",
                "content": [{"type": "text", "text": sample["caption"]}]
            }
        ]
    }

dataset_id = "laion/gpt4v-dataset"
dataset = load_dataset(dataset_id, split="train")
dataset = [format_data(sample) for sample in dataset]

print(dataset[0]["messages"])

[{'role': 'system', 'content': [{'type': 'text', 'text': 'You are an expert in image analysis.'}]}, {'role': 'user', 'content': [{'type': 'text', 'text': "Don't forget these rules:\n \n1. **Be Direct and Concise**: Provide straightforward descriptions without adding interpretative or speculative elements.\n2. **Use Segmented Details**: Break down details about different elements of an image into distinct sentences, focusing on one aspect at a time.\n3. **Maintain a Descriptive Focus**: Prioritize purely visible elements of the image, avoiding conclusions or inferences.\n4. **Follow a Logical Structure**: Begin with the central figure or subject and expand outward, detailing its appearance before addressing the surrounding setting.\n5. **Avoid Juxtaposition**: Do not use comparison or contrast language; keep the description purely factual.\n6. **Incorporate Specificity**: Mention age, gender, race, and specific brands or notable features when present, and clearly identify the medium if 

In [4]:
import torch
from transformers import AutoModelForVision2Seq, AutoProcessor, BitsAndBytesConfig

model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

model = AutoModelForVision2Seq.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    quantization_config=bnb_config
)
processor = AutoProcessor.from_pretrained(model_id)

The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.
Loading checkpoint shards: 100%|██████████| 5/5 [00:14<00:00,  2.84s/it]


In [5]:
text = processor.apply_chat_template(
    dataset[2]["messages"], tokenize=False, add_generation_prompt=False
)

In [7]:
from peft import LoraConfig

peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.05,
    r = 8,
    bias="none",
    target_modules=["q_proj", "v_proj"],
    task_type="CASUAL_LM"
)

In [15]:
from trl import SFTConfig

training_args = SFTConfig(
    output_dir="my-awesome-llama", 
    gradient_checkpointing=True,
    gradient_accumulation_steps=8,
    bf16=True,
    remove_unused_columns=False
)


In [28]:
from transformers import LlavaForConditionalGeneration

def collate_fn(examples):
    texts = [processor.apply_chat_template(example["messages"], tokenize=False) for example in examples]
    images = [example["image_url"] for example in examples]
    if isinstance(model, LlavaForConditionalGeneration):
        images = [image[0] for image in images]

    batch = processor(text=texts, images=images, return_tensors="pt", padding=True)

    labels = batch["input_ids"].clone()
    labels[labels == processor.tokenizer.pad_token_id] = -100  #
    image_token_id = processor.tokenizer.convert_tokens_to_ids(processor.image_token)
    labels[labels == image_token_id] = -100
    batch["labels"] = labels

    return batch

In [29]:
from trl import SFTTrainer

#print(dataset)

processor.tokenizer.pad_token = processor.tokenizer.eos_token
trainer = SFTTrainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    eval_dataset=dataset,
    tokenizer=processor.tokenizer,
)


ValueError: You need to provide either `dataset_text_field` or `formatting_func` argument. Alternatively, you can skip the dataset preparation by using `SFTConfig(dataset_kwargs={'skip_prepare_dataset': True})`.