# Imports

In [1]:
# !pip install ftfy regex tqdm
# # !pip install git+https://github.com/openai/CLIP.git
# !pip install --upgrade -q accelerate bitsandbytes
# !pip install git+https://github.com/huggingface/transformers.git
from transformers import AutoProcessor, LlavaForConditionalGeneration
from transformers import BitsAndBytesConfig
import torch
import torch.nn as nn
from torchvision import transforms, datasets
from torch.utils.data import Dataset, DataLoader, random_split
# import clip
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import os, json
import pandas as pd
from tqdm import tqdm
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

# Load Dataset

In [3]:
# # Load the CUB-200-2011 dataset
# def load_cub_dataset(data_dir):
#     images = pd.read_csv(os.path.join(data_dir, 'images.txt'), sep=' ', names=['image_id', 'file_path'])
#     labels = pd.read_csv(os.path.join(data_dir, 'image_class_labels.txt'), sep=' ', names=['image_id', 'class_id'])
#     classes = pd.read_csv(os.path.join(data_dir, 'classes.txt'), sep=' ', names=['class_id', 'class_name'])
#     bounding_boxes = pd.read_csv(os.path.join(data_dir, 'bounding_boxes.txt'), sep=' ', names=['image_id', 'x', 'y', 'width', 'height'])
#     part_locs = pd.read_csv(os.path.join(data_dir, 'parts/part_locs.txt'), sep=' ', names=['img_id', 'part_id', 'x', 'y', 'visible'])
#     # parts = pd.read_csv(os.path.join(data_dir, 'parts/parts.txt'), delimiter =' ', names=['part_id', 'part_name'])
#     parts = pd.read_fwf(os.path.join(data_dir, 'parts/parts.txt'), colspecs=[(0, 2), (2, None)], header=None, names=['part_id', 'part_name'])
#     parts_click_locs = pd.read_csv(os.path.join(data_dir, 'parts/part_click_locs.txt'), sep = ' ', names=['image_id', 'part_id', 'x', 'y', 'visible', 'time'])
#     attributes = pd.read_csv(os.path.join(data_dir, 'attributes/attributes.txt'), sep = ' ', names=['attribute_id', 'attribute_name'])
#     certainties = pd.read_fwf(os.path.join(data_dir, 'attributes/certainties.txt'), colspecs=[(0, 1), (2, None)], names=["certainty_id", "certainty_name"])
#     image_attribute_labels = pd.read_csv(os.path.join(data_dir, 'attributes/image_attribute_labels.txt'),
#                                          # sep = ' ',
#                                          names=['image_id', 'attribute_id', 'is_present', 'certainty_id', 'time'],
#                                          delim_whitespace=True, usecols=range(5)
#                                         )
#     return images, labels, classes,  bounding_boxes, parts, part_locs, parts_click_locs, attributes, certainties, image_attribute_labels
# data_dir = '/kaggle/input/cub2002011/CUB_200_2011'
# images_dir = os.path.join(data_dir, 'images')
# parts_dir = os.path.join(data_dir, 'parts')

# images, labels, classes, bounding_boxes, parts, part_locs, parts_click_locs, attributes, certainties, image_attribute_labels = load_cub_dataset(data_dir)

# print(images.head())
# print(labels.head())
# print(classes.head())

# print(images.shape)
# print(labels.shape)
# print(classes.shape)

In [4]:
# for _, row in images.iterrows():
#     print(row['image_id'])
#     break

In [5]:
data_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.481, 0.457, 0.408), std=(0.268, 0.261, 0.275))
])

quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16)
llava_model_id = "llava-hf/llava-1.5-7b-hf"
llava_processor = AutoProcessor.from_pretrained(llava_model_id)
llava_model  = LlavaForConditionalGeneration.from_pretrained(llava_model_id, quantization_config=quantization_config, device_map="auto")

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.48, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
Loading checkpoint shards: 100%|██████████| 3/3 [00:38<00:00, 12.69s/it]


In [6]:
class CustomDataset(Dataset):
    def __init__(self, data_dir, start, end, process_batches=True, transform=None, use_llava=True, batch_size=500, save_dir="processed"):
        self.transform = transform
        self.image_dir = data_dir
        self.image_paths = []
        self.labels = []
        self.parts_annotations = {}
        self.text_prompts = {}
        self.use_llava = use_llava
        self.batch_size = batch_size
        self.save_dir = save_dir
        self.start_idx = start
        self.end_idx = end

        os.makedirs(save_dir, exist_ok=True)

        self.classes = sorted(os.listdir(data_dir))
        self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.classes)}

        images_file = os.path.join(os.path.dirname(data_dir), 'images.txt')
        images_df = pd.read_csv(images_file, sep=' ', names=['image_id', 'file_path'], index_col=0)

        parts_file = os.path.join(os.path.dirname(data_dir), 'parts', 'parts.txt')
        parts_df = pd.read_fwf(parts_file, colspecs=[(0, 2), (2, None)], header=None, names=['part_id', 'part_name'])
        parts_df.set_index('part_id', inplace=True)
        self.part_names = parts_df.to_dict()['part_name']

        part_locs_file = os.path.join(os.path.dirname(data_dir), 'parts', 'part_locs.txt')
        part_locs_df = pd.read_csv(part_locs_file, sep=r'\s+', names=['image_id', 'part_id', 'x', 'y', 'visible'])

        for _, row in part_locs_df.iterrows():
            if row['visible'] == 1:
                image_id = int(row['image_id'])
                part_id = int(row['part_id'])

                if part_id in self.part_names:
                    if image_id not in self.parts_annotations:
                        self.parts_annotations[image_id] = []
                    self.parts_annotations[image_id].append({
                        'part_name': self.part_names[part_id],
                        'x': row['x'],
                        'y': row['y']
                    })
        if (process_batches):
            self.process_batches(images_df)

    # def process_batches(self, images_df):
    #     """Processes images in batches and saves results in separate JSON files."""
    #     num_images = len(images_df)
    #     num_batches = (num_images // self.batch_size) + 1
    
    #     for batch_idx in range(num_batches):
    #         batch_start = batch_idx * self.batch_size
    #         batch_end = min(batch_start + self.batch_size, num_images)
    
    #         batch_file = os.path.join(self.save_dir, f"batch_{batch_start}_{batch_end}.json")
    
    #         # Skip batch if it's already processed
    #         if os.path.exists(batch_file):
    #             print(f"Loaded existing batch: {batch_file}")
    #             continue
    
    #         batch_data = {}
    
    #         for image_id, row in tqdm(images_df.iloc[batch_start:batch_end].iterrows(),total=batch_end - batch_start, desc="Processing Images in batch"):
    #             file_path = row['file_path']
    #             class_name = file_path.split('/')[0]  # Extract class name from path
    #             img_path = os.path.join(self.image_dir, file_path)
    
    #             # Get part annotations (if available)
    #             parts = self.parts_annotations.get(image_id, [])
    
    #             # Generate Llava text
    #             if self.use_llava:
    #                 llava_text = self.generate_llava_prompt(img_path, parts, class_name)
    #             else:
    #                 llava_text = "No description available."
    
    #             # Store full dataset info
    #             batch_data[str(image_id)] = {
    #                 "image_path": img_path,
    #                 "class_label": class_name,
    #                 "parts": parts,
    #                 "llava_text": llava_text
    #             }
    
    #         # Save batch to JSON
    #         with open(batch_file, "w") as f:
    #             json.dump(batch_data, f, indent=4)
    
    #         print(f"Saved batch: {batch_file}")

            # exit()

    def process_batches(self, images_df):
        """Processes images in batches and saves results in separate JSON files."""
        num_images = self.end_idx - self.start_idx
        num_batches = (num_images // self.batch_size) + 1
    
        for batch_idx in range(num_batches):
            batch_start = batch_idx * self.batch_size + self.start_idx
            batch_end = min(batch_start + self.batch_size, len(images_df))
            batch_file = os.path.join(self.save_dir, f"batch_{batch_start}_{batch_end}.json")
    
            # Skip batch if it's already processed
            # if os.path.exists(batch_file):
            #     print(f"Loaded existing batch: {batch_file}")
            #     continue
    
            batch_data = {}
    
            for image_id, row in tqdm(images_df.iloc[batch_start:batch_end].iterrows(),total=batch_end - batch_start, desc="Processing Images in batch"):
                file_path = row['file_path']
                class_name = file_path.split('/')[0]  # Extract class name from path
                img_path = os.path.join(self.image_dir, file_path)
    
                # Get part annotations (if available)
                parts = self.parts_annotations.get(image_id, [])
    
                # Generate Llava text
                if self.use_llava:
                    llava_text = self.generate_llava_prompt(img_path, parts, class_name)
                else:
                    llava_text = "No description available."
    
                # Store full dataset info
                batch_data[str(image_id)] = {
                    "image_path": img_path,
                    "class_label": class_name,
                    "parts": parts,
                    "llava_text": llava_text
                }
    
            # Save batch to JSON
            with open(batch_file, "w") as f:
                json.dump(batch_data, f, indent=4)
    
            print(f"Saved batch: {batch_file}")


    def __len__(self):
        """Returns the total number of images"""
        return len(self.text_prompts)  # Load only processed data

    def __getitem__(self, idx):
        """Loads an image, its label, parts, and saved Llava text from JSON."""
        img_path = self.image_paths[idx]
        label = self.labels[idx]
    
        # Load and preprocess image
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
    
        # Retrieve image ID
        image_id = list(self.parts_annotations.keys())[idx]
    
        # Load saved batch file and extract metadata
        batch_file = os.path.join(self.save_dir, f"batch_{image_id // self.batch_size * self.batch_size}_{(image_id // self.batch_size + 1) * self.batch_size}.json")
    
        if os.path.exists(batch_file):
            with open(batch_file, "r") as f:
                batch_data = json.load(f)
            image_data = batch_data.get(str(image_id), {})
    
            parts = image_data.get("parts", [])
            text = image_data.get("llava_text", "No description available.")
        else:
            parts = self.parts_annotations.get(image_id, [])
            text = "No description available."
    
        return image, torch.tensor(label, dtype=torch.long), parts, text


    def generate_llava_prompt(self, img_path, visible_parts, class_name):
        """Generates a text description for an image using Llava"""
        if not visible_parts:
            return f"Describe the bird in the picture. It is a {class_name}."

        prompts = [f"USER: <image>\nPlease describe the {part['part_name']} of the bird in the picture in one sentence.\nASSISTANT:" for part in visible_parts]
        prompts.append(f"USER: <image>\nPlease describe the environment of the image given that the bird is a {class_name}.\nASSISTANT:")

        # Generate text
        generated_caption = []
        for prompt in prompts:
            # print("caps")
            image = Image.open(img_path).convert("RGB")
            inputs = llava_processor(text=prompt, images=[image], padding=True, return_tensors="pt").to("cuda")
            output = llava_model.generate(**inputs, max_new_tokens=1000)
            generated_text = llava_processor.batch_decode(output, skip_special_tokens=True)

            for text in generated_text:
                generated_caption.append(text.split("ASSISTANT:")[-1])

        return " ".join(generated_caption)



In [None]:
image_dir = "data\images"

custom_dataset = CustomDataset(image_dir, start=2000, end = 5000,
                               process_batches=True, transform=data_transforms, batch_size=200)
# sample_image, sample_label, sample_parts, sample_text  = custom_dataset[0]
# print(f"Class Label: {sample_label}, Visible Parts: {sample_parts}")
# print(f"Llava-Generated Text: {sample_text}")

Processing Images in batch: 100%|██████████| 200/200 [9:11:18<00:00, 165.39s/it]  


Saved batch: processed\batch_2000_2200.json


Processing Images in batch:   2%|▏         | 3/200 [04:49<5:20:51, 97.72s/it]

In [None]:
dataset_size = len(custom_dataset)
train_size = int(0.8 * dataset_size)
test_size = dataset_size - train_size

train_dataset, test_dataset = random_split(custom_dataset, [train_size, test_size])

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)

test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)
class_names = custom_dataset.classes  # List of class names (e.g., bird species)
num_classes = len(class_names)
print("loaded")

## CLIP

In [8]:
model, preprocess = clip.load("ViT-B/32", device=device)

100%|████████████████████████████████████████| 338M/338M [00:02<00:00, 154MiB/s]


In [None]:
def get_clip_img_features(img_path):
    image = preprocess(Image.open(img_path)).unsqueeze(0).to(device)
    with torch.no_grad():
        img_features = model.encode_image(image).cpu().numpy()
    return img_features

def get_clip_text_features(text):
    text_inputs = clip.tokenize([text]).to(device)
    with torch.no_grad():
        text_features = model.encode_text(text_inputs).cpu().numpy()
    return text_features

def get_cosine_similarity(img_features: torch.Tensor, txt_features: torch.Tensor) -> torch.Tensor:
    """
    Computes the cosine similarity between image and text feature tensors.
    
    Parameters:
    img_features (torch.Tensor): Feature tensor for the image.
    txt_features (torch.Tensor): Feature tensor for the text.

    Returns:
    torch.Tensor: Cosine similarity score between -1 and 1.
    """

    img_features = img_features.squeeze()
    txt_features = txt_features.squeeze()
    
    # Compute dot product
    dot_product = torch.dot(img_features, txt_features)
    
    # Compute L2 norms
    norm_img = torch.norm(img_features, p=2)
    norm_txt = torch.norm(txt_features, p=2)
    
    # Avoid division by zero
    if norm_img == 0 or norm_txt == 0:
        return torch.tensor(0.0)  # Handle zero-vector cases
    
    # Compute cosine similarity
    similarity = dot_product / (norm_img * norm_txt)
    
    return similarity.item()

def chunking_llava(text, img_path, window_size=77):
    img_features = get_clip_img_features(img_path)
    sentences = [sentence.strip() for sentence in text.split('.') if sentence.strip()]
    similarity_score_avg = 0
    for sentence in sentences:
        text_features = get_clip_text_features(sentence)
        similarity_score = get_cosine_similarity(img_features, text_features)
        similarity_score_avg += similarity_score
    similarity_score_avg /= len(sentences)
    return similarity_score_avg

In [None]:
def extract_embeddings(dataloader, model, device):
    image_embeds = []
    labels_list = []
    captions = []
    text_embeds = []
    
    with torch.no_grad():
        for imgs, labels, _, text in dataloader:

            labels_list.extend(labels.numpy())

            
            image_embeds_temp = np.vstack([get_clip_img_features(img_path) for img_path in imgs])
            image_embeds.append(img_embedding_temp)

            text_embeds_temp =  np.vstack([chunking_llava(texts[i], img_path=None)[0] for i in range(len(texts))])
            text_embeds.append(text_embedding_batch)

    image_embeds = np.vstack(image_embeds)
    text_embeds = np.vstack(text_embeds)
    labels_list = np.array(labels_list)
    
    return image_embeds, text_embeds, labels_list

# Extract embeddings for train, validation, and test sets
train_image_embeds, train_text_embeds, train_labels = extract_embeddings(train_dataloader, model, device)
test_image_embeds,test_text_embeds, test_labels = extract_embeddings(test_dataloader, model, device)

# train_combined = np.vhstack((train_image_embeds, train_text_embeds)

# Print embedding shapes
print(f"Train: {train_embeds.shape[0]} images with embeddings of size {train_embeds.shape[1]}")
print(f"Test: {test_embeds.shape[0]} images with embeddings of size {test_embeds.shape[1]}")


In [None]:
np.save('train_embeds.npy', train_embeds)
np.save('test_embeds.npy', test_embeds)
np.save('train_labels.npy', train_labels)
np.save('test_labels.npy', test_labels)

## Classifier

In [None]:
clf = LogisticRegression(max_iter=1000, solver="lbfgs", multi_class="multinomial")
clf.fit(train_embeds, train_labels)

train_preds = clf.predict(train_embeds)
test_preds = clf.predict(test_embeds)

# Compute accuracy for each set
train_acc = accuracy_score(train_labels, train_preds)
test_acc = accuracy_score(test_labels, test_preds)

# Print results
print(f"Train Accuracy: {train_acc:.4f}")
print(f"Test Accuracy: {test_acc:.4f}")

In [None]:
import joblib

# Save the trained model
joblib.dump(clf, "logistic_regression_model.pkl")

In [None]:
def baseline(img_path):
    with torch.no_grad():
        img = Image.open(img_path).convert('RGB')
        img = data_transforms(img)  # Apply transformations
        img = img.unsqueeze(0).to(device)  # Add batch dimension
        features = model.encode_image(img).cpu().numpy()
        pred_class = clf.predict(features)
        return pred_class[0]

In [None]:
img_id = 800
img_path = os.path.join(images_dir, images[images['image_id'] == img_id]['file_path'].iloc[0])
img = Image.open(img_path)

class_id = labels[labels['image_id'] == img_id]['class_id'].iloc[0]
class_name = " ".join(classes[classes['class_id'] == class_id]['class_name'].iloc[0][4:].split('_'))
print(f'Class: {class_name}')

pred_class = baseline(img_path)
pred_class_name = " ".join(classes[classes['class_id'] == pred_class + 1]['class_name'].iloc[0][4:].split('_'))
print(f'Pred Class: {pred_class_name}')

image = Image.open(img_path)
plt.imshow(image)
plt.axis('off')
plt.show()