In [2]:
from google.colab import drive
import os

drive.mount('/content/drive')

drive.mount('/content/drive')

dataset_path = "/content/drive/My Drive/Uni/Bird Exploration/CUB_200_2011"

if os.path.exists(dataset_path):
  print(f"✅ CUB dataset is available at: {dataset_path}")
else:
  print("❌ Dataset not found! Run the download script first.")

Mounted at /content/drive
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
✅ CUB dataset is available at: /content/drive/My Drive/Uni/Bird Exploration/CUB_200_2011


In [None]:

!pip install ftfy regex tqdm nltk peft wandb weave
# !pip install git+https://github.com/openai/CLIP.git
!pip install --upgrade -q accelerate bitsandbytes
!pip install git+https://github.com/huggingface/transformers.git
from transformers import AutoProcessor, LlavaForConditionalGeneration
from transformers import BitsAndBytesConfig
import torch
import torch.nn as nn
from torchvision import transforms, datasets
from torch.utils.data import Dataset, DataLoader, random_split
# import clip
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import os, json
import pandas as pd
from tqdm import tqdm
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
import nltk

from peft import LoraConfig, get_peft_model
from transformers import CLIPProcessor, CLIPModel


nltk.download('punkt_tab')


Collecting ftfy
  Downloading ftfy-6.3.1-py3-none-any.whl.metadata (7.3 kB)


In [None]:
import wandb

wandb.login()

wandb.init(
    project="CLIP_LoRA_Bird_Exploration",  # Set a project name
    name="CLIP_LoRA_Training_Run",  # Run name
    id="d4sze36t",
    config={"epochs": 5, "learning_rate": 5e-5, "batch_size": 8}  # Track hyperparameters
)


# ✅ Save Model Function (Fixed)
def save_model(epoch):
    save_path = f"/content/drive/MyDrive/Uni/Bird Exploration/clip_lora_trained_epoch_{epoch+1}"
    os.makedirs(save_path, exist_ok=True)

    # ✅ Save the entire CLIP + LoRA Model
    clip_model.save_pretrained(save_path)

    # ✅ Save tokenizer & processor
    processor.save_pretrained(save_path)

    print(f"✅ Model saved at {save_path}")

def compute_test_loss(model, dataloader, loss_fn):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for images, encoded_texts in dataloader:
            images, encoded_texts = images.to(device), encoded_texts.to(device)

            # Forward pass
            outputs_pos = model.get_text_features(input_ids=encoded_texts)
            outputs_img = model.get_image_features(pixel_values=images)

            # Ensure batch sizes are equal
            min_batch_size = min(outputs_pos.shape[0], outputs_img.shape[0])
            outputs_pos, outputs_img = outputs_pos[:min_batch_size], outputs_img[:min_batch_size]

            # Contrastive loss
            labels = torch.ones(min_batch_size).to(device)
            loss = loss_fn(outputs_pos, outputs_img, labels)

            total_loss += loss.item()

    return total_loss / len(dataloader)  # Average test loss



In [None]:
import torch
import wandb
from transformers import CLIPProcessor, CLIPModel
from peft import PeftModel
from tqdm import tqdm
import os
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import random_split, Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

# ✅ Load LoRA Model & Processor
save_path = "/content/drive/MyDrive/Uni/Bird Exploration/clip_lora_trained_final"  # Adjust if needed
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ✅ Load CLIP Model & Apply LoRA Adapter

model_name = "openai/clip-vit-base-patch32"
clip_model = CLIPModel.from_pretrained(model_name).to(device)
processor = CLIPProcessor.from_pretrained(model_name)

# clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
# clip_model = PeftModel.from_pretrained(clip_model, save_path).to(device)

# # ✅ Load Processor
# processor = CLIPProcessor.from_pretrained(save_path)

# print("✅ LoRA Adapter successfully loaded! Resuming training...")

# ✅ Load Dataset
import json
with open("/content/drive/MyDrive/Uni/Bird Exploration/CUB_200_2011/llava_captions.json", "r") as f:
    dataset = json.load(f)

# ✅ Dataset Class for Bird Exploration
class BirdDataset(Dataset):
    def __init__(self, dataset, root_dir="/content/drive/My Drive/Uni/Bird Exploration/CUB_200_2011/images/"):
        self.data = list(dataset.values())
        self.root_dir = root_dir
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
        ])

    def __len__(self):
        return len(self.data)

    def chunk_text(self, text, chunk_size=77, overlap=10):
        """Splits text into overlapping 77-token chunks for CLIP compatibility."""
        tokens = processor.tokenizer(
            text, return_tensors="pt", padding=False, truncation=False
        )["input_ids"].squeeze()

        # Ensure correct tensor format
        if len(tokens.shape) == 1:
            tokens = tokens.unsqueeze(0)

        chunks = []
        for i in range(0, len(tokens[0]), chunk_size - overlap):
            chunk = tokens[0, i:i+chunk_size]

            # ✅ Ensure every chunk is 77 tokens (Pad if necessary)
            if len(chunk) < chunk_size:
                padding = torch.zeros(chunk_size - len(chunk), dtype=torch.long)
                chunk = torch.cat([chunk, padding])

            chunks.append(chunk)

        return chunks if chunks else [torch.zeros(chunk_size, dtype=torch.long)]

    def __getitem__(self, idx):
        positive_sample = self.data[idx]

        # Load Image
        class_label = positive_sample["class_label"]
        image_filename = os.path.basename(positive_sample["image_path"])
        image_path = os.path.join(self.root_dir, class_label, image_filename)

        if not os.path.exists(image_path):
            image = torch.zeros((3, 224, 224))
        else:
            image = Image.open(image_path).convert("RGB")
            image = self.transform(image)

        # Process text (Chunk LLAVA Captions Properly)
        text = positive_sample.get("llava_text", "").strip()
        text_chunks = self.chunk_text(text)

        # ✅ Ensure every text chunk has exactly 77 tokens
        encoded_texts = torch.stack(text_chunks)

        return image, encoded_texts

# ✅ Custom Collate Function
def collate_fn(batch):
    images, text_chunks = zip(*batch)

    images_tensor = torch.stack(images).to(device)

    # ✅ Ensure batch consistency for text sequences
    max_length = 77  # Enforce max token length
    texts_tensor = pad_sequence(
        [chunk[:max_length] for chunks in text_chunks for chunk in chunks],
        batch_first=True, padding_value=0
    ).to(device)

    return images_tensor, texts_tensor

# ✅ Split Dataset into Train (70%), Validation (15%), Test (15%)
train_size = int(0.7 * len(dataset))
val_size = int(0.15 * len(dataset))
test_size = len(dataset) - train_size - val_size
dataset_items = list(dataset.items())
train_data, val_data, test_data = random_split(dataset_items, [train_size, val_size, test_size])
train_dataset, val_dataset, test_dataset = BirdDataset(dict(train_data)), BirdDataset(dict(val_data)), BirdDataset(dict(test_data))

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False, collate_fn=collate_fn)

# ✅ Optimizer & Contrastive Loss
optimizer = torch.optim.AdamW(clip_model.parameters(), lr=1e-3)
loss_fn = torch.nn.CosineEmbeddingLoss()

# ✅ Training Loop with Validation Loss & Model Saving
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True)

# ✅ Function to Compute Validation Loss
def compute_val_loss(model, dataloader, loss_fn):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for images, encoded_texts in dataloader:
            images, encoded_texts = images.to(device), encoded_texts.to(device)

            outputs_pos = model.get_text_features(input_ids=encoded_texts)
            outputs_img = model.get_image_features(pixel_values=images)

            min_batch_size = min(outputs_pos.shape[0], outputs_img.shape[0])
            outputs_pos, outputs_img = outputs_pos[:min_batch_size], outputs_img[:min_batch_size]

            labels = torch.ones(min_batch_size).to(device)
            loss = loss_fn(outputs_pos, outputs_img, labels)
            total_loss += loss.item()

    return total_loss / len(dataloader)  # Average validation loss

# ✅ Resume Training from Epoch 1 to 5
num_epochs = 5
for epoch in range(num_epochs):
    clip_model.train()
    total_train_loss = 0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch}/{num_epochs}")

    for step, (images, encoded_texts) in enumerate(progress_bar):
        optimizer.zero_grad()
        images, encoded_texts = images.to(device), encoded_texts.to(device)

        # Forward pass
        outputs_pos = clip_model.get_text_features(input_ids=encoded_texts)
        outputs_img = clip_model.get_image_features(pixel_values=images)

        min_batch_size = min(outputs_pos.shape[0], outputs_img.shape[0])
        outputs_pos, outputs_img = outputs_pos[:min_batch_size], outputs_img[:min_batch_size]

        labels = torch.ones(min_batch_size).to(device)
        loss = loss_fn(outputs_pos, outputs_img, labels)

        loss.backward()
        optimizer.step()
        total_train_loss += loss.item()

        # ✅ Log validation loss every few steps
        if step % 100 == 0:
            val_loss = compute_val_loss(clip_model, val_loader, loss_fn)
            wandb.log({"Train Loss": loss.item(), "Validation Loss (Step)": val_loss})
            print(f"Step {step}: Train Loss {loss.item():.4f}, Val Loss {val_loss:.4f}")

        progress_bar.set_postfix(loss=loss.item())

    # ✅ Compute & Log Final Validation Loss
    val_loss = compute_val_loss(clip_model, val_loader, loss_fn)
    wandb.log({"Validation Loss (Epoch)": val_loss})

    # ✅ Reduce LR if validation loss stops improving
    scheduler.step(val_loss)

    # ✅ Save the model at the end of each epoch
    save_model(epoch, clip_model, processor)

wandb.finish()


✅ LoRA Adapter successfully loaded! Resuming training...
trainable params: 491,520 || all params: 151,768,833 || trainable%: 0.3239


Epoch 1/5:   0%|          | 0/1179 [00:00<?, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (199 > 77). Running this sequence through the model will result in indexing errors
Epoch 1/5:  22%|██▏       | 255/1179 [03:45<13:38,  1.13it/s, loss=0.000394]


KeyboardInterrupt: 

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

print(device)
model_name = "openai/clip-vit-base-patch32"
clip_model = CLIPModel.from_pretrained(model_name).to(device)
processor = CLIPProcessor.from_pretrained(model_name)

cuda


Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.
You are not authenticated with the Hugging Face Hub in this notebook.
If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).


config.json:   0%|          | 0.00/4.19k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/605M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/605M [00:00<?, ?B/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.50, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


preprocessor_config.json:   0%|          | 0.00/316 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/592 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/862k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/525k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.22M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/389 [00:00<?, ?B/s]

In [None]:
with open("/content/drive/MyDrive/Uni/Bird Exploration/CUB_200_2011/llava_captions.json", "r") as f:
    dataset = json.load(f)

In [None]:

# from torch.nn.utils.rnn import pad_sequence
# class BirdDataset(Dataset):
#     def __init__(self, dataset, root_dir="/content/drive/My Drive/Uni/Bird Exploration/CUB_200_2011/images/"):
#         self.data = list(dataset.values())
#         self.root_dir = root_dir
#         self.transform = transforms.Compose([
#             transforms.Resize((224, 224)),
#             transforms.ToTensor(),
#         ])

#     def __len__(self):
#         return len(self.data)

#     def chunk_text(self, text, chunk_size=77, overlap=10):
#         """Splits text into 77-token chunks for CLIP compatibility."""
#         tokens = processor.tokenizer(
#             text, return_tensors="pt", padding=False, truncation=True, max_length=chunk_size
#         )["input_ids"].squeeze()
#         return tokens.unsqueeze(0)  # Ensure tensor format is correct

#     def __getitem__(self, idx):
#         positive_sample = self.data[idx]

#         # Load Image
#         class_label = positive_sample["class_label"]
#         image_filename = os.path.basename(positive_sample["image_path"])
#         image_path = os.path.join(self.root_dir, class_label, image_filename)

#         if not os.path.exists(image_path):
#             image = torch.zeros((3, 224, 224))
#         else:
#             image = Image.open(image_path).convert("RGB")
#             image = self.transform(image)

#         # Process text
#         text = positive_sample.get("llava_text", "").strip()
#         if not text:
#             text = "No description available."

#         text_chunks = self.chunk_text(text)
#         encoded_texts = text_chunks  # No need to stack

#         # Negative Sample (wrong bird text)
#         neg_idx = np.random.randint(0, len(self.data))
#         while neg_idx == idx:
#             neg_idx = np.random.randint(0, len(self.data))

#         negative_text = self.data[neg_idx].get("llava_text", "").strip()
#         if not negative_text:
#             negative_text = "No negative description available."

#         negative_chunks = self.chunk_text(negative_text)
#         encoded_negative_texts = negative_chunks

#         return image, encoded_texts, encoded_negative_texts

# def collate_fn(batch):
#     images, text_chunks, neg_text_chunks = zip(*batch)

#     images_tensor = torch.stack(images).to(device)

#     # Pad sequences properly
#     texts_tensor = pad_sequence(text_chunks, batch_first=True, padding_value=0).to(device)
#     neg_texts_tensor = pad_sequence(neg_text_chunks, batch_first=True, padding_value=0).to(device)

#     return images_tensor, texts_tensor, neg_texts_tensor\


from torch.nn.utils.rnn import pad_sequence
class BirdDataset(Dataset):
    def __init__(self, dataset, root_dir="/content/drive/My Drive/Uni/Bird Exploration/CUB_200_2011/images/"):
        self.data = list(dataset.values())
        self.root_dir = root_dir
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
        ])

    def __len__(self):
        return len(self.data)

    def chunk_text(self, text, chunk_size=77, overlap=10):
        """Splits text into overlapping 77-token chunks for CLIP compatibility."""
        tokens = processor.tokenizer(
            text, return_tensors="pt", padding=False, truncation=False
        )["input_ids"].squeeze()

        # Create overlapping chunks
        chunks = []
        for i in range(0, len(tokens), chunk_size - overlap):
            chunk = tokens[i:i+chunk_size]
            if len(chunk) < chunk_size:  # Pad shorter chunks
                chunk = torch.cat([chunk, torch.zeros(chunk_size - len(chunk), dtype=torch.long)])
            chunks.append(chunk)

        return chunks  # Return all chunks

    def __getitem__(self, idx):
      positive_sample = self.data[idx]

      # Load Image
      class_label = positive_sample["class_label"]
      image_filename = os.path.basename(positive_sample["image_path"])
      image_path = os.path.join(self.root_dir, class_label, image_filename)

      if not os.path.exists(image_path):
          image = torch.zeros((3, 224, 224))
      else:
          image = Image.open(image_path).convert("RGB")
          image = self.transform(image)

      # Process text (Chunk LLAVA Captions Properly)
      text = positive_sample.get("llava_text", "").strip()
      if not text:
          text = "No description available."

      text_chunks = self.chunk_text(text)  # **Get multiple chunks**
      encoded_texts = torch.stack(text_chunks)  # Convert to tensor

      # Negative Sample (wrong bird text)
      neg_idx = np.random.randint(0, len(self.data))
      while neg_idx == idx:
          neg_idx = np.random.randint(0, len(self.data))

      negative_text = self.data[neg_idx].get("llava_text", "").strip()
      if not negative_text:
          negative_text = "No negative description available."

      negative_chunks = self.chunk_text(negative_text)
      encoded_negative_texts = torch.stack(negative_chunks)

      return image, encoded_texts, encoded_negative_texts

def collate_fn(batch):
    images, text_chunks, neg_text_chunks = zip(*batch)

    images_tensor = torch.stack(images).to(device)

    # **Only Keep Necessary Tokens in Each Sequence**
    max_length = min(77, max(chunk.shape[0] for chunks in text_chunks for chunk in chunks))
    max_neg_length = min(77, max(chunk.shape[0] for chunks in neg_text_chunks for chunk in chunks))

    texts_tensor = pad_sequence(
        [chunk[:max_length] for chunks in text_chunks for chunk in chunks],
        batch_first=True, padding_value=0
    ).to(device)

    neg_texts_tensor = pad_sequence(
        [chunk[:max_neg_length] for chunks in neg_text_chunks for chunk in chunks],
        batch_first=True, padding_value=0
    ).to(device)

    return images_tensor, texts_tensor, neg_texts_tensor




def chunk_text(self, text, chunk_size=77, overlap=10):
    """Splits text into overlapping 77-token chunks for CLIP compatibility."""
    tokens = processor.tokenizer(
        text, return_tensors="pt", padding=False, truncation=False
    )["input_ids"].squeeze()

    chunks = []
    for i in range(0, len(tokens), chunk_size - overlap):
        chunk = tokens[i:i + chunk_size]
        if len(chunk) < chunk_size:  # Pad shorter chunks
            chunk = torch.cat([chunk, torch.zeros(chunk_size - len(chunk), dtype=torch.long)])
        chunks.append(chunk)

    # **Ensure Every Chunk is Exactly 77 Tokens Before Returning**
    return [chunk[:chunk_size] for chunk in chunks]  # ✅ Now, no chunk exceeds 77 tokens


# def collate_fn(batch):
#     images, text_chunks, neg_text_chunks = zip(*batch)

#     images_tensor = torch.stack(images).to(device)

#     # Pad and batch all text chunks
#     texts_tensor = pad_sequence([chunk for chunks in text_chunks for chunk in chunks], batch_first=True, padding_value=0).to(device)
#     neg_texts_tensor = pad_sequence([chunk for chunks in neg_text_chunks for chunk in chunks], batch_first=True, padding_value=0).to(device)

#     return images_tensor, texts_tensor, neg_texts_tensor



# def chunk_text(self, text, chunk_size=77, overlap=10):
#     """Splits text into overlapping 77-token chunks for CLIP compatibility."""
#     tokens = processor.tokenizer(
#         text, return_tensors="pt", padding=False, truncation=False
#     )["input_ids"].squeeze()

#     chunks = []
#     for i in range(0, len(tokens), chunk_size - overlap):
#         chunk = tokens[i:i + chunk_size]
#         if len(chunk) < chunk_size:  # Pad shorter chunks
#             chunk = torch.cat([chunk, torch.zeros(chunk_size - len(chunk), dtype=torch.long)])
#         chunks.append(chunk)

#     return chunks  # Returns all chunks (not just one)






In [None]:
import os
import torch
import torch.nn.functional as F
from tqdm import tqdm
from torch.utils.data import random_split, DataLoader
from transformers import CLIPProcessor, CLIPModel
from peft import LoraConfig, get_peft_model, PeftModel

wandb.init(
    project="Bird Exploration",  # Set a project name
    name="CLIP_LoRA_Training_Run",  # Run name
    config={"epochs": 5, "learning_rate": 5e-5, "batch_size": 8}  # Track hyperparameters
)


# Enable CUDA Debugging
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

# Load CLIP Model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = "openai/clip-vit-base-patch32"
clip_model = CLIPModel.from_pretrained(model_name).to(device)
processor = CLIPProcessor.from_pretrained(model_name)

# ✅ Apply LoRA
lora_config = LoraConfig(
    r=8,
    lora_alpha=32,
    lora_dropout=0.1,
    target_modules=["q_proj", "v_proj"]
)
clip_model = get_peft_model(clip_model, lora_config)
clip_model.print_trainable_parameters()

# ✅ Split Dataset into Train & Test (80/20 split)
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size

# ✅ Convert dataset to list before splitting
dataset_items = list(dataset.items())  # Convert dictionary to list for indexing
train_data, test_data = random_split(dataset_items, [train_size, test_size])

# ✅ Create BirdDataset instances from split data
train_dataset = BirdDataset(dict(train_data))
test_dataset = BirdDataset(dict(test_data))

# ✅ Initialize Data Loaders
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)  # ✅ Shuffle ON for training
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False, collate_fn=collate_fn)  # ✅ Shuffle OFF for testing

# ✅ Optimizer & Contrastive Loss
optimizer = torch.optim.AdamW(clip_model.parameters(), lr=5e-5)
loss_fn = torch.nn.CosineEmbeddingLoss()  # ✅ Use Contrastive Loss

# ✅ Function to Compute Accuracy
def compute_accuracy(model, dataloader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, encoded_texts, _ in dataloader:
            images, encoded_texts = images.to(device), encoded_texts.to(device)

            # Get embeddings
            outputs_text = model.get_text_features(input_ids=encoded_texts)
            outputs_image = model.get_image_features(pixel_values=images)

            # Compute cosine similarity
            similarities = F.cosine_similarity(outputs_text, outputs_image)

            # Count correct predictions (Threshold = 0.5)
            correct += (similarities > 0.5).sum().item()
            total += len(similarities)

    return correct / total  # Accuracy as percentage

# ✅ Training Loop
# ✅ Training Loop with `wandb` Logging
num_epochs = 5
for epoch in range(num_epochs):
    clip_model.train()
    total_loss = 0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")

    for images, encoded_texts, encoded_neg_texts in progress_bar:
        optimizer.zero_grad()

        # Move to GPU
        images, encoded_texts, encoded_neg_texts = images.to(device), encoded_texts.to(device), encoded_neg_texts.to(device)

        # Forward pass (positive text pairs)
        outputs_pos = clip_model.get_text_features(input_ids=encoded_texts)
        outputs_img = clip_model.get_image_features(pixel_values=images)

        # Forward pass (negative text pairs)
        outputs_neg = clip_model.get_text_features(input_ids=encoded_neg_texts)

        # **Ensure batch sizes are equal**
        min_batch_size = min(outputs_pos.shape[0], outputs_img.shape[0])
        outputs_pos, outputs_neg, outputs_img = outputs_pos[:min_batch_size], outputs_neg[:min_batch_size], outputs_img[:min_batch_size]

        # Target Labels: Positive = 1, Negative = -1
        labels = torch.cat([torch.ones(min_batch_size), -torch.ones(min_batch_size)]).to(device)

        # Compute Contrastive Loss
        loss = loss_fn(
            torch.cat([outputs_pos, outputs_neg]),
            torch.cat([outputs_img, outputs_img]),
            labels
        )

        # Backpropagation
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

        # ✅ Log Loss to `wandb`
        wandb.log({"Train Loss": loss.item()})

        # Update Progress Bar
        progress_bar.set_postfix(loss=loss.item())

    # ✅ Compute Accuracy & Log to `wandb`
    train_acc = compute_accuracy(clip_model, train_loader)
    test_acc = compute_accuracy(clip_model, test_loader)

    wandb.log({"Train Accuracy": train_acc, "Test Accuracy": test_acc})

    print(f"Epoch {epoch+1}/{num_epochs} - Loss: {total_loss:.4f} - Train Acc: {train_acc:.4f} - Test Acc: {test_acc:.4f}")

# ✅ Finish `wandb` Logging
wandb.finish()

# ✅ Save CLIP Model & LoRA Weights
save_path = "clip_lora_trained"
os.makedirs(save_path, exist_ok=True)

clip_model.save_pretrained(save_path)
clip_model.save_adapter(save_path, "lora")
processor.save_pretrained(save_path)

print(f"✅ Model and weights saved to {save_path}")

# ✅ Load CLIP Model & LoRA Weights for Inference
def load_clip_lora():
    clip_model = CLIPModel.from_pretrained(save_path).to(device)
    clip_model = PeftModel.from_pretrained(clip_model, f"{save_path}/lora")
    processor = CLIPProcessor.from_pretrained(save_path)
    print("✅ Model and LoRA weights loaded successfully!")
    return clip_model, processor

# ✅ Load & Test Model After Training
clip_model, processor = load_clip_lora()


In [None]:
# train_dataset = BirdDataset(dataset)
# train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)


# optimizer = torch.optim.Adam(clip_model.parameters(), lr=5e-5)
# loss_fn = torch.nn.CosineEmbeddingLoss()  # Contrastive loss



In [None]:
# import os
# import torch
# import torch.nn.functional as F
# from peft import LoraConfig, get_peft_model
# from transformers import CLIPProcessor, CLIPModel

# # Enable CUDA Debugging
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
# os.environ["TORCH_USE_CUDA_DSA"] = "1"

# # Restart CUDA
# torch.cuda.empty_cache()

# # Load CLIP Model
# model_name = "openai/clip-vit-base-patch32"
# clip_model = CLIPModel.from_pretrained(model_name)
# clip_model.to(device)  # Move after loading
# processor = CLIPProcessor.from_pretrained(model_name)

# # Apply LoRA
# lora_config = LoraConfig(
#     r=8,
#     lora_alpha=32,
#     lora_dropout=0.1,
#     target_modules=["q_proj", "v_proj"]
# )
# clip_model = get_peft_model(clip_model, lora_config)
# clip_model.print_trainable_parameters()

# # Optimizer & Training Setup
# optimizer = torch.optim.AdamW(clip_model.parameters(), lr=5e-5)

# # Training Loop
# num_epochs = 5
# train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)

# for epoch in range(num_epochs):
#     clip_model.train()
#     total_loss = 0
#     for images, encoded_texts, encoded_neg_texts in train_loader:
#         optimizer.zero_grad()

#         # Move to GPU
#         images = images.to(device)
#         encoded_texts = encoded_texts.to(device)
#         encoded_neg_texts = encoded_neg_texts.to(device)

#         # Debugging: Print tensor shapes
#         print(f"Encoded texts shape: {encoded_texts.shape}")  # Should be [batch, 77]
#         print(f"Negative texts shape: {encoded_neg_texts.shape}")  # Should be [batch, 77]

#         # Forward pass (positive pairs)
#         outputs = clip_model(input_ids=encoded_texts, pixel_values=images)
#         logits_per_image = outputs.logits_per_image

#         # Ensure Correct Target Tensor
#         targets = torch.arange(len(images), dtype=torch.long).to(device)
#         print(f"Targets shape: {targets.shape}")  # Should match batch_size
#         print(f"Target values: {targets}")  # Should be 0 to batch_size-1

#         # Compute Loss
#         loss = F.cross_entropy(logits_per_image, targets)

#         # Forward pass (negative pairs)
#         neg_outputs = clip_model(input_ids=encoded_neg_texts, pixel_values=images)
#         neg_logits = neg_outputs.logits_per_image

#         # Compute Negative Loss
#         neg_loss = F.cross_entropy(neg_logits, targets)
#         loss += neg_loss

#         # Backpropagation
#         loss.backward()
#         optimizer.step()
#         total_loss += loss.item()

#     print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss:.4f}")

# # Save LoRA fine-tuned CLIP
# clip_model.save_pretrained("clip_lora_trained")

import os
import torch
import torch.nn.functional as F
from peft import LoraConfig, get_peft_model
from transformers import CLIPProcessor, CLIPModel

# Enable CUDA Debugging
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

# Load CLIP Model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = "openai/clip-vit-base-patch32"
clip_model = CLIPModel.from_pretrained(model_name).to(device)
processor = CLIPProcessor.from_pretrained(model_name)

# Apply Corrected LoRA
lora_config = LoraConfig(
    r=8,
    lora_alpha=32,
    lora_dropout=0.1,
    target_modules=["q_proj", "v_proj"]
)
clip_model = get_peft_model(clip_model, lora_config)
clip_model.print_trainable_parameters()

# Optimizer & Contrastive Loss
optimizer = torch.optim.AdamW(clip_model.parameters(), lr=5e-5)
loss_fn = torch.nn.CosineEmbeddingLoss()  # ✅ Use Contrastive Loss

# Training Loop
num_epochs = 5
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)

for epoch in range(num_epochs):
    clip_model.train()
    total_loss = 0
    for images, encoded_texts, encoded_neg_texts in train_loader:
        optimizer.zero_grad()

        # Move to GPU
        images = images.to(device)
        encoded_texts = encoded_texts.to(device)
        encoded_neg_texts = encoded_neg_texts.to(device)

        # Forward pass (positive text pairs)
        outputs_pos = clip_model(input_ids=encoded_texts, pixel_values=images)
        embeddings_pos = outputs_pos.text_embeds

        # Forward pass (negative text pairs)
        outputs_neg = clip_model(input_ids=encoded_neg_texts, pixel_values=images)
        embeddings_neg = outputs_neg.text_embeds

        # Target Labels: Positive = 1, Negative = -1
        labels = torch.cat([torch.ones(len(embeddings_pos)), -torch.ones(len(embeddings_neg))]).to(device)

        # Compute Contrastive Loss
        loss = loss_fn(torch.cat([embeddings_pos, embeddings_neg]), torch.cat([embeddings_pos, embeddings_neg]), labels)

        # Backpropagation
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss:.4f}")

# Save LoRA fine-tuned CLIP
clip_model.save_pretrained("clip_lora_trained")



In [None]:
def retrieve_image(query_text):
    inputs = processor(text=[query_text], return_tensors="pt").to(device)

    with torch.no_grad():
        text_features = clip_model(**inputs).text_embeds

    similarities = {}

    for key, sample in dataset.items():
        image = Image.open(sample["image_path"]).convert("RGB")
        inputs = processor(images=image, return_tensors="pt").to(device)

        with torch.no_grad():
            image_features = clip_model(**inputs).image_embeds

        similarity = torch.nn.functional.cosine_similarity(text_features, image_features)
        similarities[key] = similarity.item()

    best_match = max(similarities, key=similarities.get)
    return dataset[best_match]["image_path"]

# Example test
query_text = "A cliff swallow with a red forehead perched on a wooden post."
print(retrieve_image(query_text))


In [None]:
def retrieve_text(image_path):
    image = Image.open(image_path).convert("RGB")
    image_inputs = processor(images=image, return_tensors="pt").to(device)

    with torch.no_grad():
        image_features = clip_model.get_image_features(**image_inputs)

    similarities = {}

    for key, sample in dataset.items():
        text = sample["llava_text"]
        text_inputs = processor(text=[text], return_tensors="pt", padding=True, truncation=True).to(device)

        with torch.no_grad():
            text_features = clip_model.get_text_features(**text_inputs)

        similarity = torch.nn.functional.cosine_similarity(image_features, text_features)
        similarities[key] = similarity.item()

    best_match = max(similarities, key=similarities.get)
    return dataset[best_match]["llava_text"]

test_image = "/kaggle/input/cub2002011/CUB_200_2011/images/002.Laysan_Albatross/Laysan_Albatross_0001_545.jpg"
print(retrieve_text(test_image))