```
This notebook shows how we generated the llava_captions for the whole dataset by running batches and concatenating them at the end in a json file.

```

# Imports

In [None]:
# !pip install ftfy regex tqdm
# !pip install git+https://github.com/openai/CLIP.git
# !pip install --upgrade -q accelerate bitsandbytes
# !pip install git+https://github.com/huggingface/transformers.git
from transformers import AutoProcessor, LlavaForConditionalGeneration
from transformers import BitsAndBytesConfig
import torch
import torch.nn as nn
from torchvision import transforms, datasets
from torch.utils.data import Dataset, DataLoader, random_split
import clip
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import os, json
import pandas as pd
from tqdm import tqdm
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split

from utils import *

Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to c:\users\egypt\appdata\local\temp\pip-req-build-ubv0k7rj
  Resolved https://github.com/openai/CLIP.git to commit dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
  Preparing metadata (setup.py): started
  Preparing metadata (setup.py): finished with status 'done'
Building wheels for collected packages: clip
  Building wheel for clip (setup.py): started
  Building wheel for clip (setup.py): finished with status 'done'
  Created wheel for clip: filename=clip-1.0-py3-none-any.whl size=1369570 sha256=62c95addecbaab850064d692edfb3d05db9aec32c5a2f1cdcd31e8b3c11d0b32
  Stored in directory: C:\Users\EGYPT\AppData\Local\Temp\pip-ephem-wheel-cache-ih6wu0hp\wheels\c8\e4\e1\11374c111387672fc2068dfbe0d4b424cb9cdd1b2e184a71b5
Successfully built clip
Installing collected packages: clip
Successfully installed clip-1.0


  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git 'C:\Users\EGYPT\AppData\Local\Temp\pip-req-build-ubv0k7rj'
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
torchaudio 0.10.1 requires torch==1.10.1, but you have torch 2.6.0 which is incompatible.
torchvision 0.11.2 requires torch==1.10.1, but you have torch 2.6.0 which is incompatible.


^C


In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

# Load Dataset

In [None]:
# Load the CUB-200-2011 dataset

data_dir = '/kaggle/input/cub2002011/CUB_200_2011'
images_dir = os.path.join(data_dir, 'images')
parts_dir = os.path.join(data_dir, 'parts')

images, labels, classes, bounding_boxes, parts, part_locs, parts_click_locs, attributes, certainties, image_attribute_labels, _ = load_cub_dataset(data_dir)

print(images.head())
print(labels.head())
print(classes.head())

print(images.shape)
print(labels.shape)
print(classes.shape)

In [None]:
data_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.481, 0.457, 0.408), std=(0.268, 0.261, 0.275))
])

quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16)
llava_model_id = "llava-hf/llava-1.5-7b-hf"
llava_processor = AutoProcessor.from_pretrained(llava_model_id)
llava_model  = LlavaForConditionalGeneration.from_pretrained(llava_model_id, quantization_config=quantization_config, device_map="auto")

In [None]:
class CustomDataset(Dataset):
    def __init__(self, data_dir, start, end, process_batches=True, transform=None, use_llava=True, batch_size=500, save_dir="processed_batches"):
        self.transform = transform
        self.image_dir = data_dir
        self.image_paths = []
        self.labels = []
        self.parts_annotations = {}
        self.text_prompts = {}
        self.use_llava = use_llava
        self.batch_size = batch_size
        self.save_dir = save_dir
        self.start_idx = start
        self.end_idx = end

        os.makedirs(save_dir, exist_ok=True)

        self.classes = sorted(os.listdir(data_dir))
        self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.classes)}

        images_file = os.path.join(os.path.dirname(data_dir), 'images.txt')
        images_df = pd.read_csv(images_file, sep=' ', names=['image_id', 'file_path'], index_col=0)

        parts_file = os.path.join(os.path.dirname(data_dir), 'parts', 'parts.txt')
        parts_df = pd.read_fwf(parts_file, colspecs=[(0, 2), (2, None)], header=None, names=['part_id', 'part_name'])
        parts_df.set_index('part_id', inplace=True)
        self.part_names = parts_df.to_dict()['part_name']

        part_locs_file = os.path.join(os.path.dirname(data_dir), 'parts', 'part_locs.txt')
        part_locs_df = pd.read_csv(part_locs_file, sep=r'\s+', names=['image_id', 'part_id', 'x', 'y', 'visible'])

        for _, row in part_locs_df.iterrows():
            if row['visible'] == 1:
                image_id = int(row['image_id'])
                part_id = int(row['part_id'])

                if part_id in self.part_names:
                    if image_id not in self.parts_annotations:
                        self.parts_annotations[image_id] = []
                    self.parts_annotations[image_id].append({
                        'part_name': self.part_names[part_id],
                        'x': row['x'],
                        'y': row['y']
                    })
        if (process_batches):
            self.process_batches(images_df)

    def process_batches(self, images_df):
        num_images = self.end_idx - self.start_idx
        num_batches = (num_images // self.batch_size) + 1
    
        for batch_idx in range(num_batches):
            batch_start = batch_idx * self.batch_size + self.start_idx
            batch_end = min(batch_start + self.batch_size, len(images_df))
            batch_file = os.path.join(self.save_dir, f"batch_{batch_start}_{batch_end}.json")
    
            # skip batch if it's already processed
            # if os.path.exists(batch_file):
            #     print(f"Loaded existing batch: {batch_file}")
            #     continue
    
            batch_data = {}
    
            for image_id, row in tqdm(images_df.iloc[batch_start:batch_end].iterrows(),total=batch_end - batch_start, desc="Processing Images in batch"):
                file_path = row['file_path']
                class_name = file_path.split('/')[0]  # get class name from path
                img_path = os.path.join(self.image_dir, file_path)
    
                parts = self.parts_annotations.get(image_id, [])
    
                if self.use_llava:
                    llava_text = self.generate_llava_prompt(img_path, parts, class_name)
                else:
                    llava_text = "No description available."
    
                # save full dataset info
                batch_data[str(image_id)] = {
                    "image_path": img_path,
                    "class_label": class_name,
                    "parts": parts,
                    "llava_text": llava_text
                }
    
            
            with open(batch_file, "w") as f:
                json.dump(batch_data, f, indent=4)
    
            print(f"Saved batch: {batch_file}")


    def __len__(self):
        return len(self.text_prompts) 

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]
    
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
    
        image_id = list(self.parts_annotations.keys())[idx]
    
        # Load saved batch file and extract info (Metadata)
        batch_file = os.path.join(self.save_dir, f"batch_{image_id // self.batch_size * self.batch_size}_{(image_id // self.batch_size + 1) * self.batch_size}.json")
    
        if os.path.exists(batch_file):
            with open(batch_file, "r") as f:
                batch_data = json.load(f)
            image_data = batch_data.get(str(image_id), {})
    
            parts = image_data.get("parts", [])
            text = image_data.get("llava_text", "No description available.")
        else:
            parts = self.parts_annotations.get(image_id, [])
            text = "No description available."
    
        return image, torch.tensor(label, dtype=torch.long), parts, text


    def generate_llava_prompt(self, img_path, visible_parts, class_name):
        if not visible_parts:
            return f"Describe the bird in the picture. It is a {class_name}."

        prompts = [f"USER: <image>\nPlease describe the {part['part_name']} of the bird in the picture in one sentence.\nASSISTANT:" for part in visible_parts]
        prompts.append(f"USER: <image>\nPlease describe the environment of the image given that the bird is a {class_name}.\nASSISTANT:")

        generated_caption = []
        for prompt in prompts:
            # print("test")
            image = Image.open(img_path).convert("RGB")
            inputs = llava_processor(text=prompt, images=[image], padding=True, return_tensors="pt").to("cuda")
            output = llava_model.generate(**inputs, max_new_tokens=1000)
            generated_text = llava_processor.batch_decode(output, skip_special_tokens=True)

            for text in generated_text:
                generated_caption.append(text.split("ASSISTANT:")[-1])

        return " ".join(generated_caption)


image_dir = "/data/CUB_200_2011/images"

custom_dataset = CustomDataset(image_dir, start=5000, end = 6000,
                               process_batches=True, transform=data_transforms, batch_size=100)


ex_image, ex_label, ex_parts, ex_text = custom_dataset[0]
print(f"Class Label: {ex_label}, Visible Parts: {ex_parts}")
print(f"Llava-Generated Text: {ex_text}")


Processing Images in batch: 100%|██████████| 5/5 [01:50<00:00, 22.05s/it]


Saved batch: processed_batches/batch_5000_5005.json


Processing Images in batch:  20%|██        | 1/5 [00:39<02:37, 39.35s/it]


KeyboardInterrupt: 