# Task 2: LLaVA Caption Generation
- In this notebook we test LlaVA model, we generate descriptive captions for bird species images.
- The model leverages advanced language-vision alignment to produce detailed and nuanced captions (way better results in comparison with BLIP-2).

In [None]:
import torch
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import os
import pandas as pd
from utils import *

from transformers import LlavaProcessor, LlavaForConditionalGeneration

### Hyperparameters

In [2]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE

'cuda'

### Dataset

In [None]:
# Load the CUB-200-2011 dataset
data_dir = 'data/CUB_200_2011'

images, labels, classes, _, _, _, _, _, _, _, _ = load_cub_dataset(data_dir)

print(images.head())
print(labels.head())
print(classes.head())

print(images.shape)
print(labels.shape)
print(classes.shape)

   image_id                                          file_path
0         1  001.Black_footed_Albatross/Black_Footed_Albatr...
1         2  001.Black_footed_Albatross/Black_Footed_Albatr...
2         3  001.Black_footed_Albatross/Black_Footed_Albatr...
3         4  001.Black_footed_Albatross/Black_Footed_Albatr...
4         5  001.Black_footed_Albatross/Black_Footed_Albatr...
   image_id  class_id
0         1         1
1         2         1
2         3         1
3         4         1
4         5         1
   class_id                  class_name
0         1  001.Black_footed_Albatross
1         2        002.Laysan_Albatross
2         3         003.Sooty_Albatross
3         4       004.Groove_billed_Ani
4         5          005.Crested_Auklet
(11788, 2)
(11788, 2)
(200, 2)


In [None]:
img_path = 'data/images/001.Black_footed_Albatross/Black_Footed_Albatross_0001_796111.jpg'
img = Image.open(img_path).convert('RGB')

### Llava

In [None]:
llava_processor = LlavaProcessor.from_pretrained("liuhaotian/llava-v1.5-7b")
llava_model = LlavaForConditionalGeneration.from_pretrained("liuhaotian/llava-v1.5-7b")

def generate_llava_caption(image):
    inputs = llava_processor(image, return_tensors="pt")
    output = llava_model.generate(**inputs)
    return llava_processor.decode(output[0], skip_special_tokens=True)

In [None]:
llava_caption = generate_llava_caption(img)
print("LLaVA Caption:", llava_caption)
