In [None]:
!pip install torch torchvision
!pip install ftfy regex tqdm
!pip install git+https://github.com/openai/CLIP.git
!pip install faiss-cpu
!pip install faiss-gpu
!pip install faiss
!pip install --upgrade transformers

Collecting ftfy
  Downloading ftfy-6.3.1-py3-none-any.whl.metadata (7.3 kB)
Downloading ftfy-6.3.1-py3-none-any.whl (44 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.8/44.8 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: ftfy
Successfully installed ftfy-6.3.1
Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-um_caf8p
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-um_caf8p
  Resolved https://github.com/openai/CLIP.git to commit dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: clip
  Building wheel for clip (setup.py) ... [?25l[?25hdone
  Created wheel for clip: filename=clip-1.0-py3-none-any.whl size=1369489 sha256=e1a325df98b9ec69887630b9021ef255ef993409a5ef6dc886e637fbf2f4201a
  Stored in directory: /tmp/pip-ephem-w

In [None]:
from google.colab import drive
import torch
import faiss
import numpy as np
from PIL import Image
from transformers import AlignProcessor, AlignModel
from torch.utils.data import DataLoader

# Step 1: Mount Google Drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [None]:
train_path       = '/content/drive/My Drive/Building_images_train/'
train_label_path = '/content/drive/My Drive/Building_images_train/train_labels.json'
# output_dir = '/content/drive/My Drive/Image_text models/CLIP_fine-tunned/'
# output_dir = '/content/drive/My Drive/Image_text models/Clip_fine-tunned_conservative'
# output_dir = '/content/drive/My Drive/Image_text models/Clip_fine-tunned_conservative'


In [None]:
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import CLIPProcessor, CLIPModel
import os
import json
from PIL import Image
from tqdm import tqdm

# Dataset Class for Variable Label Lengths
class ImageTextDataset(Dataset):
    def __init__(self, image_dir, json_path, processor):
        """
        Args:
            image_dir (str): Path to the directory containing images.
            json_path (str): Path to the JSON file with image-label mappings.
            processor: CLIP processor for preprocessing.
        """
        self.image_dir = image_dir
        with open(json_path, 'r') as f:
            self.data = json.load(f)
        self.processor = processor
        self.samples = [
            (image_name, label)
            for image_name, labels in self.data.items()
            for label in labels  # Flatten to image-label pairs
        ]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        image_name, label = self.samples[idx]
        image_path = os.path.join(self.image_dir, image_name)
        image = Image.open(image_path).convert("RGB")
        return image, label, image_name

# Custom Collate Function
def collate_fn(batch):
    """
    Custom collate function to handle variable-length text inputs.

    Args:
        batch: List of tuples (image, label, image_name).

    Returns:
        Processed inputs, image_names, labels.
    """
    images = [item[0] for item in batch]
    labels = [item[1] for item in batch]
    image_names = [item[2] for item in batch]

    # Process images and texts separately
    inputs = processor(
        text=labels,
        images=images,
        return_tensors="pt",
        padding=True
    )
    return inputs, image_names, labels

# Load Pre-trained CLIP Model and Processor
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# Dataset and DataLoader
image_dir = train_path # Replace with your image directory
json_path = train_label_path  # Replace with your label file path
dataset = ImageTextDataset(image_dir, json_path, processor)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)

# Training Setup
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-6)
loss_fn = torch.nn.CrossEntropyLoss()

# Fine-Tuning Loop
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

epochs = 30
for epoch in range(epochs):
    model.train()
    total_loss = 0
    for batch in tqdm(dataloader, desc=f"Epoch {epoch + 1}/{epochs}"):
        inputs, _, _ = batch
        pixel_values = inputs["pixel_values"].to(device)
        input_ids = inputs["input_ids"].to(device)
        attention_mask = inputs["attention_mask"].to(device)

        # Get embeddings
        outputs = model(pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask)
        logits_per_image = outputs.logits_per_image  # Image-to-text similarity
        logits_per_text = outputs.logits_per_text  # Text-to-image similarity

        # Target: Single correct pair per batch
        targets = torch.arange(len(logits_per_image)).to(device)

        # Compute loss
        loss = (loss_fn(logits_per_image, targets) + loss_fn(logits_per_text, targets)) / 2
        total_loss += loss.item()

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch + 1} Loss: {total_loss / len(dataloader)}")

output_dir = '/content/drive/My Drive/Image_text models/CLIP_fine-tunned_batch32/'

# Save the Fine-tuned Model
model.save_pretrained(output_dir)
processor.save_pretrained(output_dir)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/4.19k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/605M [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/316 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/592 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/605M [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/862k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/525k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.22M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/389 [00:00<?, ?B/s]

Epoch 1/30: 100%|██████████| 23/23 [02:47<00:00,  7.29s/it]


Epoch 1 Loss: 2.9201519696608833


Epoch 2/30: 100%|██████████| 23/23 [00:12<00:00,  1.81it/s]


Epoch 2 Loss: 2.4464148853136147


Epoch 3/30: 100%|██████████| 23/23 [00:12<00:00,  1.83it/s]


Epoch 3 Loss: 2.3445256896640942


Epoch 4/30: 100%|██████████| 23/23 [00:12<00:00,  1.79it/s]


Epoch 4 Loss: 2.286290676697441


Epoch 5/30: 100%|██████████| 23/23 [00:12<00:00,  1.78it/s]


Epoch 5 Loss: 2.2826152106989985


Epoch 6/30: 100%|██████████| 23/23 [00:12<00:00,  1.81it/s]


Epoch 6 Loss: 2.2464515644571055


Epoch 7/30: 100%|██████████| 23/23 [00:12<00:00,  1.81it/s]


Epoch 7 Loss: 2.2303415122239487


Epoch 8/30: 100%|██████████| 23/23 [00:12<00:00,  1.81it/s]


Epoch 8 Loss: 2.2419700052427207


Epoch 9/30: 100%|██████████| 23/23 [00:12<00:00,  1.80it/s]


Epoch 9 Loss: 2.226146594337795


Epoch 10/30: 100%|██████████| 23/23 [00:12<00:00,  1.79it/s]


Epoch 10 Loss: 2.2076559222262837


Epoch 11/30: 100%|██████████| 23/23 [00:12<00:00,  1.79it/s]


Epoch 11 Loss: 2.195014741109765


Epoch 12/30: 100%|██████████| 23/23 [00:12<00:00,  1.79it/s]


Epoch 12 Loss: 2.2121589028316997


Epoch 13/30: 100%|██████████| 23/23 [00:12<00:00,  1.82it/s]


Epoch 13 Loss: 2.2069668147874917


Epoch 14/30: 100%|██████████| 23/23 [00:12<00:00,  1.82it/s]


Epoch 14 Loss: 2.189763950264972


Epoch 15/30: 100%|██████████| 23/23 [00:12<00:00,  1.79it/s]


Epoch 15 Loss: 2.1898144846377163


Epoch 16/30: 100%|██████████| 23/23 [00:12<00:00,  1.80it/s]


Epoch 16 Loss: 2.162617766338846


Epoch 17/30: 100%|██████████| 23/23 [00:12<00:00,  1.79it/s]


Epoch 17 Loss: 2.1629305663316147


Epoch 18/30: 100%|██████████| 23/23 [00:12<00:00,  1.81it/s]


Epoch 18 Loss: 2.1532789313274883


Epoch 19/30: 100%|██████████| 23/23 [00:12<00:00,  1.83it/s]


Epoch 19 Loss: 2.176597050998522


Epoch 20/30: 100%|██████████| 23/23 [00:12<00:00,  1.81it/s]


Epoch 20 Loss: 2.144528969474461


Epoch 21/30: 100%|██████████| 23/23 [00:12<00:00,  1.81it/s]


Epoch 21 Loss: 2.1720266134842583


Epoch 22/30: 100%|██████████| 23/23 [00:12<00:00,  1.81it/s]


Epoch 22 Loss: 2.1681152893149336


Epoch 23/30: 100%|██████████| 23/23 [00:12<00:00,  1.82it/s]


Epoch 23 Loss: 2.1352909751560376


Epoch 24/30: 100%|██████████| 23/23 [00:12<00:00,  1.80it/s]


Epoch 24 Loss: 2.1245554322781772


Epoch 25/30: 100%|██████████| 23/23 [00:12<00:00,  1.79it/s]


Epoch 25 Loss: 2.1596639260001806


Epoch 26/30: 100%|██████████| 23/23 [00:12<00:00,  1.79it/s]


Epoch 26 Loss: 2.1624998424364175


Epoch 27/30: 100%|██████████| 23/23 [00:13<00:00,  1.75it/s]


Epoch 27 Loss: 2.162308578905852


Epoch 28/30: 100%|██████████| 23/23 [00:13<00:00,  1.76it/s]


Epoch 28 Loss: 2.1601010094518247


Epoch 29/30: 100%|██████████| 23/23 [00:12<00:00,  1.79it/s]


Epoch 29 Loss: 2.1621434170266856


Epoch 30/30: 100%|██████████| 23/23 [00:12<00:00,  1.80it/s]


Epoch 30 Loss: 2.1390307312426358


[]

In [None]:
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
import matplotlib.pyplot as plt
import os

# Load Fine-Tuned Model and Processor
model_path = output_dir
print(model_path)
model = CLIPModel.from_pretrained(model_path)
processor = CLIPProcessor.from_pretrained(model_path)

# model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
# processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# Function to Compute Similarity Scores
def get_similar_images(query, image_dir, top_k=5):
    """
    Args:
        query (str): The text query to search for.
        image_dir (str): Directory containing the images.
        top_k (int): Number of top similar images to return.

    Returns:
        List of tuples (image_name, score) and a list of corresponding images.
    """
    model.eval()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # Prepare the text input
    text_inputs = processor(text=[query], return_tensors="pt", padding=True)
    text_inputs = {k: v.to(device) for k, v in text_inputs.items()}

    # Get text embedding
    with torch.no_grad():
        text_features = model.get_text_features(**text_inputs)
        text_features /= text_features.norm(dim=-1, keepdim=True)  # Normalize

    # Compute similarity scores for all images
    image_scores = []
    images = {}

    for image_name in os.listdir(image_dir):
        if image_name.endswith(".json"):
            continue
        image_path = os.path.join(image_dir, image_name)
        image = Image.open(image_path).convert("RGB")

        # Process the image
        image_inputs = processor(images=image, return_tensors="pt")
        image_inputs = {k: v.to(device) for k, v in image_inputs.items()}

        # Get image embedding
        with torch.no_grad():
            image_features = model.get_image_features(**image_inputs)
            image_features /= image_features.norm(dim=-1, keepdim=True)  # Normalize

        # Compute cosine similarity
        similarity = torch.matmul(text_features, image_features.T).item()
        image_scores.append((image_name, similarity))
        images[image_name] = image

    # Sort by similarity score in descending order
    image_scores = sorted(image_scores, key=lambda x: x[1], reverse=True)

    # Get top_k results
    top_scores = image_scores[:top_k]
    top_images = [images[image_name] for image_name, _ in top_scores]
    return top_scores, top_images

# Display Results
def display_results(query, top_scores, top_images):
    """
    Displays the images with their similarity scores.

    Args:
        query (str): The text query.
        top_scores (list): List of tuples (image_name, score).
        top_images (list): List of PIL images.
    """
    print(f"Query: {query}")
    for i, (image_name, score) in enumerate(top_scores):
        print(f"{i+1}: {image_name} (Score: {score:.4f})")

    # Show images vertically with larger scales
    fig, axes = plt.subplots(len(top_images), 1, figsize=(5, 5 * len(top_images)))
    if len(top_images) == 1:
        axes = [axes]  # Ensure axes is iterable for a single image
    for ax, (image, (image_name, score)) in zip(axes, zip(top_images, top_scores)):
        ax.imshow(image)
        ax.set_title(f"{image_name}\nScore: {score:.4f}", fontsize=14)
        ax.axis("off")
    plt.tight_layout()
    plt.show()

# Example Usage
query = "bath tub"
image_dir = train_path # Change to your image directory
# image_dir = '/content/drive/My Drive/Building_images_test_case/10130818'

top_scores, top_images = get_similar_images(query, image_dir, top_k=50)
display_results(query, top_scores, top_images)


# Normal But with batch size of 8

In [None]:
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import CLIPProcessor, CLIPModel
import os
import json
from PIL import Image
from tqdm import tqdm

# Dataset Class for Variable Label Lengths
class ImageTextDataset(Dataset):
    def __init__(self, image_dir, json_path, processor):
        """
        Args:
            image_dir (str): Path to the directory containing images.
            json_path (str): Path to the JSON file with image-label mappings.
            processor: CLIP processor for preprocessing.
        """
        self.image_dir = image_dir
        with open(json_path, 'r') as f:
            self.data = json.load(f)
        self.processor = processor
        self.samples = [
            (image_name, label)
            for image_name, labels in self.data.items()
            for label in labels  # Flatten to image-label pairs
        ]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        image_name, label = self.samples[idx]
        image_path = os.path.join(self.image_dir, image_name)
        image = Image.open(image_path).convert("RGB")
        return image, label, image_name

# Custom Collate Function
def collate_fn(batch):
    """
    Custom collate function to handle variable-length text inputs.

    Args:
        batch: List of tuples (image, label, image_name).

    Returns:
        Processed inputs, image_names, labels.
    """
    images = [item[0] for item in batch]
    labels = [item[1] for item in batch]
    image_names = [item[2] for item in batch]

    # Process images and texts separately
    inputs = processor(
        text=labels,
        images=images,
        return_tensors="pt",
        padding=True
    )
    return inputs, image_names, labels

# Load Pre-trained CLIP Model and Processor
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# Dataset and DataLoader
image_dir = train_path # Replace with your image directory
json_path = train_label_path  # Replace with your label file path
dataset = ImageTextDataset(image_dir, json_path, processor)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)

# Training Setup
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-6)
loss_fn = torch.nn.CrossEntropyLoss()

# Fine-Tuning Loop
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

epochs = 30
for epoch in range(epochs):
    model.train()
    total_loss = 0
    for batch in tqdm(dataloader, desc=f"Epoch {epoch + 1}/{epochs}"):
        inputs, _, _ = batch
        pixel_values = inputs["pixel_values"].to(device)
        input_ids = inputs["input_ids"].to(device)
        attention_mask = inputs["attention_mask"].to(device)

        # Get embeddings
        outputs = model(pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask)
        logits_per_image = outputs.logits_per_image  # Image-to-text similarity
        logits_per_text = outputs.logits_per_text  # Text-to-image similarity

        # Target: Single correct pair per batch
        targets = torch.arange(len(logits_per_image)).to(device)

        # Compute loss
        loss = (loss_fn(logits_per_image, targets) + loss_fn(logits_per_text, targets)) / 2
        total_loss += loss.item()

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch + 1} Loss: {total_loss / len(dataloader)}")

output_dir = '/content/drive/My Drive/Image_text models/CLIP_fine-tunned_batch8/'

# Save the Fine-tuned Model
model.save_pretrained(output_dir)
processor.save_pretrained(output_dir)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/4.19k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/605M [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/316 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/592 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/862k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/605M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/525k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.22M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/389 [00:00<?, ?B/s]

Epoch 1/30: 100%|██████████| 90/90 [03:18<00:00,  2.20s/it]


Epoch 1 Loss: 1.5359422557883793


Epoch 2/30: 100%|██████████| 90/90 [00:15<00:00,  5.65it/s]


Epoch 2 Loss: 1.246729678577847


Epoch 3/30: 100%|██████████| 90/90 [00:15<00:00,  5.65it/s]


Epoch 3 Loss: 1.1727253993352253


Epoch 4/30: 100%|██████████| 90/90 [00:16<00:00,  5.62it/s]


Epoch 4 Loss: 1.1543829490741093


Epoch 5/30: 100%|██████████| 90/90 [00:16<00:00,  5.59it/s]


Epoch 5 Loss: 1.168420931365755


Epoch 6/30: 100%|██████████| 90/90 [00:16<00:00,  5.62it/s]


Epoch 6 Loss: 1.158016366428799


Epoch 7/30: 100%|██████████| 90/90 [00:15<00:00,  5.66it/s]


Epoch 7 Loss: 1.1025220188829634


Epoch 8/30: 100%|██████████| 90/90 [00:15<00:00,  5.67it/s]


Epoch 8 Loss: 1.1283703956339095


Epoch 9/30: 100%|██████████| 90/90 [00:15<00:00,  5.68it/s]


Epoch 9 Loss: 1.1743578576379352


Epoch 10/30: 100%|██████████| 90/90 [00:16<00:00,  5.58it/s]


Epoch 10 Loss: 1.152445696791013


Epoch 11/30: 100%|██████████| 90/90 [00:15<00:00,  5.66it/s]


Epoch 11 Loss: 1.100095839632882


Epoch 12/30: 100%|██████████| 90/90 [00:15<00:00,  5.66it/s]


Epoch 12 Loss: 1.0996025655004713


Epoch 13/30: 100%|██████████| 90/90 [00:15<00:00,  5.69it/s]


Epoch 13 Loss: 1.109622124168608


Epoch 14/30: 100%|██████████| 90/90 [00:15<00:00,  5.69it/s]


Epoch 14 Loss: 1.1172988560464647


Epoch 15/30: 100%|██████████| 90/90 [00:15<00:00,  5.65it/s]


Epoch 15 Loss: 1.104592544833819


Epoch 16/30: 100%|██████████| 90/90 [00:15<00:00,  5.69it/s]


Epoch 16 Loss: 1.101513546705246


Epoch 17/30: 100%|██████████| 90/90 [00:15<00:00,  5.67it/s]


Epoch 17 Loss: 1.1043026208877564


Epoch 18/30: 100%|██████████| 90/90 [00:15<00:00,  5.68it/s]


Epoch 18 Loss: 1.0705316010448667


Epoch 19/30: 100%|██████████| 90/90 [00:15<00:00,  5.70it/s]


Epoch 19 Loss: 1.0923064470291137


Epoch 20/30: 100%|██████████| 90/90 [00:15<00:00,  5.68it/s]


Epoch 20 Loss: 1.071909401151869


Epoch 21/30: 100%|██████████| 90/90 [00:15<00:00,  5.66it/s]


Epoch 21 Loss: 1.099150475528505


Epoch 22/30: 100%|██████████| 90/90 [00:15<00:00,  5.67it/s]


Epoch 22 Loss: 1.056631393565072


Epoch 23/30: 100%|██████████| 90/90 [00:15<00:00,  5.68it/s]


Epoch 23 Loss: 1.0972596691714394


Epoch 24/30: 100%|██████████| 90/90 [00:15<00:00,  5.66it/s]


Epoch 24 Loss: 1.0864543914794922


Epoch 25/30: 100%|██████████| 90/90 [00:15<00:00,  5.67it/s]


Epoch 25 Loss: 1.0563523391882579


Epoch 26/30: 100%|██████████| 90/90 [00:15<00:00,  5.70it/s]


Epoch 26 Loss: 1.0570258643892076


Epoch 27/30: 100%|██████████| 90/90 [00:15<00:00,  5.68it/s]


Epoch 27 Loss: 1.0834139062298669


Epoch 28/30: 100%|██████████| 90/90 [00:15<00:00,  5.67it/s]


Epoch 28 Loss: 1.0582764983177184


Epoch 29/30: 100%|██████████| 90/90 [00:15<00:00,  5.68it/s]


Epoch 29 Loss: 1.072802930408054


Epoch 30/30: 100%|██████████| 90/90 [00:15<00:00,  5.68it/s]


Epoch 30 Loss: 1.0546453360054229


[]

# Freezing 1layers




In [None]:
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import CLIPProcessor, CLIPModel
import os
import json
from PIL import Image
from tqdm import tqdm

# Dataset Class for Variable Label Lengths
class ImageTextDataset(Dataset):
    def __init__(self, image_dir, json_path, processor):
        """
        Args:
            image_dir (str): Path to the directory containing images.
            json_path (str): Path to the JSON file with image-label mappings.
            processor: CLIP processor for preprocessing.
        """
        self.image_dir = image_dir
        with open(json_path, 'r') as f:
            self.data = json.load(f)
        self.processor = processor
        self.samples = [
            (image_name, label)
            for image_name, labels in self.data.items()
            for label in labels  # Flatten to image-label pairs
        ]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        image_name, label = self.samples[idx]
        image_path = os.path.join(self.image_dir, image_name)
        image = Image.open(image_path).convert("RGB")
        return image, label, image_name

# Custom Collate Function
def collate_fn(batch):
    """
    Custom collate function to handle variable-length text inputs.

    Args:
        batch: List of tuples (image, label, image_name).

    Returns:
        Processed inputs, image_names, labels.
    """
    images = [item[0] for item in batch]
    labels = [item[1] for item in batch]
    image_names = [item[2] for item in batch]

    # Process images and texts separately
    inputs = processor(
        text=labels,
        images=images,
        return_tensors="pt",
        padding=True
    )
    return inputs, image_names, labels

# Load Pre-trained CLIP Model and Processor
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# Freeze all parameters
for param in model.parameters():
    param.requires_grad = False

# Unfreeze the last few layers of the vision model
num_unfreeze_vision = 1  # Number of vision encoder layers to unfreeze
for block in model.vision_model.encoder.layers[-num_unfreeze_vision:]:
    for param in block.parameters():
        param.requires_grad = True

# Unfreeze the last few layers of the text model
num_unfreeze_text = 1  # Number of text encoder layers to unfreeze
for block in model.text_model.encoder.layers[-num_unfreeze_text:]:
    for param in block.parameters():
        param.requires_grad = True

# Unfreeze the projection layers
for param in model.visual_projection.parameters():
    param.requires_grad = True

for param in model.text_projection.parameters():
    param.requires_grad = True

# Dataset and DataLoader
image_dir = train_path  # Replace with your image directory
json_path = train_label_path  # Replace with your label file path
dataset = ImageTextDataset(image_dir, json_path, processor)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)

# Training Setup
optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()), lr=5e-6
)
loss_fn = torch.nn.CrossEntropyLoss()

# Fine-Tuning Loop
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

epochs = 100
for epoch in range(epochs):
    model.train()
    total_loss = 0
    for batch in tqdm(dataloader, desc=f"Epoch {epoch + 1}/{epochs}"):
        inputs, _, _ = batch
        pixel_values = inputs["pixel_values"].to(device)
        input_ids = inputs["input_ids"].to(device)
        attention_mask = inputs["attention_mask"].to(device)

        # Get embeddings
        outputs = model(
            pixel_values=pixel_values,
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        logits_per_image = outputs.logits_per_image  # Image-to-text similarity
        logits_per_text = outputs.logits_per_text  # Text-to-image similarity

        # Target: Single correct pair per batch
        targets = torch.arange(len(logits_per_image)).to(device)

        # Compute loss
        loss = (loss_fn(logits_per_image, targets) + loss_fn(logits_per_text, targets)) / 2
        total_loss += loss.item()

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch + 1} Loss: {total_loss / len(dataloader)}")

output_dir = '/content/drive/My Drive/Image_text models/CLIP_fine-tunned_freeze_1layer/'

# Save the Fine-tuned Model
model.save_pretrained(output_dir)
processor.save_pretrained(output_dir)

Epoch 1/100: 100%|██████████| 90/90 [00:13<00:00,  6.48it/s]


Epoch 1 Loss: 1.689998409483168


Epoch 2/100: 100%|██████████| 90/90 [00:13<00:00,  6.70it/s]


Epoch 2 Loss: 1.4149100091722278


Epoch 3/100: 100%|██████████| 90/90 [00:13<00:00,  6.72it/s]


Epoch 3 Loss: 1.2962137732240888


Epoch 4/100: 100%|██████████| 90/90 [00:13<00:00,  6.73it/s]


Epoch 4 Loss: 1.2506895383199057


Epoch 5/100: 100%|██████████| 90/90 [00:13<00:00,  6.72it/s]


Epoch 5 Loss: 1.2224797738922968


Epoch 6/100: 100%|██████████| 90/90 [00:13<00:00,  6.70it/s]


Epoch 6 Loss: 1.1686806996663412


Epoch 7/100: 100%|██████████| 90/90 [00:13<00:00,  6.73it/s]


Epoch 7 Loss: 1.193711927202013


Epoch 8/100: 100%|██████████| 90/90 [00:13<00:00,  6.72it/s]


Epoch 8 Loss: 1.156329799691836


Epoch 9/100: 100%|██████████| 90/90 [00:13<00:00,  6.69it/s]


Epoch 9 Loss: 1.1396403789520264


Epoch 10/100: 100%|██████████| 90/90 [00:13<00:00,  6.70it/s]


Epoch 10 Loss: 1.1163716140720579


Epoch 11/100: 100%|██████████| 90/90 [00:13<00:00,  6.72it/s]


Epoch 11 Loss: 1.1123610768053267


Epoch 12/100: 100%|██████████| 90/90 [00:13<00:00,  6.74it/s]


Epoch 12 Loss: 1.1191523992353016


Epoch 13/100: 100%|██████████| 90/90 [00:13<00:00,  6.74it/s]


Epoch 13 Loss: 1.106596980492274


Epoch 14/100: 100%|██████████| 90/90 [00:13<00:00,  6.74it/s]


Epoch 14 Loss: 1.0819501181443532


Epoch 15/100: 100%|██████████| 90/90 [00:13<00:00,  6.66it/s]


Epoch 15 Loss: 1.0759023567040762


Epoch 16/100: 100%|██████████| 90/90 [00:13<00:00,  6.73it/s]


Epoch 16 Loss: 1.1000162694189284


Epoch 17/100: 100%|██████████| 90/90 [00:13<00:00,  6.74it/s]


Epoch 17 Loss: 1.103260197242101


Epoch 18/100: 100%|██████████| 90/90 [00:13<00:00,  6.69it/s]


Epoch 18 Loss: 1.0833199507660336


Epoch 19/100: 100%|██████████| 90/90 [00:13<00:00,  6.73it/s]


Epoch 19 Loss: 1.1005433718363444


Epoch 20/100: 100%|██████████| 90/90 [00:13<00:00,  6.70it/s]


Epoch 20 Loss: 1.079317225019137


Epoch 21/100: 100%|██████████| 90/90 [00:13<00:00,  6.73it/s]


Epoch 21 Loss: 1.085166233778


Epoch 22/100: 100%|██████████| 90/90 [00:13<00:00,  6.72it/s]


Epoch 22 Loss: 1.0683797276682323


Epoch 23/100: 100%|██████████| 90/90 [00:13<00:00,  6.72it/s]


Epoch 23 Loss: 1.042302668425772


Epoch 24/100: 100%|██████████| 90/90 [00:13<00:00,  6.70it/s]


Epoch 24 Loss: 1.0534002143475745


Epoch 25/100: 100%|██████████| 90/90 [00:13<00:00,  6.66it/s]


Epoch 25 Loss: 1.081073068247901


Epoch 26/100: 100%|██████████| 90/90 [00:13<00:00,  6.72it/s]


Epoch 26 Loss: 1.0807229535447227


Epoch 27/100: 100%|██████████| 90/90 [00:13<00:00,  6.74it/s]


Epoch 27 Loss: 1.08313743505213


Epoch 28/100: 100%|██████████| 90/90 [00:13<00:00,  6.69it/s]


Epoch 28 Loss: 1.0766234828366175


Epoch 29/100: 100%|██████████| 90/90 [00:13<00:00,  6.67it/s]


Epoch 29 Loss: 1.0454385356770621


Epoch 30/100: 100%|██████████| 90/90 [00:13<00:00,  6.68it/s]


Epoch 30 Loss: 1.0728353030151792


Epoch 31/100: 100%|██████████| 90/90 [00:13<00:00,  6.72it/s]


Epoch 31 Loss: 1.052816143963072


Epoch 32/100: 100%|██████████| 90/90 [00:13<00:00,  6.71it/s]


Epoch 32 Loss: 1.050371868742837


Epoch 33/100: 100%|██████████| 90/90 [00:13<00:00,  6.71it/s]


Epoch 33 Loss: 1.0676832589838239


Epoch 34/100: 100%|██████████| 90/90 [00:13<00:00,  6.71it/s]


Epoch 34 Loss: 1.0711542632844713


Epoch 35/100: 100%|██████████| 90/90 [00:13<00:00,  6.68it/s]


Epoch 35 Loss: 1.0403482973575593


Epoch 36/100: 100%|██████████| 90/90 [00:13<00:00,  6.72it/s]


Epoch 36 Loss: 1.0366137305895486


Epoch 37/100: 100%|██████████| 90/90 [00:13<00:00,  6.70it/s]


Epoch 37 Loss: 1.0460595985253651


Epoch 38/100: 100%|██████████| 90/90 [00:13<00:00,  6.66it/s]


Epoch 38 Loss: 1.0266112287839253


Epoch 39/100: 100%|██████████| 90/90 [00:13<00:00,  6.72it/s]


Epoch 39 Loss: 1.0305299937725068


Epoch 40/100: 100%|██████████| 90/90 [00:13<00:00,  6.74it/s]


Epoch 40 Loss: 1.0327876756588619


Epoch 41/100: 100%|██████████| 90/90 [00:13<00:00,  6.74it/s]


Epoch 41 Loss: 1.0380409742395083


Epoch 42/100: 100%|██████████| 90/90 [00:13<00:00,  6.71it/s]


Epoch 42 Loss: 1.036004998948839


Epoch 43/100: 100%|██████████| 90/90 [00:13<00:00,  6.73it/s]


Epoch 43 Loss: 1.0202363888422648


Epoch 44/100: 100%|██████████| 90/90 [00:13<00:00,  6.74it/s]


Epoch 44 Loss: 1.0200886136955685


Epoch 45/100: 100%|██████████| 90/90 [00:13<00:00,  6.71it/s]


Epoch 45 Loss: 1.0071602715386285


Epoch 46/100: 100%|██████████| 90/90 [00:13<00:00,  6.67it/s]


Epoch 46 Loss: 1.0262692719697952


Epoch 47/100: 100%|██████████| 90/90 [00:13<00:00,  6.69it/s]


Epoch 47 Loss: 1.0256374445226457


Epoch 48/100: 100%|██████████| 90/90 [00:13<00:00,  6.71it/s]


Epoch 48 Loss: 1.0321747452020644


Epoch 49/100: 100%|██████████| 90/90 [00:13<00:00,  6.71it/s]


Epoch 49 Loss: 1.0365797188546924


Epoch 50/100: 100%|██████████| 90/90 [00:13<00:00,  6.71it/s]


Epoch 50 Loss: 1.0340921812587314


Epoch 51/100: 100%|██████████| 90/90 [00:13<00:00,  6.72it/s]


Epoch 51 Loss: 1.0217374089691373


Epoch 52/100: 100%|██████████| 90/90 [00:13<00:00,  6.69it/s]


Epoch 52 Loss: 1.0638878617021772


Epoch 53/100: 100%|██████████| 90/90 [00:13<00:00,  6.71it/s]


Epoch 53 Loss: 0.9965778234932158


Epoch 54/100: 100%|██████████| 90/90 [00:13<00:00,  6.66it/s]


Epoch 54 Loss: 1.0430727826224433


Epoch 55/100: 100%|██████████| 90/90 [00:13<00:00,  6.72it/s]


Epoch 55 Loss: 1.0626152084933387


Epoch 56/100: 100%|██████████| 90/90 [00:13<00:00,  6.71it/s]


Epoch 56 Loss: 1.0449162843326727


Epoch 57/100: 100%|██████████| 90/90 [00:13<00:00,  6.70it/s]


Epoch 57 Loss: 1.0248333950837454


Epoch 58/100: 100%|██████████| 90/90 [00:13<00:00,  6.71it/s]


Epoch 58 Loss: 1.0216944893201192


Epoch 59/100: 100%|██████████| 90/90 [00:13<00:00,  6.71it/s]


Epoch 59 Loss: 1.0366872284147475


Epoch 60/100: 100%|██████████| 90/90 [00:13<00:00,  6.70it/s]


Epoch 60 Loss: 1.03450083302127


Epoch 61/100: 100%|██████████| 90/90 [00:13<00:00,  6.65it/s]


Epoch 61 Loss: 1.0120513872967827


Epoch 62/100: 100%|██████████| 90/90 [00:13<00:00,  6.72it/s]


Epoch 62 Loss: 1.014072236749861


Epoch 63/100: 100%|██████████| 90/90 [00:13<00:00,  6.73it/s]


Epoch 63 Loss: 1.0002612147066328


Epoch 64/100: 100%|██████████| 90/90 [00:13<00:00,  6.73it/s]


Epoch 64 Loss: 1.0561771161026425


Epoch 65/100: 100%|██████████| 90/90 [00:13<00:00,  6.74it/s]


Epoch 65 Loss: 1.020770818988482


Epoch 66/100: 100%|██████████| 90/90 [00:13<00:00,  6.69it/s]


Epoch 66 Loss: 1.0528495404455396


Epoch 67/100: 100%|██████████| 90/90 [00:13<00:00,  6.72it/s]


Epoch 67 Loss: 1.0278810100422966


Epoch 68/100: 100%|██████████| 90/90 [00:13<00:00,  6.72it/s]


Epoch 68 Loss: 1.0299078421460257


Epoch 69/100: 100%|██████████| 90/90 [00:13<00:00,  6.75it/s]


Epoch 69 Loss: 1.020499246650272


Epoch 70/100: 100%|██████████| 90/90 [00:13<00:00,  6.67it/s]


Epoch 70 Loss: 1.0414341217941707


Epoch 71/100: 100%|██████████| 90/90 [00:13<00:00,  6.75it/s]


Epoch 71 Loss: 1.0257096618413926


Epoch 72/100: 100%|██████████| 90/90 [00:13<00:00,  6.72it/s]


Epoch 72 Loss: 1.0341809382041296


Epoch 73/100: 100%|██████████| 90/90 [00:13<00:00,  6.74it/s]


Epoch 73 Loss: 1.0287474562724432


Epoch 74/100: 100%|██████████| 90/90 [00:13<00:00,  6.70it/s]


Epoch 74 Loss: 1.0507609804471334


Epoch 75/100: 100%|██████████| 90/90 [00:13<00:00,  6.73it/s]


Epoch 75 Loss: 1.016028289662467


Epoch 76/100: 100%|██████████| 90/90 [00:13<00:00,  6.72it/s]


Epoch 76 Loss: 1.0383331331941816


Epoch 77/100: 100%|██████████| 90/90 [00:13<00:00,  6.67it/s]


Epoch 77 Loss: 1.0515618019633823


Epoch 78/100: 100%|██████████| 90/90 [00:13<00:00,  6.76it/s]


Epoch 78 Loss: 1.0388903952307171


Epoch 79/100: 100%|██████████| 90/90 [00:13<00:00,  6.69it/s]


Epoch 79 Loss: 1.0294364765286446


Epoch 80/100: 100%|██████████| 90/90 [00:13<00:00,  6.73it/s]


Epoch 80 Loss: 1.036523284183608


Epoch 81/100: 100%|██████████| 90/90 [00:13<00:00,  6.69it/s]


Epoch 81 Loss: 1.016379925277498


Epoch 82/100: 100%|██████████| 90/90 [00:13<00:00,  6.72it/s]


Epoch 82 Loss: 1.0175085660484102


Epoch 83/100: 100%|██████████| 90/90 [00:13<00:00,  6.75it/s]


Epoch 83 Loss: 0.9913618998395072


Epoch 84/100: 100%|██████████| 90/90 [00:13<00:00,  6.67it/s]


Epoch 84 Loss: 1.0146007819308176


Epoch 85/100: 100%|██████████| 90/90 [00:13<00:00,  6.73it/s]


Epoch 85 Loss: 1.0453796525796255


Epoch 86/100: 100%|██████████| 90/90 [00:13<00:00,  6.72it/s]


Epoch 86 Loss: 1.028807873527209


Epoch 87/100: 100%|██████████| 90/90 [00:13<00:00,  6.71it/s]


Epoch 87 Loss: 0.9981936679946052


Epoch 88/100: 100%|██████████| 90/90 [00:13<00:00,  6.75it/s]


Epoch 88 Loss: 1.0155971907907062


Epoch 89/100: 100%|██████████| 90/90 [00:13<00:00,  6.66it/s]


Epoch 89 Loss: 1.0381555557250977


Epoch 90/100: 100%|██████████| 90/90 [00:13<00:00,  6.71it/s]


Epoch 90 Loss: 1.0145976192421384


Epoch 91/100: 100%|██████████| 90/90 [00:13<00:00,  6.68it/s]


Epoch 91 Loss: 1.014439144068294


Epoch 92/100: 100%|██████████| 90/90 [00:13<00:00,  6.72it/s]


Epoch 92 Loss: 0.9943316880199644


Epoch 93/100: 100%|██████████| 90/90 [00:13<00:00,  6.70it/s]


Epoch 93 Loss: 0.9971017923620012


Epoch 94/100: 100%|██████████| 90/90 [00:13<00:00,  6.70it/s]


Epoch 94 Loss: 0.9925498247146607


Epoch 95/100: 100%|██████████| 90/90 [00:13<00:00,  6.73it/s]


Epoch 95 Loss: 1.0166680018107097


Epoch 96/100: 100%|██████████| 90/90 [00:13<00:00,  6.74it/s]


Epoch 96 Loss: 1.020485567384296


Epoch 97/100: 100%|██████████| 90/90 [00:13<00:00,  6.72it/s]


Epoch 97 Loss: 1.0280525008837382


Epoch 98/100: 100%|██████████| 90/90 [00:13<00:00,  6.71it/s]


Epoch 98 Loss: 1.0154608031113943


Epoch 99/100: 100%|██████████| 90/90 [00:13<00:00,  6.75it/s]


Epoch 99 Loss: 1.0464306357834074


Epoch 100/100: 100%|██████████| 90/90 [00:13<00:00,  6.74it/s]


Epoch 100 Loss: 1.0140284690592023


[]

# Freeze 2layers

In [None]:
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import CLIPProcessor, CLIPModel
import os
import json
from PIL import Image
from tqdm import tqdm

# Dataset Class for Variable Label Lengths
class ImageTextDataset(Dataset):
    def __init__(self, image_dir, json_path, processor):
        """
        Args:
            image_dir (str): Path to the directory containing images.
            json_path (str): Path to the JSON file with image-label mappings.
            processor: CLIP processor for preprocessing.
        """
        self.image_dir = image_dir
        with open(json_path, 'r') as f:
            self.data = json.load(f)
        self.processor = processor
        self.samples = [
            (image_name, label)
            for image_name, labels in self.data.items()
            for label in labels  # Flatten to image-label pairs
        ]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        image_name, label = self.samples[idx]
        image_path = os.path.join(self.image_dir, image_name)
        image = Image.open(image_path).convert("RGB")
        return image, label, image_name

# Custom Collate Function
def collate_fn(batch):
    """
    Custom collate function to handle variable-length text inputs.

    Args:
        batch: List of tuples (image, label, image_name).

    Returns:
        Processed inputs, image_names, labels.
    """
    images = [item[0] for item in batch]
    labels = [item[1] for item in batch]
    image_names = [item[2] for item in batch]

    # Process images and texts separately
    inputs = processor(
        text=labels,
        images=images,
        return_tensors="pt",
        padding=True
    )
    return inputs, image_names, labels

# Load Pre-trained CLIP Model and Processor
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# Freeze all parameters
for param in model.parameters():
    param.requires_grad = False

# Unfreeze the last few layers of the vision model
num_unfreeze_vision = 2  # Number of vision encoder layers to unfreeze
for block in model.vision_model.encoder.layers[-num_unfreeze_vision:]:
    for param in block.parameters():
        param.requires_grad = True

# Unfreeze the last few layers of the text model
num_unfreeze_text = 2  # Number of text encoder layers to unfreeze
for block in model.text_model.encoder.layers[-num_unfreeze_text:]:
    for param in block.parameters():
        param.requires_grad = True

# Unfreeze the projection layers
for param in model.visual_projection.parameters():
    param.requires_grad = True

for param in model.text_projection.parameters():
    param.requires_grad = True

# Dataset and DataLoader
image_dir = train_path  # Replace with your image directory
json_path = train_label_path  # Replace with your label file path
dataset = ImageTextDataset(image_dir, json_path, processor)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)

# Training Setup
optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()), lr=5e-6
)
loss_fn = torch.nn.CrossEntropyLoss()

# Fine-Tuning Loop
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

epochs = 100
for epoch in range(epochs):
    model.train()
    total_loss = 0
    for batch in tqdm(dataloader, desc=f"Epoch {epoch + 1}/{epochs}"):
        inputs, _, _ = batch
        pixel_values = inputs["pixel_values"].to(device)
        input_ids = inputs["input_ids"].to(device)
        attention_mask = inputs["attention_mask"].to(device)

        # Get embeddings
        outputs = model(
            pixel_values=pixel_values,
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        logits_per_image = outputs.logits_per_image  # Image-to-text similarity
        logits_per_text = outputs.logits_per_text  # Text-to-image similarity

        # Target: Single correct pair per batch
        targets = torch.arange(len(logits_per_image)).to(device)

        # Compute loss
        loss = (loss_fn(logits_per_image, targets) + loss_fn(logits_per_text, targets)) / 2
        total_loss += loss.item()

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch + 1} Loss: {total_loss / len(dataloader)}")

output_dir = '/content/drive/My Drive/Image_text models/CLIP_fine-tunned_freeze_2layer/'

# Save the Fine-tuned Model
model.save_pretrained(output_dir)
processor.save_pretrained(output_dir)

Epoch 1/100: 100%|██████████| 90/90 [00:14<00:00,  6.38it/s]


Epoch 1 Loss: 1.6508085979355707


Epoch 2/100: 100%|██████████| 90/90 [00:13<00:00,  6.57it/s]


Epoch 2 Loss: 1.3420739541451137


Epoch 3/100: 100%|██████████| 90/90 [00:13<00:00,  6.61it/s]


Epoch 3 Loss: 1.2473464959197573


Epoch 4/100: 100%|██████████| 90/90 [00:13<00:00,  6.58it/s]


Epoch 4 Loss: 1.1914492845535278


Epoch 5/100: 100%|██████████| 90/90 [00:13<00:00,  6.54it/s]


Epoch 5 Loss: 1.1869719445705413


Epoch 6/100: 100%|██████████| 90/90 [00:13<00:00,  6.53it/s]


Epoch 6 Loss: 1.1442803117964002


Epoch 7/100: 100%|██████████| 90/90 [00:13<00:00,  6.58it/s]


Epoch 7 Loss: 1.1503643016020457


Epoch 8/100: 100%|██████████| 90/90 [00:13<00:00,  6.60it/s]


Epoch 8 Loss: 1.0924428986178505


Epoch 9/100: 100%|██████████| 90/90 [00:13<00:00,  6.59it/s]


Epoch 9 Loss: 1.1001955807209014


Epoch 10/100: 100%|██████████| 90/90 [00:13<00:00,  6.61it/s]


Epoch 10 Loss: 1.1153179195192124


Epoch 11/100: 100%|██████████| 90/90 [00:13<00:00,  6.58it/s]


Epoch 11 Loss: 1.0992622087399164


Epoch 12/100: 100%|██████████| 90/90 [00:13<00:00,  6.57it/s]


Epoch 12 Loss: 1.1205142703321245


Epoch 13/100: 100%|██████████| 90/90 [00:13<00:00,  6.58it/s]


Epoch 13 Loss: 1.1098178055551318


Epoch 14/100: 100%|██████████| 90/90 [00:13<00:00,  6.55it/s]


Epoch 14 Loss: 1.0960359546873304


Epoch 15/100: 100%|██████████| 90/90 [00:13<00:00,  6.57it/s]


Epoch 15 Loss: 1.0954757259951697


Epoch 16/100: 100%|██████████| 90/90 [00:13<00:00,  6.57it/s]


Epoch 16 Loss: 1.086807600988282


Epoch 17/100: 100%|██████████| 90/90 [00:13<00:00,  6.59it/s]


Epoch 17 Loss: 1.0756394624710084


Epoch 18/100: 100%|██████████| 90/90 [00:13<00:00,  6.62it/s]


Epoch 18 Loss: 1.0949955072667863


Epoch 19/100: 100%|██████████| 90/90 [00:13<00:00,  6.58it/s]


Epoch 19 Loss: 1.050253951218393


Epoch 20/100: 100%|██████████| 90/90 [00:13<00:00,  6.59it/s]


Epoch 20 Loss: 1.0501759267515607


Epoch 21/100: 100%|██████████| 90/90 [00:13<00:00,  6.60it/s]


Epoch 21 Loss: 1.06385874317752


Epoch 22/100: 100%|██████████| 90/90 [00:13<00:00,  6.60it/s]


Epoch 22 Loss: 1.0772103488445282


Epoch 23/100: 100%|██████████| 90/90 [00:13<00:00,  6.56it/s]


Epoch 23 Loss: 1.02248914324575


Epoch 24/100: 100%|██████████| 90/90 [00:13<00:00,  6.60it/s]


Epoch 24 Loss: 1.05554012854894


Epoch 25/100: 100%|██████████| 90/90 [00:13<00:00,  6.62it/s]


Epoch 25 Loss: 1.0385284023152457


Epoch 26/100: 100%|██████████| 90/90 [00:13<00:00,  6.56it/s]


Epoch 26 Loss: 1.0761136876212225


Epoch 27/100: 100%|██████████| 90/90 [00:13<00:00,  6.60it/s]


Epoch 27 Loss: 1.0455051571130753


Epoch 28/100: 100%|██████████| 90/90 [00:13<00:00,  6.55it/s]


Epoch 28 Loss: 1.0530082176129023


Epoch 29/100: 100%|██████████| 90/90 [00:13<00:00,  6.55it/s]


Epoch 29 Loss: 1.0434229228231642


Epoch 30/100: 100%|██████████| 90/90 [00:13<00:00,  6.59it/s]


Epoch 30 Loss: 1.0584559857845306


Epoch 31/100: 100%|██████████| 90/90 [00:13<00:00,  6.57it/s]


Epoch 31 Loss: 1.0320874194304148


Epoch 32/100: 100%|██████████| 90/90 [00:13<00:00,  6.59it/s]


Epoch 32 Loss: 1.0755344172318777


Epoch 33/100: 100%|██████████| 90/90 [00:13<00:00,  6.60it/s]


Epoch 33 Loss: 1.059370645880699


Epoch 34/100: 100%|██████████| 90/90 [00:13<00:00,  6.61it/s]


Epoch 34 Loss: 1.0511359983020359


Epoch 35/100: 100%|██████████| 90/90 [00:13<00:00,  6.59it/s]


Epoch 35 Loss: 1.0338207019699945


Epoch 36/100: 100%|██████████| 90/90 [00:13<00:00,  6.57it/s]


Epoch 36 Loss: 1.053313747048378


Epoch 37/100: 100%|██████████| 90/90 [00:13<00:00,  6.56it/s]


Epoch 37 Loss: 1.027905375096533


Epoch 38/100: 100%|██████████| 90/90 [00:13<00:00,  6.57it/s]


Epoch 38 Loss: 1.0650100198056962


Epoch 39/100: 100%|██████████| 90/90 [00:13<00:00,  6.61it/s]


Epoch 39 Loss: 1.0214624504248302


Epoch 40/100: 100%|██████████| 90/90 [00:13<00:00,  6.60it/s]


Epoch 40 Loss: 1.0462151736021041


Epoch 41/100: 100%|██████████| 90/90 [00:13<00:00,  6.59it/s]


Epoch 41 Loss: 1.055959858165847


Epoch 42/100: 100%|██████████| 90/90 [00:13<00:00,  6.61it/s]


Epoch 42 Loss: 1.0389994637833702


Epoch 43/100: 100%|██████████| 90/90 [00:13<00:00,  6.61it/s]


Epoch 43 Loss: 1.0483826067712572


Epoch 44/100: 100%|██████████| 90/90 [00:13<00:00,  6.61it/s]


Epoch 44 Loss: 1.035704673661126


Epoch 45/100: 100%|██████████| 90/90 [00:13<00:00,  6.59it/s]


Epoch 45 Loss: 1.0252594729264577


Epoch 46/100: 100%|██████████| 90/90 [00:13<00:00,  6.52it/s]


Epoch 46 Loss: 1.0202156957652835


Epoch 47/100: 100%|██████████| 90/90 [00:13<00:00,  6.58it/s]


Epoch 47 Loss: 1.0379280808899138


Epoch 48/100: 100%|██████████| 90/90 [00:13<00:00,  6.59it/s]


Epoch 48 Loss: 1.0412673350837496


Epoch 49/100: 100%|██████████| 90/90 [00:13<00:00,  6.65it/s]


Epoch 49 Loss: 1.0265872809622023


Epoch 50/100: 100%|██████████| 90/90 [00:13<00:00,  6.55it/s]


Epoch 50 Loss: 1.0252533660994636


Epoch 51/100: 100%|██████████| 90/90 [00:13<00:00,  6.57it/s]


Epoch 51 Loss: 1.0302919406029913


Epoch 52/100: 100%|██████████| 90/90 [00:13<00:00,  6.58it/s]


Epoch 52 Loss: 0.99832999739382


Epoch 53/100: 100%|██████████| 90/90 [00:13<00:00,  6.58it/s]


Epoch 53 Loss: 1.0292255696323183


Epoch 54/100: 100%|██████████| 90/90 [00:13<00:00,  6.63it/s]


Epoch 54 Loss: 1.0068368719683753


Epoch 55/100: 100%|██████████| 90/90 [00:13<00:00,  6.56it/s]


Epoch 55 Loss: 1.0366940197017458


Epoch 56/100: 100%|██████████| 90/90 [00:13<00:00,  6.62it/s]


Epoch 56 Loss: 1.0304039465056525


Epoch 57/100: 100%|██████████| 90/90 [00:13<00:00,  6.61it/s]


Epoch 57 Loss: 1.0411895073122448


Epoch 58/100: 100%|██████████| 90/90 [00:13<00:00,  6.61it/s]


Epoch 58 Loss: 1.0109772637486458


Epoch 59/100: 100%|██████████| 90/90 [00:13<00:00,  6.54it/s]


Epoch 59 Loss: 1.0539869278669358


Epoch 60/100: 100%|██████████| 90/90 [00:13<00:00,  6.61it/s]


Epoch 60 Loss: 1.0324943572282792


Epoch 61/100: 100%|██████████| 90/90 [00:13<00:00,  6.65it/s]


Epoch 61 Loss: 1.0264933036433326


Epoch 62/100: 100%|██████████| 90/90 [00:13<00:00,  6.60it/s]


Epoch 62 Loss: 1.0580240064197117


Epoch 63/100: 100%|██████████| 90/90 [00:13<00:00,  6.62it/s]


Epoch 63 Loss: 1.0388229929738575


Epoch 64/100: 100%|██████████| 90/90 [00:13<00:00,  6.56it/s]


Epoch 64 Loss: 1.0299209237098694


Epoch 65/100: 100%|██████████| 90/90 [00:13<00:00,  6.61it/s]


Epoch 65 Loss: 1.0070424477259319


Epoch 66/100: 100%|██████████| 90/90 [00:13<00:00,  6.57it/s]


Epoch 66 Loss: 1.0338302191760804


Epoch 67/100: 100%|██████████| 90/90 [00:13<00:00,  6.54it/s]


Epoch 67 Loss: 1.0375620156526566


Epoch 68/100: 100%|██████████| 90/90 [00:13<00:00,  6.56it/s]


Epoch 68 Loss: 1.0116859257221222


Epoch 69/100: 100%|██████████| 90/90 [00:13<00:00,  6.62it/s]


Epoch 69 Loss: 0.9925439324643877


Epoch 70/100: 100%|██████████| 90/90 [00:13<00:00,  6.62it/s]


Epoch 70 Loss: 1.0180228180355495


Epoch 71/100: 100%|██████████| 90/90 [00:13<00:00,  6.57it/s]


Epoch 71 Loss: 1.0199477407667372


Epoch 72/100: 100%|██████████| 90/90 [00:13<00:00,  6.50it/s]


Epoch 72 Loss: 1.0423897779650158


Epoch 73/100: 100%|██████████| 90/90 [00:13<00:00,  6.46it/s]


Epoch 73 Loss: 0.9932647552755144


Epoch 74/100: 100%|██████████| 90/90 [00:13<00:00,  6.50it/s]


Epoch 74 Loss: 1.0194002628326415


Epoch 75/100: 100%|██████████| 90/90 [00:13<00:00,  6.44it/s]


Epoch 75 Loss: 1.0113732202185526


Epoch 76/100: 100%|██████████| 90/90 [00:13<00:00,  6.49it/s]


Epoch 76 Loss: 1.0158264878723355


Epoch 77/100: 100%|██████████| 90/90 [00:13<00:00,  6.45it/s]


Epoch 77 Loss: 1.0472906665669548


Epoch 78/100: 100%|██████████| 90/90 [00:13<00:00,  6.52it/s]


Epoch 78 Loss: 1.0245472997426988


Epoch 79/100: 100%|██████████| 90/90 [00:13<00:00,  6.52it/s]


Epoch 79 Loss: 1.016881466574139


Epoch 80/100: 100%|██████████| 90/90 [00:13<00:00,  6.57it/s]


Epoch 80 Loss: 1.0295650904377303


Epoch 81/100: 100%|██████████| 90/90 [00:13<00:00,  6.56it/s]


Epoch 81 Loss: 1.0309089708659385


Epoch 82/100: 100%|██████████| 90/90 [00:13<00:00,  6.55it/s]


Epoch 82 Loss: 1.014120907915963


Epoch 83/100: 100%|██████████| 90/90 [00:13<00:00,  6.55it/s]


Epoch 83 Loss: 1.0343182630009122


Epoch 84/100: 100%|██████████| 90/90 [00:13<00:00,  6.57it/s]


Epoch 84 Loss: 1.0322421309020784


Epoch 85/100: 100%|██████████| 90/90 [00:13<00:00,  6.57it/s]


Epoch 85 Loss: 1.0136645118395486


Epoch 86/100: 100%|██████████| 90/90 [00:13<00:00,  6.55it/s]


Epoch 86 Loss: 1.0351771576537026


Epoch 87/100: 100%|██████████| 90/90 [00:13<00:00,  6.50it/s]


Epoch 87 Loss: 1.0426621344354419


Epoch 88/100: 100%|██████████| 90/90 [00:13<00:00,  6.55it/s]


Epoch 88 Loss: 0.9980060976412561


Epoch 89/100: 100%|██████████| 90/90 [00:13<00:00,  6.58it/s]


Epoch 89 Loss: 1.0176732559998831


Epoch 90/100: 100%|██████████| 90/90 [00:13<00:00,  6.56it/s]


Epoch 90 Loss: 1.0063970178365707


Epoch 91/100: 100%|██████████| 90/90 [00:13<00:00,  6.54it/s]


Epoch 91 Loss: 1.0397711081637278


Epoch 92/100: 100%|██████████| 90/90 [00:13<00:00,  6.56it/s]


Epoch 92 Loss: 1.004975966612498


Epoch 93/100: 100%|██████████| 90/90 [00:13<00:00,  6.56it/s]


Epoch 93 Loss: 0.9759337027867635


Epoch 94/100: 100%|██████████| 90/90 [00:13<00:00,  6.55it/s]


Epoch 94 Loss: 1.006122069557508


Epoch 95/100: 100%|██████████| 90/90 [00:13<00:00,  6.54it/s]


Epoch 95 Loss: 1.030162391397688


Epoch 96/100: 100%|██████████| 90/90 [00:13<00:00,  6.54it/s]


Epoch 96 Loss: 1.0524790896309746


Epoch 97/100: 100%|██████████| 90/90 [00:13<00:00,  6.47it/s]


Epoch 97 Loss: 1.0062871078650157


Epoch 98/100: 100%|██████████| 90/90 [00:13<00:00,  6.58it/s]


Epoch 98 Loss: 1.034164637989468


Epoch 99/100: 100%|██████████| 90/90 [00:13<00:00,  6.54it/s]


Epoch 99 Loss: 1.0099311043818793


Epoch 100/100: 100%|██████████| 90/90 [00:13<00:00,  6.54it/s]


Epoch 100 Loss: 1.0118419922060438


[]

# Freezing 10layer

In [None]:
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import CLIPProcessor, CLIPModel
import os
import json
from PIL import Image
from tqdm import tqdm

# Dataset Class for Variable Label Lengths
class ImageTextDataset(Dataset):
    def __init__(self, image_dir, json_path, processor):
        """
        Args:
            image_dir (str): Path to the directory containing images.
            json_path (str): Path to the JSON file with image-label mappings.
            processor: CLIP processor for preprocessing.
        """
        self.image_dir = image_dir
        with open(json_path, 'r') as f:
            self.data = json.load(f)
        self.processor = processor
        self.samples = [
            (image_name, label)
            for image_name, labels in self.data.items()
            for label in labels  # Flatten to image-label pairs
        ]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        image_name, label = self.samples[idx]
        image_path = os.path.join(self.image_dir, image_name)
        image = Image.open(image_path).convert("RGB")
        return image, label, image_name

# Custom Collate Function
def collate_fn(batch):
    """
    Custom collate function to handle variable-length text inputs.

    Args:
        batch: List of tuples (image, label, image_name).

    Returns:
        Processed inputs, image_names, labels.
    """
    images = [item[0] for item in batch]
    labels = [item[1] for item in batch]
    image_names = [item[2] for item in batch]

    # Process images and texts separately
    inputs = processor(
        text=labels,
        images=images,
        return_tensors="pt",
        padding=True
    )
    return inputs, image_names, labels

# Load Pre-trained CLIP Model and Processor
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# Freeze all parameters
for param in model.parameters():
    param.requires_grad = False

# Unfreeze the last few layers of the vision model
num_unfreeze_vision = 10  # Number of vision encoder layers to unfreeze
for block in model.vision_model.encoder.layers[-num_unfreeze_vision:]:
    for param in block.parameters():
        param.requires_grad = True

# Unfreeze the last few layers of the text model
num_unfreeze_text = 10  # Number of text encoder layers to unfreeze
for block in model.text_model.encoder.layers[-num_unfreeze_text:]:
    for param in block.parameters():
        param.requires_grad = True

# Unfreeze the projection layers
for param in model.visual_projection.parameters():
    param.requires_grad = True

for param in model.text_projection.parameters():
    param.requires_grad = True

# Dataset and DataLoader
image_dir = train_path  # Replace with your image directory
json_path = train_label_path  # Replace with your label file path
dataset = ImageTextDataset(image_dir, json_path, processor)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)

# Training Setup
optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()), lr=5e-6
)
loss_fn = torch.nn.CrossEntropyLoss()

# Fine-Tuning Loop
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

epochs = 100
for epoch in range(epochs):
    model.train()
    total_loss = 0
    for batch in tqdm(dataloader, desc=f"Epoch {epoch + 1}/{epochs}"):
        inputs, _, _ = batch
        pixel_values = inputs["pixel_values"].to(device)
        input_ids = inputs["input_ids"].to(device)
        attention_mask = inputs["attention_mask"].to(device)

        # Get embeddings
        outputs = model(
            pixel_values=pixel_values,
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        logits_per_image = outputs.logits_per_image  # Image-to-text similarity
        logits_per_text = outputs.logits_per_text  # Text-to-image similarity

        # Target: Single correct pair per batch
        targets = torch.arange(len(logits_per_image)).to(device)

        # Compute loss
        loss = (loss_fn(logits_per_image, targets) + loss_fn(logits_per_text, targets)) / 2
        total_loss += loss.item()

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch + 1} Loss: {total_loss / len(dataloader)}")

output_dir = '/content/drive/My Drive/Image_text models/CLIP_fine-tunned_freeze_10layer/'

# Save the Fine-tuned Model
model.save_pretrained(output_dir)
processor.save_pretrained(output_dir)

Epoch 1/100: 100%|██████████| 90/90 [00:16<00:00,  5.49it/s]


Epoch 1 Loss: 1.6042993783950805


Epoch 2/100: 100%|██████████| 90/90 [00:16<00:00,  5.62it/s]


Epoch 2 Loss: 1.2681684679455227


Epoch 3/100: 100%|██████████| 90/90 [00:15<00:00,  5.67it/s]


Epoch 3 Loss: 1.252566048171785


Epoch 4/100: 100%|██████████| 90/90 [00:15<00:00,  5.68it/s]


Epoch 4 Loss: 1.157747703128391


Epoch 5/100: 100%|██████████| 90/90 [00:15<00:00,  5.64it/s]


Epoch 5 Loss: 1.15138518081771


Epoch 6/100: 100%|██████████| 90/90 [00:16<00:00,  5.61it/s]


Epoch 6 Loss: 1.1487221346961127


Epoch 7/100: 100%|██████████| 90/90 [00:15<00:00,  5.66it/s]


Epoch 7 Loss: 1.1593861159351138


Epoch 8/100: 100%|██████████| 90/90 [00:16<00:00,  5.62it/s]


Epoch 8 Loss: 1.107711550924513


Epoch 9/100: 100%|██████████| 90/90 [00:16<00:00,  5.60it/s]


Epoch 9 Loss: 1.1058588915401035


Epoch 10/100: 100%|██████████| 90/90 [00:16<00:00,  5.59it/s]


Epoch 10 Loss: 1.1094505879614087


Epoch 11/100: 100%|██████████| 90/90 [00:15<00:00,  5.64it/s]


Epoch 11 Loss: 1.0922796338796616


Epoch 12/100: 100%|██████████| 90/90 [00:15<00:00,  5.63it/s]


Epoch 12 Loss: 1.1080534531010522


Epoch 13/100: 100%|██████████| 90/90 [00:16<00:00,  5.59it/s]


Epoch 13 Loss: 1.072628562649091


Epoch 14/100: 100%|██████████| 90/90 [00:16<00:00,  5.61it/s]


Epoch 14 Loss: 1.07061241335339


Epoch 15/100: 100%|██████████| 90/90 [00:16<00:00,  5.60it/s]


Epoch 15 Loss: 1.1019275860653983


Epoch 16/100: 100%|██████████| 90/90 [00:16<00:00,  5.58it/s]


Epoch 16 Loss: 1.0991683691740035


Epoch 17/100: 100%|██████████| 90/90 [00:16<00:00,  5.61it/s]


Epoch 17 Loss: 1.0982932302686903


Epoch 18/100: 100%|██████████| 90/90 [00:16<00:00,  5.62it/s]


Epoch 18 Loss: 1.087209822071923


Epoch 19/100: 100%|██████████| 90/90 [00:16<00:00,  5.57it/s]


Epoch 19 Loss: 1.082083613342709


Epoch 20/100: 100%|██████████| 90/90 [00:16<00:00,  5.60it/s]


Epoch 20 Loss: 1.0411112599902683


Epoch 21/100: 100%|██████████| 90/90 [00:16<00:00,  5.58it/s]


Epoch 21 Loss: 1.0618518908818564


Epoch 22/100: 100%|██████████| 90/90 [00:16<00:00,  5.58it/s]


Epoch 22 Loss: 1.0786213662889268


Epoch 23/100: 100%|██████████| 90/90 [00:15<00:00,  5.64it/s]


Epoch 23 Loss: 1.0587264074219598


Epoch 24/100: 100%|██████████| 90/90 [00:15<00:00,  5.63it/s]


Epoch 24 Loss: 1.0484385271867116


Epoch 25/100: 100%|██████████| 90/90 [00:16<00:00,  5.57it/s]


Epoch 25 Loss: 1.0586055285400815


Epoch 26/100: 100%|██████████| 90/90 [00:16<00:00,  5.60it/s]


Epoch 26 Loss: 1.0392515463961496


Epoch 27/100: 100%|██████████| 90/90 [00:15<00:00,  5.63it/s]


Epoch 27 Loss: 1.037819155719545


Epoch 28/100: 100%|██████████| 90/90 [00:16<00:00,  5.60it/s]


Epoch 28 Loss: 1.0727485325601367


Epoch 29/100: 100%|██████████| 90/90 [00:15<00:00,  5.63it/s]


Epoch 29 Loss: 1.0722738805744383


Epoch 30/100: 100%|██████████| 90/90 [00:16<00:00,  5.62it/s]


Epoch 30 Loss: 1.0730877919329538


Epoch 31/100: 100%|██████████| 90/90 [00:16<00:00,  5.62it/s]


Epoch 31 Loss: 1.03470778465271


Epoch 32/100: 100%|██████████| 90/90 [00:16<00:00,  5.60it/s]


Epoch 32 Loss: 1.0490541411770715


Epoch 33/100: 100%|██████████| 90/90 [00:15<00:00,  5.64it/s]


Epoch 33 Loss: 1.0394524388843112


Epoch 34/100: 100%|██████████| 90/90 [00:16<00:00,  5.61it/s]


Epoch 34 Loss: 1.0536157760355207


Epoch 35/100: 100%|██████████| 90/90 [00:16<00:00,  5.61it/s]


Epoch 35 Loss: 1.0308429698149364


Epoch 36/100: 100%|██████████| 90/90 [00:15<00:00,  5.63it/s]


Epoch 36 Loss: 1.0462716533078087


Epoch 37/100: 100%|██████████| 90/90 [00:15<00:00,  5.63it/s]


Epoch 37 Loss: 1.013976949453354


Epoch 38/100: 100%|██████████| 90/90 [00:16<00:00,  5.62it/s]


Epoch 38 Loss: 1.0566611329714457


Epoch 39/100: 100%|██████████| 90/90 [00:16<00:00,  5.60it/s]


Epoch 39 Loss: 1.0540140648682912


Epoch 40/100: 100%|██████████| 90/90 [00:15<00:00,  5.64it/s]


Epoch 40 Loss: 1.03531410296758


Epoch 41/100: 100%|██████████| 90/90 [00:15<00:00,  5.63it/s]


Epoch 41 Loss: 1.061013208495246


Epoch 42/100: 100%|██████████| 90/90 [00:15<00:00,  5.68it/s]


Epoch 42 Loss: 1.0766001866923438


Epoch 43/100: 100%|██████████| 90/90 [00:15<00:00,  5.63it/s]


Epoch 43 Loss: 1.020225308338801


Epoch 44/100: 100%|██████████| 90/90 [00:15<00:00,  5.63it/s]


Epoch 44 Loss: 1.0541974743207296


Epoch 45/100: 100%|██████████| 90/90 [00:16<00:00,  5.62it/s]


Epoch 45 Loss: 1.038666311899821


Epoch 46/100: 100%|██████████| 90/90 [00:15<00:00,  5.66it/s]


Epoch 46 Loss: 1.0525204608837764


Epoch 47/100: 100%|██████████| 90/90 [00:15<00:00,  5.67it/s]


Epoch 47 Loss: 1.0301720009909736


Epoch 48/100: 100%|██████████| 90/90 [00:15<00:00,  5.63it/s]


Epoch 48 Loss: 1.032989932762252


Epoch 49/100: 100%|██████████| 90/90 [00:16<00:00,  5.61it/s]


Epoch 49 Loss: 1.0492051439152823


Epoch 50/100: 100%|██████████| 90/90 [00:16<00:00,  5.62it/s]


Epoch 50 Loss: 1.0724315835369957


Epoch 51/100: 100%|██████████| 90/90 [00:15<00:00,  5.64it/s]


Epoch 51 Loss: 1.045020572344462


Epoch 52/100: 100%|██████████| 90/90 [00:16<00:00,  5.58it/s]


Epoch 52 Loss: 1.062402496735255


Epoch 53/100: 100%|██████████| 90/90 [00:16<00:00,  5.61it/s]


Epoch 53 Loss: 1.0395480268531376


Epoch 54/100: 100%|██████████| 90/90 [00:16<00:00,  5.57it/s]


Epoch 54 Loss: 1.0320887141757542


Epoch 55/100: 100%|██████████| 90/90 [00:16<00:00,  5.56it/s]


Epoch 55 Loss: 1.0365233835246828


Epoch 56/100: 100%|██████████| 90/90 [00:16<00:00,  5.48it/s]


Epoch 56 Loss: 1.0288296070363787


Epoch 57/100: 100%|██████████| 90/90 [00:16<00:00,  5.48it/s]


Epoch 57 Loss: 1.0341628743542566


Epoch 58/100: 100%|██████████| 90/90 [00:16<00:00,  5.51it/s]


Epoch 58 Loss: 1.0107795619302327


Epoch 59/100: 100%|██████████| 90/90 [00:16<00:00,  5.58it/s]


Epoch 59 Loss: 1.019809744755427


Epoch 60/100: 100%|██████████| 90/90 [00:16<00:00,  5.60it/s]


Epoch 60 Loss: 1.0301227546400493


Epoch 61/100: 100%|██████████| 90/90 [00:16<00:00,  5.57it/s]


Epoch 61 Loss: 1.025265653596984


Epoch 62/100: 100%|██████████| 90/90 [00:16<00:00,  5.58it/s]


Epoch 62 Loss: 1.0373506612247891


Epoch 63/100: 100%|██████████| 90/90 [00:16<00:00,  5.51it/s]


Epoch 63 Loss: 1.0383623517221874


Epoch 64/100: 100%|██████████| 90/90 [00:16<00:00,  5.59it/s]


Epoch 64 Loss: 1.0218967388073603


Epoch 65/100: 100%|██████████| 90/90 [00:15<00:00,  5.63it/s]


Epoch 65 Loss: 1.0182910521825155


Epoch 66/100: 100%|██████████| 90/90 [00:15<00:00,  5.66it/s]


Epoch 66 Loss: 1.026273677746455


Epoch 67/100: 100%|██████████| 90/90 [00:16<00:00,  5.61it/s]


Epoch 67 Loss: 1.0222249921825197


Epoch 68/100: 100%|██████████| 90/90 [00:15<00:00,  5.64it/s]


Epoch 68 Loss: 1.032025349802441


Epoch 69/100: 100%|██████████| 90/90 [00:15<00:00,  5.63it/s]


Epoch 69 Loss: 1.0328230324718688


Epoch 70/100: 100%|██████████| 90/90 [00:15<00:00,  5.65it/s]


Epoch 70 Loss: 1.0460636360777749


Epoch 71/100: 100%|██████████| 90/90 [00:15<00:00,  5.64it/s]


Epoch 71 Loss: 1.0298694736427731


Epoch 72/100: 100%|██████████| 90/90 [00:15<00:00,  5.66it/s]


Epoch 72 Loss: 1.0282256563504537


Epoch 73/100: 100%|██████████| 90/90 [00:16<00:00,  5.60it/s]


Epoch 73 Loss: 1.0254932830731074


Epoch 74/100: 100%|██████████| 90/90 [00:16<00:00,  5.61it/s]


Epoch 74 Loss: 1.0387339472770691


Epoch 75/100: 100%|██████████| 90/90 [00:16<00:00,  5.62it/s]


Epoch 75 Loss: 1.013740403122372


Epoch 76/100: 100%|██████████| 90/90 [00:15<00:00,  5.66it/s]


Epoch 76 Loss: 1.0144969953431024


Epoch 77/100: 100%|██████████| 90/90 [00:15<00:00,  5.66it/s]


Epoch 77 Loss: 1.0154777394400702


Epoch 78/100: 100%|██████████| 90/90 [00:15<00:00,  5.66it/s]


Epoch 78 Loss: 1.0117676628960504


Epoch 79/100: 100%|██████████| 90/90 [00:15<00:00,  5.63it/s]


Epoch 79 Loss: 0.990981423523691


Epoch 80/100: 100%|██████████| 90/90 [00:16<00:00,  5.61it/s]


Epoch 80 Loss: 1.0432870003912185


Epoch 81/100: 100%|██████████| 90/90 [00:15<00:00,  5.64it/s]


Epoch 81 Loss: 1.0303444898790783


Epoch 82/100: 100%|██████████| 90/90 [00:15<00:00,  5.63it/s]


Epoch 82 Loss: 1.0483780172136095


Epoch 83/100: 100%|██████████| 90/90 [00:16<00:00,  5.60it/s]


Epoch 83 Loss: 1.021624845597479


Epoch 84/100: 100%|██████████| 90/90 [00:15<00:00,  5.66it/s]


Epoch 84 Loss: 1.0397756030162175


Epoch 85/100: 100%|██████████| 90/90 [00:15<00:00,  5.65it/s]


Epoch 85 Loss: 1.0699537300401263


Epoch 86/100: 100%|██████████| 90/90 [00:16<00:00,  5.62it/s]


Epoch 86 Loss: 0.9956395241949293


Epoch 87/100: 100%|██████████| 90/90 [00:15<00:00,  5.66it/s]


Epoch 87 Loss: 1.0681516783105003


Epoch 88/100: 100%|██████████| 90/90 [00:16<00:00,  5.62it/s]


Epoch 88 Loss: 1.0134561826785407


Epoch 89/100: 100%|██████████| 90/90 [00:15<00:00,  5.65it/s]


Epoch 89 Loss: 1.0190034061670303


Epoch 90/100: 100%|██████████| 90/90 [00:15<00:00,  5.64it/s]


Epoch 90 Loss: 1.020582299762302


Epoch 91/100: 100%|██████████| 90/90 [00:15<00:00,  5.64it/s]


Epoch 91 Loss: 1.0063775585757362


Epoch 92/100: 100%|██████████| 90/90 [00:15<00:00,  5.66it/s]


Epoch 92 Loss: 1.0104183024830289


Epoch 93/100: 100%|██████████| 90/90 [00:16<00:00,  5.61it/s]


Epoch 93 Loss: 1.0048756152391434


Epoch 94/100: 100%|██████████| 90/90 [00:15<00:00,  5.64it/s]


Epoch 94 Loss: 1.0313951538668737


Epoch 95/100: 100%|██████████| 90/90 [00:15<00:00,  5.66it/s]


Epoch 95 Loss: 1.0164359894063737


Epoch 96/100: 100%|██████████| 90/90 [00:15<00:00,  5.63it/s]


Epoch 96 Loss: 0.9936663793192969


Epoch 97/100: 100%|██████████| 90/90 [00:15<00:00,  5.64it/s]


Epoch 97 Loss: 1.0256016731262207


Epoch 98/100: 100%|██████████| 90/90 [00:15<00:00,  5.63it/s]


Epoch 98 Loss: 1.0023028870423636


Epoch 99/100: 100%|██████████| 90/90 [00:15<00:00,  5.64it/s]


Epoch 99 Loss: 1.0105150471131006


Epoch 100/100: 100%|██████████| 90/90 [00:15<00:00,  5.63it/s]


Epoch 100 Loss: 1.0248105817370945


[]

# Info NCE Loss

In [None]:
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import CLIPProcessor, CLIPModel
import os
import json
from PIL import Image
from tqdm import tqdm

# Dataset Class for Variable Label Lengths
class ImageTextDataset(Dataset):
    def __init__(self, image_dir, json_path, processor):
        """
        Args:
            image_dir (str): Path to the directory containing images.
            json_path (str): Path to the JSON file with image-label mappings.
            processor: CLIP processor for preprocessing.
        """
        self.image_dir = image_dir
        with open(json_path, 'r') as f:
            self.data = json.load(f)
        self.processor = processor
        self.samples = [
            (image_name, label)
            for image_name, labels in self.data.items()
            for label in labels  # Flatten to image-label pairs
        ]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        image_name, label = self.samples[idx]
        image_path = os.path.join(self.image_dir, image_name)
        image = Image.open(image_path).convert("RGB")
        return image, label, image_name

# Custom Collate Function
def collate_fn(batch):
    """
    Custom collate function to handle variable-length text inputs.

    Args:
        batch: List of tuples (image, label, image_name).

    Returns:
        Processed inputs, image_names, labels.
    """
    images = [item[0] for item in batch]
    labels = [item[1] for item in batch]
    image_names = [item[2] for item in batch]

    # Process images and texts separately
    inputs = processor(
        text=labels,
        images=images,
        return_tensors="pt",
        padding=True
    )
    return inputs, image_names, labels

# Define the InfoNCE Loss Function
def info_nce_loss(image_embeddings, text_embeddings, logit_scale):
    """
    Computes the InfoNCE loss for the CLIP model.

    Args:
        image_embeddings: Image embeddings tensor.
        text_embeddings: Text embeddings tensor.
        logit_scale: Logit scaling factor (usually obtained from the model).

    Returns:
        Scalar loss value.
    """
    # Normalize the embeddings
    image_embeddings = image_embeddings / image_embeddings.norm(dim=-1, keepdim=True)
    text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True)

    # Compute similarity logits
    logits = torch.matmul(image_embeddings, text_embeddings.T) * logit_scale

    # Ground truth labels (diagonal elements are positives)
    batch_size = image_embeddings.size(0)
    labels = torch.arange(batch_size).to(image_embeddings.device)

    # Cross-entropy loss
    loss_i2t = torch.nn.functional.cross_entropy(logits, labels)
    loss_t2i = torch.nn.functional.cross_entropy(logits.T, labels)
    loss = (loss_i2t + loss_t2i) / 2.0
    return loss

# Load Pre-trained CLIP Model and Processor
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# Dataset and DataLoader
image_dir = train_path  # Replace with your image directory
json_path = train_label_path  # Replace with your label file path
dataset = ImageTextDataset(image_dir, json_path, processor)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)

# Training Setup
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-6)

# Fine-Tuning Loop
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

epochs = 100
for epoch in range(epochs):
    model.train()
    total_loss = 0
    for batch in tqdm(dataloader, desc=f"Epoch {epoch + 1}/{epochs}"):
        inputs, _, _ = batch
        pixel_values = inputs["pixel_values"].to(device)
        input_ids = inputs["input_ids"].to(device)
        attention_mask = inputs["attention_mask"].to(device)

        # Forward pass
        outputs = model(
            pixel_values=pixel_values,
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True
        )
        image_embeddings = outputs.image_embeds
        text_embeddings = outputs.text_embeds
        logit_scale = model.logit_scale.exp()

        # Compute InfoNCE loss
        loss = info_nce_loss(image_embeddings, text_embeddings, logit_scale)
        total_loss += loss.item()

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch + 1} Loss: {total_loss / len(dataloader)}")

output_dir = '/content/drive/My Drive/Image_text models/CLIP_fine-tunned_NCEloss/'

# Save the Fine-tuned Model
model.save_pretrained(output_dir)
processor.save_pretrained(output_dir)


Epoch 1/100: 100%|██████████| 90/90 [00:17<00:00,  5.24it/s]


Epoch 1 Loss: 1.5853077431519826


Epoch 2/100: 100%|██████████| 90/90 [00:16<00:00,  5.44it/s]


Epoch 2 Loss: 1.2576130277580686


Epoch 3/100: 100%|██████████| 90/90 [00:16<00:00,  5.43it/s]


Epoch 3 Loss: 1.245283994409773


Epoch 4/100: 100%|██████████| 90/90 [00:16<00:00,  5.38it/s]


Epoch 4 Loss: 1.2047805077499814


Epoch 5/100: 100%|██████████| 90/90 [00:16<00:00,  5.43it/s]


Epoch 5 Loss: 1.1379893435372246


Epoch 6/100: 100%|██████████| 90/90 [00:16<00:00,  5.44it/s]


Epoch 6 Loss: 1.1423303292857276


Epoch 7/100: 100%|██████████| 90/90 [00:16<00:00,  5.38it/s]


Epoch 7 Loss: 1.1637189070383707


Epoch 8/100: 100%|██████████| 90/90 [00:16<00:00,  5.41it/s]


Epoch 8 Loss: 1.129258375035392


Epoch 9/100: 100%|██████████| 90/90 [00:16<00:00,  5.41it/s]


Epoch 9 Loss: 1.1058789289659925


Epoch 10/100: 100%|██████████| 90/90 [00:16<00:00,  5.40it/s]


Epoch 10 Loss: 1.1069667869144015


Epoch 11/100: 100%|██████████| 90/90 [00:16<00:00,  5.40it/s]


Epoch 11 Loss: 1.0993518124024073


Epoch 12/100: 100%|██████████| 90/90 [00:16<00:00,  5.40it/s]


Epoch 12 Loss: 1.0855102837085724


Epoch 13/100: 100%|██████████| 90/90 [00:16<00:00,  5.45it/s]


Epoch 13 Loss: 1.1089292897118463


Epoch 14/100: 100%|██████████| 90/90 [00:16<00:00,  5.44it/s]


Epoch 14 Loss: 1.095831452144517


Epoch 15/100: 100%|██████████| 90/90 [00:16<00:00,  5.39it/s]


Epoch 15 Loss: 1.0996181478103002


Epoch 16/100: 100%|██████████| 90/90 [00:16<00:00,  5.42it/s]


Epoch 16 Loss: 1.0738052076763578


Epoch 17/100: 100%|██████████| 90/90 [00:16<00:00,  5.42it/s]


Epoch 17 Loss: 1.0954957314663463


Epoch 18/100: 100%|██████████| 90/90 [00:16<00:00,  5.45it/s]


Epoch 18 Loss: 1.0712362163596683


Epoch 19/100: 100%|██████████| 90/90 [00:16<00:00,  5.40it/s]


Epoch 19 Loss: 1.0799297094345093


Epoch 20/100: 100%|██████████| 90/90 [00:16<00:00,  5.44it/s]


Epoch 20 Loss: 1.08804372118579


Epoch 21/100: 100%|██████████| 90/90 [00:16<00:00,  5.42it/s]


Epoch 21 Loss: 1.095726917187373


Epoch 22/100: 100%|██████████| 90/90 [00:16<00:00,  5.43it/s]


Epoch 22 Loss: 1.0968442704942492


Epoch 23/100: 100%|██████████| 90/90 [00:16<00:00,  5.39it/s]


Epoch 23 Loss: 1.0808910621537102


Epoch 24/100: 100%|██████████| 90/90 [00:16<00:00,  5.40it/s]


Epoch 24 Loss: 1.0668689482741887


Epoch 25/100: 100%|██████████| 90/90 [00:16<00:00,  5.42it/s]


Epoch 25 Loss: 1.0879038688209322


Epoch 26/100: 100%|██████████| 90/90 [00:16<00:00,  5.32it/s]


Epoch 26 Loss: 1.0685463362269931


Epoch 27/100: 100%|██████████| 90/90 [00:16<00:00,  5.43it/s]


Epoch 27 Loss: 1.0712467875745562


Epoch 28/100: 100%|██████████| 90/90 [00:16<00:00,  5.42it/s]


Epoch 28 Loss: 1.065511683622996


Epoch 29/100: 100%|██████████| 90/90 [00:16<00:00,  5.39it/s]


Epoch 29 Loss: 1.0516919255256654


Epoch 30/100: 100%|██████████| 90/90 [00:16<00:00,  5.40it/s]


Epoch 30 Loss: 1.0532794836494657


Epoch 31/100: 100%|██████████| 90/90 [00:16<00:00,  5.43it/s]


Epoch 31 Loss: 1.064936402771208


Epoch 32/100: 100%|██████████| 90/90 [00:16<00:00,  5.40it/s]


Epoch 32 Loss: 1.0540701468785605


Epoch 33/100: 100%|██████████| 90/90 [00:16<00:00,  5.43it/s]


Epoch 33 Loss: 1.0656849486960305


Epoch 34/100: 100%|██████████| 90/90 [00:16<00:00,  5.39it/s]


Epoch 34 Loss: 1.0444475962056055


Epoch 35/100: 100%|██████████| 90/90 [00:16<00:00,  5.43it/s]


Epoch 35 Loss: 1.0408451634976599


Epoch 36/100: 100%|██████████| 90/90 [00:16<00:00,  5.42it/s]


Epoch 36 Loss: 1.0944158063994514


Epoch 37/100: 100%|██████████| 90/90 [00:16<00:00,  5.39it/s]


Epoch 37 Loss: 1.0595342112912072


Epoch 38/100: 100%|██████████| 90/90 [00:16<00:00,  5.42it/s]


Epoch 38 Loss: 1.039101904961798


Epoch 39/100: 100%|██████████| 90/90 [00:16<00:00,  5.45it/s]


Epoch 39 Loss: 1.0054027563995784


Epoch 40/100: 100%|██████████| 90/90 [00:16<00:00,  5.42it/s]


Epoch 40 Loss: 1.0748381286859512


Epoch 41/100: 100%|██████████| 90/90 [00:16<00:00,  5.42it/s]


Epoch 41 Loss: 1.056810513138771


Epoch 42/100: 100%|██████████| 90/90 [00:16<00:00,  5.43it/s]


Epoch 42 Loss: 1.016345586710506


Epoch 43/100: 100%|██████████| 90/90 [00:16<00:00,  5.44it/s]


Epoch 43 Loss: 1.0307846006419923


Epoch 44/100: 100%|██████████| 90/90 [00:16<00:00,  5.43it/s]


Epoch 44 Loss: 1.0721919655799865


Epoch 45/100: 100%|██████████| 90/90 [00:16<00:00,  5.44it/s]


Epoch 45 Loss: 1.0509478204780154


Epoch 46/100: 100%|██████████| 90/90 [00:16<00:00,  5.43it/s]


Epoch 46 Loss: 1.0337784469127655


Epoch 47/100: 100%|██████████| 90/90 [00:16<00:00,  5.44it/s]


Epoch 47 Loss: 1.0520914806260002


Epoch 48/100: 100%|██████████| 90/90 [00:16<00:00,  5.37it/s]


Epoch 48 Loss: 1.063833643330468


Epoch 49/100: 100%|██████████| 90/90 [00:16<00:00,  5.40it/s]


Epoch 49 Loss: 1.0297347353564368


Epoch 50/100: 100%|██████████| 90/90 [00:16<00:00,  5.42it/s]


Epoch 50 Loss: 1.0279345972670448


Epoch 51/100: 100%|██████████| 90/90 [00:16<00:00,  5.40it/s]


Epoch 51 Loss: 1.06149924993515


Epoch 52/100: 100%|██████████| 90/90 [00:16<00:00,  5.44it/s]


Epoch 52 Loss: 1.0311456011401283


Epoch 53/100: 100%|██████████| 90/90 [00:16<00:00,  5.37it/s]


Epoch 53 Loss: 1.046382137801912


Epoch 54/100: 100%|██████████| 90/90 [00:16<00:00,  5.37it/s]


Epoch 54 Loss: 1.0259472754266528


Epoch 55/100: 100%|██████████| 90/90 [00:16<00:00,  5.34it/s]


Epoch 55 Loss: 1.0515583402580686


Epoch 56/100: 100%|██████████| 90/90 [00:16<00:00,  5.38it/s]


Epoch 56 Loss: 1.0726051721307965


Epoch 57/100: 100%|██████████| 90/90 [00:16<00:00,  5.37it/s]


Epoch 57 Loss: 1.0118743154737684


Epoch 58/100: 100%|██████████| 90/90 [00:16<00:00,  5.42it/s]


Epoch 58 Loss: 1.0387346953153611


Epoch 59/100: 100%|██████████| 90/90 [00:16<00:00,  5.40it/s]


Epoch 59 Loss: 1.060973866780599


Epoch 60/100: 100%|██████████| 90/90 [00:16<00:00,  5.34it/s]


Epoch 60 Loss: 1.0407769504520628


Epoch 61/100: 100%|██████████| 90/90 [00:16<00:00,  5.40it/s]


Epoch 61 Loss: 1.0462628768550024


Epoch 62/100: 100%|██████████| 90/90 [00:16<00:00,  5.41it/s]


Epoch 62 Loss: 1.0487950980663299


Epoch 63/100: 100%|██████████| 90/90 [00:16<00:00,  5.44it/s]


Epoch 63 Loss: 1.0635191768407821


Epoch 64/100: 100%|██████████| 90/90 [00:16<00:00,  5.41it/s]


Epoch 64 Loss: 1.0787701275613573


Epoch 65/100: 100%|██████████| 90/90 [00:16<00:00,  5.40it/s]


Epoch 65 Loss: 1.0460608416133457


Epoch 66/100: 100%|██████████| 90/90 [00:16<00:00,  5.42it/s]


Epoch 66 Loss: 1.0335392544666926


Epoch 67/100: 100%|██████████| 90/90 [00:16<00:00,  5.39it/s]


Epoch 67 Loss: 1.0045287208424674


Epoch 68/100: 100%|██████████| 90/90 [00:16<00:00,  5.42it/s]


Epoch 68 Loss: 1.0098267767164442


Epoch 69/100: 100%|██████████| 90/90 [00:16<00:00,  5.45it/s]


Epoch 69 Loss: 1.0243418653806051


Epoch 70/100: 100%|██████████| 90/90 [00:16<00:00,  5.43it/s]


Epoch 70 Loss: 1.0342219832870696


Epoch 71/100: 100%|██████████| 90/90 [00:16<00:00,  5.40it/s]


Epoch 71 Loss: 1.0452001872989867


Epoch 72/100: 100%|██████████| 90/90 [00:16<00:00,  5.46it/s]


Epoch 72 Loss: 1.019735227028529


Epoch 73/100: 100%|██████████| 90/90 [00:16<00:00,  5.40it/s]


Epoch 73 Loss: 1.0327225873867671


Epoch 74/100: 100%|██████████| 90/90 [00:16<00:00,  5.42it/s]


Epoch 74 Loss: 1.0330238822433684


Epoch 75/100: 100%|██████████| 90/90 [00:16<00:00,  5.43it/s]


Epoch 75 Loss: 1.0231142752700382


Epoch 76/100: 100%|██████████| 90/90 [00:16<00:00,  5.44it/s]


Epoch 76 Loss: 1.0322735293043985


Epoch 77/100: 100%|██████████| 90/90 [00:16<00:00,  5.43it/s]


Epoch 77 Loss: 1.0007755117283927


Epoch 78/100: 100%|██████████| 90/90 [00:16<00:00,  5.40it/s]


Epoch 78 Loss: 1.0421234097745684


Epoch 79/100: 100%|██████████| 90/90 [00:16<00:00,  5.41it/s]


Epoch 79 Loss: 1.0168669376108381


Epoch 80/100: 100%|██████████| 90/90 [00:16<00:00,  5.48it/s]


Epoch 80 Loss: 1.0239905301067564


Epoch 81/100: 100%|██████████| 90/90 [00:16<00:00,  5.46it/s]


Epoch 81 Loss: 0.9967061377233929


Epoch 82/100: 100%|██████████| 90/90 [00:16<00:00,  5.38it/s]


Epoch 82 Loss: 1.0492814865377214


Epoch 83/100: 100%|██████████| 90/90 [00:16<00:00,  5.42it/s]


Epoch 83 Loss: 1.042879830300808


Epoch 84/100: 100%|██████████| 90/90 [00:16<00:00,  5.42it/s]


Epoch 84 Loss: 1.0168524814976587


Epoch 85/100: 100%|██████████| 90/90 [00:16<00:00,  5.35it/s]


Epoch 85 Loss: 1.024218617214097


Epoch 86/100: 100%|██████████| 90/90 [00:16<00:00,  5.40it/s]


Epoch 86 Loss: 1.008716662062539


Epoch 87/100: 100%|██████████| 90/90 [00:16<00:00,  5.43it/s]


Epoch 87 Loss: 1.016300353076723


Epoch 88/100: 100%|██████████| 90/90 [00:16<00:00,  5.40it/s]


Epoch 88 Loss: 1.0281335105498632


Epoch 89/100: 100%|██████████| 90/90 [00:16<00:00,  5.41it/s]


Epoch 89 Loss: 1.0075913624631034


Epoch 90/100: 100%|██████████| 90/90 [00:16<00:00,  5.37it/s]


Epoch 90 Loss: 1.0286931938595243


Epoch 91/100: 100%|██████████| 90/90 [00:16<00:00,  5.46it/s]


Epoch 91 Loss: 1.0436984714534547


Epoch 92/100: 100%|██████████| 90/90 [00:16<00:00,  5.43it/s]


Epoch 92 Loss: 1.0238313630223275


Epoch 93/100: 100%|██████████| 90/90 [00:16<00:00,  5.41it/s]


Epoch 93 Loss: 1.03293101158407


Epoch 94/100: 100%|██████████| 90/90 [00:16<00:00,  5.45it/s]


Epoch 94 Loss: 1.0365278389718797


Epoch 95/100: 100%|██████████| 90/90 [00:16<00:00,  5.42it/s]


Epoch 95 Loss: 1.0173213862710528


Epoch 96/100: 100%|██████████| 90/90 [00:16<00:00,  5.42it/s]


Epoch 96 Loss: 0.999022948079639


Epoch 97/100: 100%|██████████| 90/90 [00:16<00:00,  5.38it/s]


Epoch 97 Loss: 1.0309170378579033


Epoch 98/100: 100%|██████████| 90/90 [00:16<00:00,  5.40it/s]


Epoch 98 Loss: 1.0202074299256008


Epoch 99/100: 100%|██████████| 90/90 [00:16<00:00,  5.39it/s]


Epoch 99 Loss: 1.0406276524066924


Epoch 100/100: 100%|██████████| 90/90 [00:16<00:00,  5.42it/s]


Epoch 100 Loss: 1.0291113393174278


[]

# Batch 1

In [None]:
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import CLIPProcessor, CLIPModel
import os
import json
from PIL import Image
from tqdm import tqdm

# Dataset Class for Variable Label Lengths
class ImageTextDataset(Dataset):
    def __init__(self, image_dir, json_path, processor):
        """
        Args:
            image_dir (str): Path to the directory containing images.
            json_path (str): Path to the JSON file with image-label mappings.
            processor: CLIP processor for preprocessing.
        """
        self.image_dir = image_dir
        with open(json_path, 'r') as f:
            self.data = json.load(f)
        self.processor = processor
        self.samples = [
            (image_name, label)
            for image_name, labels in self.data.items()
            for label in labels  # Flatten to image-label pairs
        ]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        image_name, label = self.samples[idx]
        image_path = os.path.join(self.image_dir, image_name)
        image = Image.open(image_path).convert("RGB")
        return image, label, image_name

# Custom Collate Function
def collate_fn(batch):
    """
    Custom collate function to handle variable-length text inputs.

    Args:
        batch: List of tuples (image, label, image_name).

    Returns:
        Processed inputs, image_names, labels.
    """
    images = [item[0] for item in batch]
    labels = [item[1] for item in batch]
    image_names = [item[2] for item in batch]

    # Process images and texts separately
    inputs = processor(
        text=labels,
        images=images,
        return_tensors="pt",
        padding=True
    )
    return inputs, image_names, labels

# Load Pre-trained CLIP Model and Processor
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# Dataset and DataLoader
image_dir = train_path # Replace with your image directory
json_path = train_label_path  # Replace with your label file path
dataset = ImageTextDataset(image_dir, json_path, processor)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=collate_fn)

# Training Setup
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-6)
loss_fn = torch.nn.CrossEntropyLoss()

# Fine-Tuning Loop
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

epochs = 100
for epoch in range(epochs):
    model.train()
    total_loss = 0
    for batch in tqdm(dataloader, desc=f"Epoch {epoch + 1}/{epochs}"):
        inputs, _, _ = batch
        pixel_values = inputs["pixel_values"].to(device)
        input_ids = inputs["input_ids"].to(device)
        attention_mask = inputs["attention_mask"].to(device)

        # Get embeddings
        outputs = model(pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask)
        logits_per_image = outputs.logits_per_image  # Image-to-text similarity
        logits_per_text = outputs.logits_per_text  # Text-to-image similarity

        # Target: Single correct pair per batch
        targets = torch.arange(len(logits_per_image)).to(device)

        # Compute loss
        loss = (loss_fn(logits_per_image, targets) + loss_fn(logits_per_text, targets)) / 2
        total_loss += loss.item()

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch + 1} Loss: {total_loss / len(dataloader)}")

output_dir = '/content/drive/My Drive/Image_text models/CLIP_fine-tunned_batch1/'

# Save the Fine-tuned Model
model.save_pretrained(output_dir)
processor.save_pretrained(output_dir)

# Batch 2

In [None]:
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import CLIPProcessor, CLIPModel
import os
import json
from PIL import Image
from tqdm import tqdm

# Dataset Class for Variable Label Lengths
class ImageTextDataset(Dataset):
    def __init__(self, image_dir, json_path, processor):
        """
        Args:
            image_dir (str): Path to the directory containing images.
            json_path (str): Path to the JSON file with image-label mappings.
            processor: CLIP processor for preprocessing.
        """
        self.image_dir = image_dir
        with open(json_path, 'r') as f:
            self.data = json.load(f)
        self.processor = processor
        self.samples = [
            (image_name, label)
            for image_name, labels in self.data.items()
            for label in labels  # Flatten to image-label pairs
        ]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        image_name, label = self.samples[idx]
        image_path = os.path.join(self.image_dir, image_name)
        image = Image.open(image_path).convert("RGB")
        return image, label, image_name

# Custom Collate Function
def collate_fn(batch):
    """
    Custom collate function to handle variable-length text inputs.

    Args:
        batch: List of tuples (image, label, image_name).

    Returns:
        Processed inputs, image_names, labels.
    """
    images = [item[0] for item in batch]
    labels = [item[1] for item in batch]
    image_names = [item[2] for item in batch]

    # Process images and texts separately
    inputs = processor(
        text=labels,
        images=images,
        return_tensors="pt",
        padding=True
    )
    return inputs, image_names, labels

# Load Pre-trained CLIP Model and Processor
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# Dataset and DataLoader
image_dir = train_path # Replace with your image directory
json_path = train_label_path  # Replace with your label file path
dataset = ImageTextDataset(image_dir, json_path, processor)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)

# Training Setup
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-6)
loss_fn = torch.nn.CrossEntropyLoss()

# Fine-Tuning Loop
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

epochs = 50
for epoch in range(epochs):
    model.train()
    total_loss = 0
    for batch in tqdm(dataloader, desc=f"Epoch {epoch + 1}/{epochs}"):
        inputs, _, _ = batch
        pixel_values = inputs["pixel_values"].to(device)
        input_ids = inputs["input_ids"].to(device)
        attention_mask = inputs["attention_mask"].to(device)

        # Get embeddings
        outputs = model(pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask)
        logits_per_image = outputs.logits_per_image  # Image-to-text similarity
        logits_per_text = outputs.logits_per_text  # Text-to-image similarity

        # Target: Single correct pair per batch
        targets = torch.arange(len(logits_per_image)).to(device)

        # Compute loss
        loss = (loss_fn(logits_per_image, targets) + loss_fn(logits_per_text, targets)) / 2
        total_loss += loss.item()

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch + 1} Loss: {total_loss / len(dataloader)}")

output_dir = '/content/drive/My Drive/Image_text models/CLIP_fine-tunned_batch2/'

# Save the Fine-tuned Model
model.save_pretrained(output_dir)
processor.save_pretrained(output_dir)

Epoch 1/50: 100%|██████████| 360/360 [00:33<00:00, 10.66it/s]


Epoch 1 Loss: 0.5422555976152782


Epoch 2/50: 100%|██████████| 360/360 [00:32<00:00, 10.99it/s]


Epoch 2 Loss: 0.36681859780605414


Epoch 3/50: 100%|██████████| 360/360 [00:32<00:00, 10.98it/s]


Epoch 3 Loss: 0.36905015648317707


Epoch 4/50: 100%|██████████| 360/360 [00:32<00:00, 11.01it/s]


Epoch 4 Loss: 0.3561517131895041


Epoch 5/50: 100%|██████████| 360/360 [00:32<00:00, 11.00it/s]


Epoch 5 Loss: 0.31306870295779593


Epoch 6/50: 100%|██████████| 360/360 [00:32<00:00, 11.04it/s]


Epoch 6 Loss: 0.3380009062918816


Epoch 7/50: 100%|██████████| 360/360 [00:32<00:00, 11.02it/s]


Epoch 7 Loss: 0.33837770262867806


Epoch 8/50: 100%|██████████| 360/360 [00:32<00:00, 11.01it/s]


Epoch 8 Loss: 0.3007974369688832


Epoch 9/50: 100%|██████████| 360/360 [00:32<00:00, 10.91it/s]


Epoch 9 Loss: 0.32058403150034187


Epoch 10/50: 100%|██████████| 360/360 [00:32<00:00, 11.01it/s]


Epoch 10 Loss: 0.3027906331452818


Epoch 11/50: 100%|██████████| 360/360 [00:32<00:00, 10.93it/s]


Epoch 11 Loss: 0.30588754773379556


Epoch 12/50: 100%|██████████| 360/360 [00:32<00:00, 10.98it/s]


Epoch 12 Loss: 0.3088043427970054


Epoch 13/50: 100%|██████████| 360/360 [00:32<00:00, 11.03it/s]


Epoch 13 Loss: 0.3113676084660357


Epoch 14/50: 100%|██████████| 360/360 [00:32<00:00, 10.99it/s]


Epoch 14 Loss: 0.30029616990328173


Epoch 15/50: 100%|██████████| 360/360 [00:32<00:00, 11.00it/s]


Epoch 15 Loss: 0.2965634963154116


Epoch 16/50: 100%|██████████| 360/360 [00:33<00:00, 10.89it/s]


Epoch 16 Loss: 0.31841355477058136


Epoch 17/50: 100%|██████████| 360/360 [00:32<00:00, 10.96it/s]


Epoch 17 Loss: 0.31320219774355185


Epoch 18/50: 100%|██████████| 360/360 [00:32<00:00, 10.95it/s]


Epoch 18 Loss: 0.27725361713628127


Epoch 19/50: 100%|██████████| 360/360 [00:32<00:00, 11.05it/s]


Epoch 19 Loss: 0.3095376330178422


Epoch 20/50: 100%|██████████| 360/360 [00:32<00:00, 10.94it/s]


Epoch 20 Loss: 0.27502692898674974


Epoch 21/50: 100%|██████████| 360/360 [00:32<00:00, 10.93it/s]


Epoch 21 Loss: 0.24436061703471143


Epoch 22/50: 100%|██████████| 360/360 [00:32<00:00, 11.04it/s]


Epoch 22 Loss: 0.28289897146473175


Epoch 23/50: 100%|██████████| 360/360 [00:32<00:00, 11.02it/s]


Epoch 23 Loss: 0.29910545622380547


Epoch 24/50: 100%|██████████| 360/360 [00:32<00:00, 10.93it/s]


Epoch 24 Loss: 0.27806528806687286


Epoch 25/50: 100%|██████████| 360/360 [00:32<00:00, 10.97it/s]


Epoch 25 Loss: 0.30212670332347596


Epoch 26/50: 100%|██████████| 360/360 [00:32<00:00, 11.03it/s]


Epoch 26 Loss: 0.2894323523094701


Epoch 27/50: 100%|██████████| 360/360 [00:32<00:00, 11.03it/s]


Epoch 27 Loss: 0.29371342713903914


Epoch 28/50: 100%|██████████| 360/360 [00:32<00:00, 11.05it/s]


Epoch 28 Loss: 0.28060510193225935


Epoch 29/50: 100%|██████████| 360/360 [00:32<00:00, 10.98it/s]


Epoch 29 Loss: 0.28582902942106964


Epoch 30/50: 100%|██████████| 360/360 [00:32<00:00, 10.98it/s]


Epoch 30 Loss: 0.2736297665370532


Epoch 31/50: 100%|██████████| 360/360 [00:32<00:00, 10.93it/s]


Epoch 31 Loss: 0.26909467954429955


Epoch 32/50: 100%|██████████| 360/360 [00:32<00:00, 11.04it/s]


Epoch 32 Loss: 0.26855142265036797


Epoch 33/50: 100%|██████████| 360/360 [00:32<00:00, 11.00it/s]


Epoch 33 Loss: 0.24893369112128721


Epoch 34/50: 100%|██████████| 360/360 [00:32<00:00, 10.97it/s]


Epoch 34 Loss: 0.2847012169886044


Epoch 35/50: 100%|██████████| 360/360 [00:32<00:00, 10.92it/s]


Epoch 35 Loss: 0.26611327594815876


Epoch 36/50: 100%|██████████| 360/360 [00:32<00:00, 10.92it/s]


Epoch 36 Loss: 0.30853297044820416


Epoch 37/50: 100%|██████████| 360/360 [00:33<00:00, 10.91it/s]


Epoch 37 Loss: 0.2639168383313896


Epoch 38/50: 100%|██████████| 360/360 [00:32<00:00, 11.04it/s]


Epoch 38 Loss: 0.2835930725646727


Epoch 39/50: 100%|██████████| 360/360 [00:33<00:00, 10.91it/s]


Epoch 39 Loss: 0.2557371320778717


Epoch 40/50: 100%|██████████| 360/360 [00:32<00:00, 10.95it/s]


Epoch 40 Loss: 0.2685751440280506


Epoch 41/50: 100%|██████████| 360/360 [00:32<00:00, 11.04it/s]


Epoch 41 Loss: 0.2799218522676509


Epoch 42/50: 100%|██████████| 360/360 [00:32<00:00, 11.04it/s]


Epoch 42 Loss: 0.25934069613301175


Epoch 43/50: 100%|██████████| 360/360 [00:32<00:00, 11.00it/s]


Epoch 43 Loss: 0.25278312064554376


Epoch 44/50: 100%|██████████| 360/360 [00:32<00:00, 11.01it/s]


Epoch 44 Loss: 0.24906301133119085


Epoch 45/50: 100%|██████████| 360/360 [00:32<00:00, 11.11it/s]


Epoch 45 Loss: 0.26446798131915883


Epoch 46/50: 100%|██████████| 360/360 [00:32<00:00, 10.93it/s]


Epoch 46 Loss: 0.2994256928164961


Epoch 47/50: 100%|██████████| 360/360 [00:32<00:00, 11.08it/s]


Epoch 47 Loss: 0.2780609607923149


Epoch 48/50: 100%|██████████| 360/360 [00:32<00:00, 11.07it/s]


Epoch 48 Loss: 0.24725178575487614


Epoch 49/50: 100%|██████████| 360/360 [00:32<00:00, 11.08it/s]


Epoch 49 Loss: 0.27804317695382047


Epoch 50/50: 100%|██████████| 360/360 [00:32<00:00, 11.03it/s]


Epoch 50 Loss: 0.23301027235432029


[]

# batch 16

In [None]:
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import CLIPProcessor, CLIPModel
import os
import json
from PIL import Image
from tqdm import tqdm

# Dataset Class for Variable Label Lengths
class ImageTextDataset(Dataset):
    def __init__(self, image_dir, json_path, processor):
        """
        Args:
            image_dir (str): Path to the directory containing images.
            json_path (str): Path to the JSON file with image-label mappings.
            processor: CLIP processor for preprocessing.
        """
        self.image_dir = image_dir
        with open(json_path, 'r') as f:
            self.data = json.load(f)
        self.processor = processor
        self.samples = [
            (image_name, label)
            for image_name, labels in self.data.items()
            for label in labels  # Flatten to image-label pairs
        ]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        image_name, label = self.samples[idx]
        image_path = os.path.join(self.image_dir, image_name)
        image = Image.open(image_path).convert("RGB")
        return image, label, image_name

# Custom Collate Function
def collate_fn(batch):
    """
    Custom collate function to handle variable-length text inputs.

    Args:
        batch: List of tuples (image, label, image_name).

    Returns:
        Processed inputs, image_names, labels.
    """
    images = [item[0] for item in batch]
    labels = [item[1] for item in batch]
    image_names = [item[2] for item in batch]

    # Process images and texts separately
    inputs = processor(
        text=labels,
        images=images,
        return_tensors="pt",
        padding=True
    )
    return inputs, image_names, labels

# Load Pre-trained CLIP Model and Processor
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# Dataset and DataLoader
image_dir = train_path # Replace with your image directory
json_path = train_label_path  # Replace with your label file path
dataset = ImageTextDataset(image_dir, json_path, processor)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, collate_fn=collate_fn)

# Training Setup
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-6)
loss_fn = torch.nn.CrossEntropyLoss()

# Fine-Tuning Loop
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

epochs = 30
for epoch in range(epochs):
    model.train()
    total_loss = 0
    for batch in tqdm(dataloader, desc=f"Epoch {epoch + 1}/{epochs}"):
        inputs, _, _ = batch
        pixel_values = inputs["pixel_values"].to(device)
        input_ids = inputs["input_ids"].to(device)
        attention_mask = inputs["attention_mask"].to(device)

        # Get embeddings
        outputs = model(pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask)
        logits_per_image = outputs.logits_per_image  # Image-to-text similarity
        logits_per_text = outputs.logits_per_text  # Text-to-image similarity

        # Target: Single correct pair per batch
        targets = torch.arange(len(logits_per_image)).to(device)

        # Compute loss
        loss = (loss_fn(logits_per_image, targets) + loss_fn(logits_per_text, targets)) / 2
        total_loss += loss.item()

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch + 1} Loss: {total_loss / len(dataloader)}")

output_dir = '/content/drive/My Drive/Image_text models/CLIP_fine-tunned_batch16/'

# Save the Fine-tuned Model
model.save_pretrained(output_dir)
processor.save_pretrained(output_dir)

Epoch 1/30: 100%|██████████| 45/45 [00:13<00:00,  3.30it/s]


Epoch 1 Loss: 2.233604245715671


Epoch 2/30: 100%|██████████| 45/45 [00:13<00:00,  3.32it/s]


Epoch 2 Loss: 1.8275928576787313


Epoch 3/30: 100%|██████████| 45/45 [00:13<00:00,  3.30it/s]


Epoch 3 Loss: 1.7465017371707492


Epoch 4/30: 100%|██████████| 45/45 [00:13<00:00,  3.31it/s]


Epoch 4 Loss: 1.7377519687016805


Epoch 5/30: 100%|██████████| 45/45 [00:13<00:00,  3.31it/s]


Epoch 5 Loss: 1.7251242028342353


Epoch 6/30: 100%|██████████| 45/45 [00:13<00:00,  3.31it/s]


Epoch 6 Loss: 1.679741628964742


Epoch 7/30: 100%|██████████| 45/45 [00:13<00:00,  3.28it/s]


Epoch 7 Loss: 1.665297638045417


Epoch 8/30: 100%|██████████| 45/45 [00:13<00:00,  3.30it/s]


Epoch 8 Loss: 1.686546672715081


Epoch 9/30: 100%|██████████| 45/45 [00:13<00:00,  3.27it/s]


Epoch 9 Loss: 1.6679582277933755


Epoch 10/30: 100%|██████████| 45/45 [00:13<00:00,  3.30it/s]


Epoch 10 Loss: 1.6492950201034546


Epoch 11/30: 100%|██████████| 45/45 [00:13<00:00,  3.32it/s]


Epoch 11 Loss: 1.6117117934756808


Epoch 12/30: 100%|██████████| 45/45 [00:13<00:00,  3.30it/s]


Epoch 12 Loss: 1.604825864897834


Epoch 13/30: 100%|██████████| 45/45 [00:13<00:00,  3.33it/s]


Epoch 13 Loss: 1.6381810771094427


Epoch 14/30: 100%|██████████| 45/45 [00:13<00:00,  3.29it/s]


Epoch 14 Loss: 1.6339445167117648


Epoch 15/30: 100%|██████████| 45/45 [00:13<00:00,  3.30it/s]


Epoch 15 Loss: 1.5970677110883924


Epoch 16/30: 100%|██████████| 45/45 [00:13<00:00,  3.28it/s]


Epoch 16 Loss: 1.6116660118103028


Epoch 17/30: 100%|██████████| 45/45 [00:13<00:00,  3.24it/s]


Epoch 17 Loss: 1.6135146379470826


Epoch 18/30: 100%|██████████| 45/45 [00:13<00:00,  3.32it/s]


Epoch 18 Loss: 1.5939011838701036


Epoch 19/30: 100%|██████████| 45/45 [00:13<00:00,  3.29it/s]


Epoch 19 Loss: 1.6095197333229914


Epoch 20/30: 100%|██████████| 45/45 [00:13<00:00,  3.31it/s]


Epoch 20 Loss: 1.62249755859375


Epoch 21/30: 100%|██████████| 45/45 [00:13<00:00,  3.32it/s]


Epoch 21 Loss: 1.6034669637680055


Epoch 22/30: 100%|██████████| 45/45 [00:13<00:00,  3.30it/s]


Epoch 22 Loss: 1.5938016414642333


Epoch 23/30: 100%|██████████| 45/45 [00:13<00:00,  3.30it/s]


Epoch 23 Loss: 1.6129871580335828


Epoch 24/30: 100%|██████████| 45/45 [00:13<00:00,  3.30it/s]


Epoch 24 Loss: 1.5978549718856812


Epoch 25/30: 100%|██████████| 45/45 [00:13<00:00,  3.30it/s]


Epoch 25 Loss: 1.582426921526591


Epoch 26/30: 100%|██████████| 45/45 [00:13<00:00,  3.28it/s]


Epoch 26 Loss: 1.5808081759346857


Epoch 27/30: 100%|██████████| 45/45 [00:13<00:00,  3.31it/s]


Epoch 27 Loss: 1.5919511874516805


Epoch 28/30: 100%|██████████| 45/45 [00:13<00:00,  3.29it/s]


Epoch 28 Loss: 1.5870842297871908


Epoch 29/30: 100%|██████████| 45/45 [00:13<00:00,  3.29it/s]


Epoch 29 Loss: 1.5786115831798977


Epoch 30/30: 100%|██████████| 45/45 [00:13<00:00,  3.30it/s]


Epoch 30 Loss: 1.5666036817762587


[]

#test

In [None]:
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import CLIPProcessor, CLIPModel
import os
import json
from PIL import Image
from tqdm import tqdm

# Dataset Class for Variable Label Lengths
class ImageTextDataset(Dataset):
    def __init__(self, image_dir, json_path, processor):
        """
        Args:
            image_dir (str): Path to the directory containing images.
            json_path (str): Path to the JSON file with image-label mappings.
            processor: CLIP processor for preprocessing.
        """
        self.image_dir = image_dir
        with open(json_path, 'r') as f:
            self.data = json.load(f)
        self.processor = processor
        self.samples = [
            (image_name, label)
            for image_name, labels in self.data.items()
            for label in labels  # Flatten to image-label pairs
        ]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        image_name, label = self.samples[idx]
        image_path = os.path.join(self.image_dir, image_name)
        image = Image.open(image_path).convert("RGB")
        return image, label, image_name

# Custom Collate Function
def collate_fn(batch):
    """
    Custom collate function to handle variable-length text inputs.

    Args:
        batch: List of tuples (image, label, image_name).

    Returns:
        Processed inputs, image_names, labels.
    """
    images = [item[0] for item in batch]
    labels = [item[1] for item in batch]
    image_names = [item[2] for item in batch]

    # Process images and texts separately
    inputs = processor(
        text=labels,
        images=images,
        return_tensors="pt",
        padding=True
    )
    return inputs, image_names, labels

# Load Pre-trained CLIP Model and Processor
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# Freeze all parameters
for param in model.parameters():
    param.requires_grad = False

# Unfreeze the last few layers of the vision model
num_unfreeze_vision = 2  # Number of vision encoder layers to unfreeze
for block in model.vision_model.encoder.layers[-num_unfreeze_vision:]:
    for param in block.parameters():
        param.requires_grad = True

# Unfreeze the last few layers of the text model
num_unfreeze_text = 2  # Number of text encoder layers to unfreeze
for block in model.text_model.encoder.layers[-num_unfreeze_text:]:
    for param in block.parameters():
        param.requires_grad = True

# Unfreeze the projection layers
for param in model.visual_projection.parameters():
    param.requires_grad = True

for param in model.text_projection.parameters():
    param.requires_grad = True

# Dataset and DataLoader
image_dir = train_path  # Replace with your image directory
json_path = train_label_path  # Replace with your label file path
dataset = ImageTextDataset(image_dir, json_path, processor)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, collate_fn=collate_fn)

# Training Setup
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-6)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)
loss_fn = torch.nn.CrossEntropyLoss()

# Fine-Tuning Loop
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

epochs = 100
for epoch in range(epochs):
    model.train()
    total_loss = 0
    for batch in tqdm(dataloader, desc=f"Epoch {epoch + 1}/{epochs}"):
        inputs, _, _ = batch
        pixel_values = inputs["pixel_values"].to(device)
        input_ids = inputs["input_ids"].to(device)
        attention_mask = inputs["attention_mask"].to(device)

        # Get embeddings
        outputs = model(pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask)
        logits_per_image = outputs.logits_per_image  # Image-to-text similarity
        logits_per_text = outputs.logits_per_text  # Text-to-image similarity

        # Target: Single correct pair per batch
        targets = torch.arange(len(logits_per_image)).to(device)

        # Compute loss
        loss = (loss_fn(logits_per_image, targets) + loss_fn(logits_per_text, targets)) / 2
        total_loss += loss.item()

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Update scheduler
    avg_loss = total_loss / len(dataloader)
    scheduler.step(avg_loss)  # Step the scheduler based on average loss

    print(f"Epoch {epoch + 1} Loss: {avg_loss}")

# Save the Fine-tuned Model
output_dir = '/content/drive/My Drive/Image_text models/CLIP_fine-tunned_test/'
model.save_pretrained(output_dir)
processor.save_pretrained(output_dir)


Epoch 1/100: 100%|██████████| 45/45 [00:12<00:00,  3.55it/s]


Epoch 1 Loss: 2.3692666901482475


Epoch 2/100: 100%|██████████| 45/45 [00:12<00:00,  3.59it/s]


Epoch 2 Loss: 1.9242683675554064


Epoch 3/100: 100%|██████████| 45/45 [00:12<00:00,  3.67it/s]


Epoch 3 Loss: 1.8630759689542982


Epoch 4/100: 100%|██████████| 45/45 [00:12<00:00,  3.67it/s]


Epoch 4 Loss: 1.768665901819865


Epoch 5/100: 100%|██████████| 45/45 [00:12<00:00,  3.65it/s]


Epoch 5 Loss: 1.7280426687664456


Epoch 6/100: 100%|██████████| 45/45 [00:12<00:00,  3.67it/s]


Epoch 6 Loss: 1.7239666329489813


Epoch 7/100: 100%|██████████| 45/45 [00:12<00:00,  3.66it/s]


Epoch 7 Loss: 1.7132461918724908


Epoch 8/100: 100%|██████████| 45/45 [00:12<00:00,  3.66it/s]


Epoch 8 Loss: 1.6612792174021402


Epoch 9/100: 100%|██████████| 45/45 [00:12<00:00,  3.65it/s]


Epoch 9 Loss: 1.6644848240746393


Epoch 10/100: 100%|██████████| 45/45 [00:12<00:00,  3.66it/s]


Epoch 10 Loss: 1.6558197445339626


Epoch 11/100: 100%|██████████| 45/45 [00:12<00:00,  3.66it/s]


Epoch 11 Loss: 1.6474936167399088


Epoch 12/100: 100%|██████████| 45/45 [00:12<00:00,  3.67it/s]


Epoch 12 Loss: 1.6096653434965345


Epoch 13/100: 100%|██████████| 45/45 [00:12<00:00,  3.64it/s]


Epoch 13 Loss: 1.615617405043708


Epoch 14/100: 100%|██████████| 45/45 [00:12<00:00,  3.60it/s]


Epoch 14 Loss: 1.6098404725392659


Epoch 15/100: 100%|██████████| 45/45 [00:12<00:00,  3.68it/s]


Epoch 15 Loss: 1.602302818828159


Epoch 16/100: 100%|██████████| 45/45 [00:12<00:00,  3.64it/s]


Epoch 16 Loss: 1.6119689305623373


Epoch 17/100: 100%|██████████| 45/45 [00:12<00:00,  3.66it/s]


Epoch 17 Loss: 1.5812545352511935


Epoch 18/100: 100%|██████████| 45/45 [00:12<00:00,  3.63it/s]


Epoch 18 Loss: 1.5767789151933458


Epoch 19/100: 100%|██████████| 45/45 [00:12<00:00,  3.64it/s]


Epoch 19 Loss: 1.604551943143209


Epoch 20/100: 100%|██████████| 45/45 [00:12<00:00,  3.63it/s]


Epoch 20 Loss: 1.6257047759162055


Epoch 21/100: 100%|██████████| 45/45 [00:12<00:00,  3.64it/s]


Epoch 21 Loss: 1.6039236015743679


Epoch 22/100: 100%|██████████| 45/45 [00:12<00:00,  3.62it/s]


Epoch 22 Loss: 1.5844619989395141


Epoch 23/100: 100%|██████████| 45/45 [00:12<00:00,  3.64it/s]


Epoch 23 Loss: 1.579642497168647


Epoch 24/100: 100%|██████████| 45/45 [00:12<00:00,  3.64it/s]


Epoch 24 Loss: 1.5897045294443766


Epoch 25/100: 100%|██████████| 45/45 [00:12<00:00,  3.64it/s]


Epoch 25 Loss: 1.5419976393381754


Epoch 26/100: 100%|██████████| 45/45 [00:12<00:00,  3.63it/s]


Epoch 26 Loss: 1.5348955379592049


Epoch 27/100: 100%|██████████| 45/45 [00:12<00:00,  3.65it/s]


Epoch 27 Loss: 1.531223914358351


Epoch 28/100: 100%|██████████| 45/45 [00:12<00:00,  3.66it/s]


Epoch 28 Loss: 1.5084482696321275


Epoch 29/100: 100%|██████████| 45/45 [00:12<00:00,  3.61it/s]


Epoch 29 Loss: 1.4970042043262057


Epoch 30/100: 100%|██████████| 45/45 [00:12<00:00,  3.64it/s]


Epoch 30 Loss: 1.4944061279296874


Epoch 31/100: 100%|██████████| 45/45 [00:12<00:00,  3.62it/s]


Epoch 31 Loss: 1.519358233610789


Epoch 32/100: 100%|██████████| 45/45 [00:12<00:00,  3.64it/s]


Epoch 32 Loss: 1.4863146596484713


Epoch 33/100: 100%|██████████| 45/45 [00:12<00:00,  3.63it/s]


Epoch 33 Loss: 1.5107393821080526


Epoch 34/100: 100%|██████████| 45/45 [00:12<00:00,  3.65it/s]


Epoch 34 Loss: 1.517406235800849


Epoch 35/100: 100%|██████████| 45/45 [00:12<00:00,  3.68it/s]


Epoch 35 Loss: 1.502162684334649


Epoch 36/100: 100%|██████████| 45/45 [00:12<00:00,  3.67it/s]


Epoch 36 Loss: 1.5075544834136962


Epoch 37/100: 100%|██████████| 45/45 [00:12<00:00,  3.64it/s]


Epoch 37 Loss: 1.4848721875084772


Epoch 38/100: 100%|██████████| 45/45 [00:12<00:00,  3.60it/s]


Epoch 38 Loss: 1.517831524213155


Epoch 39/100: 100%|██████████| 45/45 [00:12<00:00,  3.63it/s]


Epoch 39 Loss: 1.5081548876232571


Epoch 40/100: 100%|██████████| 45/45 [00:12<00:00,  3.65it/s]


Epoch 40 Loss: 1.5173715538448758


Epoch 41/100: 100%|██████████| 45/45 [00:12<00:00,  3.64it/s]


Epoch 41 Loss: 1.529546512497796


Epoch 42/100: 100%|██████████| 45/45 [00:12<00:00,  3.65it/s]


Epoch 42 Loss: 1.5086971547868517


Epoch 43/100: 100%|██████████| 45/45 [00:12<00:00,  3.62it/s]


Epoch 43 Loss: 1.4815297683080038


Epoch 44/100: 100%|██████████| 45/45 [00:12<00:00,  3.62it/s]


Epoch 44 Loss: 1.5270763529671563


Epoch 45/100: 100%|██████████| 45/45 [00:12<00:00,  3.64it/s]


Epoch 45 Loss: 1.497962146335178


Epoch 46/100: 100%|██████████| 45/45 [00:12<00:00,  3.65it/s]


Epoch 46 Loss: 1.511531615257263


Epoch 47/100: 100%|██████████| 45/45 [00:12<00:00,  3.65it/s]


Epoch 47 Loss: 1.4923734267552693


Epoch 48/100: 100%|██████████| 45/45 [00:12<00:00,  3.64it/s]


Epoch 48 Loss: 1.4738168080647787


Epoch 49/100: 100%|██████████| 45/45 [00:12<00:00,  3.59it/s]


Epoch 49 Loss: 1.516599138577779


Epoch 50/100: 100%|██████████| 45/45 [00:12<00:00,  3.62it/s]


Epoch 50 Loss: 1.499723858303494


Epoch 51/100: 100%|██████████| 45/45 [00:12<00:00,  3.63it/s]


Epoch 51 Loss: 1.4972797128889295


Epoch 52/100: 100%|██████████| 45/45 [00:12<00:00,  3.63it/s]


Epoch 52 Loss: 1.5044786559210883


Epoch 53/100: 100%|██████████| 45/45 [00:12<00:00,  3.64it/s]


Epoch 53 Loss: 1.4971455971399943


Epoch 54/100: 100%|██████████| 45/45 [00:12<00:00,  3.62it/s]


Epoch 54 Loss: 1.5021616829766167


Epoch 55/100: 100%|██████████| 45/45 [00:12<00:00,  3.63it/s]


Epoch 55 Loss: 1.5137620992130703


Epoch 56/100: 100%|██████████| 45/45 [00:12<00:00,  3.62it/s]


Epoch 56 Loss: 1.4910206741756864


Epoch 57/100: 100%|██████████| 45/45 [00:12<00:00,  3.64it/s]


Epoch 57 Loss: 1.5197299506929185


Epoch 58/100: 100%|██████████| 45/45 [00:12<00:00,  3.66it/s]


Epoch 58 Loss: 1.4937411149342854


Epoch 59/100: 100%|██████████| 45/45 [00:12<00:00,  3.58it/s]


Epoch 59 Loss: 1.4985980934566923


Epoch 60/100: 100%|██████████| 45/45 [00:12<00:00,  3.62it/s]


Epoch 60 Loss: 1.4914293024275038


Epoch 61/100: 100%|██████████| 45/45 [00:12<00:00,  3.58it/s]


Epoch 61 Loss: 1.5162235312991672


Epoch 62/100: 100%|██████████| 45/45 [00:12<00:00,  3.65it/s]


Epoch 62 Loss: 1.494189206759135


Epoch 63/100: 100%|██████████| 45/45 [00:12<00:00,  3.62it/s]


Epoch 63 Loss: 1.5152239428626166


Epoch 64/100: 100%|██████████| 45/45 [00:12<00:00,  3.61it/s]


Epoch 64 Loss: 1.500564530160692


Epoch 65/100: 100%|██████████| 45/45 [00:12<00:00,  3.61it/s]


Epoch 65 Loss: 1.4679153203964233


Epoch 66/100: 100%|██████████| 45/45 [00:12<00:00,  3.64it/s]


Epoch 66 Loss: 1.4909656365712485


Epoch 67/100: 100%|██████████| 45/45 [00:12<00:00,  3.65it/s]


Epoch 67 Loss: 1.490247329076131


Epoch 68/100: 100%|██████████| 45/45 [00:12<00:00,  3.64it/s]


Epoch 68 Loss: 1.495192697313097


Epoch 69/100: 100%|██████████| 45/45 [00:12<00:00,  3.63it/s]


Epoch 69 Loss: 1.4764116684595743


Epoch 70/100: 100%|██████████| 45/45 [00:12<00:00,  3.62it/s]


Epoch 70 Loss: 1.5049020767211914


Epoch 71/100: 100%|██████████| 45/45 [00:12<00:00,  3.60it/s]


Epoch 71 Loss: 1.5268583986494275


Epoch 72/100: 100%|██████████| 45/45 [00:12<00:00,  3.64it/s]


Epoch 72 Loss: 1.5022629366980658


Epoch 73/100: 100%|██████████| 45/45 [00:12<00:00,  3.60it/s]


Epoch 73 Loss: 1.4969072182973227


Epoch 74/100: 100%|██████████| 45/45 [00:12<00:00,  3.63it/s]


Epoch 74 Loss: 1.479095803366767


Epoch 75/100: 100%|██████████| 45/45 [00:12<00:00,  3.64it/s]


Epoch 75 Loss: 1.4849638077947829


Epoch 76/100: 100%|██████████| 45/45 [00:12<00:00,  3.65it/s]


Epoch 76 Loss: 1.4826335695054795


Epoch 77/100: 100%|██████████| 45/45 [00:12<00:00,  3.66it/s]


Epoch 77 Loss: 1.4921074204974705


Epoch 78/100: 100%|██████████| 45/45 [00:12<00:00,  3.63it/s]


Epoch 78 Loss: 1.4925076749589707


Epoch 79/100: 100%|██████████| 45/45 [00:12<00:00,  3.64it/s]


Epoch 79 Loss: 1.5214261054992675


Epoch 80/100: 100%|██████████| 45/45 [00:12<00:00,  3.63it/s]


Epoch 80 Loss: 1.51809647348192


Epoch 81/100: 100%|██████████| 45/45 [00:12<00:00,  3.62it/s]


Epoch 81 Loss: 1.51360993915134


Epoch 82/100: 100%|██████████| 45/45 [00:12<00:00,  3.63it/s]


Epoch 82 Loss: 1.4885270675023397


Epoch 83/100: 100%|██████████| 45/45 [00:12<00:00,  3.61it/s]


Epoch 83 Loss: 1.5033178753323024


Epoch 84/100: 100%|██████████| 45/45 [00:12<00:00,  3.60it/s]


Epoch 84 Loss: 1.4898732079399957


Epoch 85/100: 100%|██████████| 45/45 [00:12<00:00,  3.62it/s]


Epoch 85 Loss: 1.4840957747565375


Epoch 86/100: 100%|██████████| 45/45 [00:12<00:00,  3.66it/s]


Epoch 86 Loss: 1.4925250238842434


Epoch 87/100: 100%|██████████| 45/45 [00:12<00:00,  3.65it/s]


Epoch 87 Loss: 1.5206630812750923


Epoch 88/100: 100%|██████████| 45/45 [00:12<00:00,  3.64it/s]


Epoch 88 Loss: 1.491842926873101


Epoch 89/100: 100%|██████████| 45/45 [00:12<00:00,  3.65it/s]


Epoch 89 Loss: 1.5119426250457764


Epoch 90/100: 100%|██████████| 45/45 [00:12<00:00,  3.67it/s]


Epoch 90 Loss: 1.5028040091196695


Epoch 91/100: 100%|██████████| 45/45 [00:12<00:00,  3.60it/s]


Epoch 91 Loss: 1.492215535375807


Epoch 92/100: 100%|██████████| 45/45 [00:12<00:00,  3.64it/s]


Epoch 92 Loss: 1.507954223950704


Epoch 93/100: 100%|██████████| 45/45 [00:12<00:00,  3.61it/s]


Epoch 93 Loss: 1.4969696707195705


Epoch 94/100: 100%|██████████| 45/45 [00:12<00:00,  3.63it/s]


Epoch 94 Loss: 1.511050910419888


Epoch 95/100: 100%|██████████| 45/45 [00:12<00:00,  3.65it/s]


Epoch 95 Loss: 1.5125292115741307


Epoch 96/100: 100%|██████████| 45/45 [00:12<00:00,  3.62it/s]


Epoch 96 Loss: 1.5013363149431016


Epoch 97/100: 100%|██████████| 45/45 [00:12<00:00,  3.63it/s]


Epoch 97 Loss: 1.5125376727845934


Epoch 98/100: 100%|██████████| 45/45 [00:12<00:00,  3.62it/s]


Epoch 98 Loss: 1.4972224235534668


Epoch 99/100: 100%|██████████| 45/45 [00:12<00:00,  3.65it/s]


Epoch 99 Loss: 1.5043695158428616


Epoch 100/100: 100%|██████████| 45/45 [00:12<00:00,  3.66it/s]


Epoch 100 Loss: 1.4822502970695495


[]

# Conservative loss

In [None]:
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import CLIPProcessor, CLIPModel
import os
import json
from PIL import Image
from tqdm import tqdm

# Dataset Class for Variable Label Lengths
class ImageTextDataset(Dataset):
    def __init__(self, image_dir, json_path, processor):
        """
        Args:
            image_dir (str): Path to the directory containing images.
            json_path (str): Path to the JSON file with image-label mappings.
            processor: CLIP processor for preprocessing.
        """
        self.image_dir = image_dir
        with open(json_path, 'r') as f:
            self.data = json.load(f)
        self.processor = processor
        self.samples = [
            (image_name, label)
            for image_name, labels in self.data.items()
            for label in labels  # Flatten to image-label pairs
        ]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        image_name, label = self.samples[idx]
        image_path = os.path.join(self.image_dir, image_name)
        image = Image.open(image_path).convert("RGB")
        return image, label, image_name

# Custom Collate Function
def collate_fn(batch):
    """
    Custom collate function to handle variable-length text inputs.

    Args:
        batch: List of tuples (image, label, image_name).

    Returns:
        Processed inputs, image_names, labels.
    """
    images = [item[0] for item in batch]
    labels = [item[1] for item in batch]
    image_names = [item[2] for item in batch]

    # Process images and texts separately
    inputs = processor(
        text=labels,
        images=images,
        return_tensors="pt",
        padding=True
    )
    return inputs, image_names, labels

# Load Pre-trained CLIP Model and Processor
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# Dataset and DataLoader
image_dir = train_path  # Replace with your image directory
json_path = train_label_path  # Replace with your label file path
dataset = ImageTextDataset(image_dir, json_path, processor)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, collate_fn=collate_fn)

# Training Setup
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-6)

# Fine-Tuning Loop
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

epochs = 100
for epoch in range(epochs):
    model.train()
    total_loss = 0
    for batch in tqdm(dataloader, desc=f"Epoch {epoch + 1}/{epochs}"):
        inputs, _, _ = batch
        pixel_values = inputs["pixel_values"].to(device)
        input_ids = inputs["input_ids"].to(device)
        attention_mask = inputs["attention_mask"].to(device)

        # Get embeddings
        outputs = model(pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask)
        image_embeds = outputs.image_embeds  # Image embeddings
        text_embeds = outputs.text_embeds  # Text embeddings

        # Normalize embeddings
        image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
        text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)

        # Compute similarity matrix
        logits_per_image = image_embeds @ text_embeds.t()
        logits_per_text = text_embeds @ image_embeds.t()

        # Target: Single correct pair per batch
        targets = torch.arange(len(logits_per_image)).to(device)

        # Compute loss (CLIP symmetric loss)
        loss_image = torch.nn.functional.cross_entropy(logits_per_image, targets)
        loss_text = torch.nn.functional.cross_entropy(logits_per_text, targets)
        loss = (loss_image + loss_text) / 2
        total_loss += loss.item()

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch + 1} Loss: {total_loss / len(dataloader)}")

output_dir = '/content/drive/My Drive/Image_text models/Clip_fine-tunned_conservative'

# Save the Fine-tuned Model
model.save_pretrained(output_dir)
processor.save_pretrained(output_dir)


Epoch 1/100: 100%|██████████| 45/45 [00:14<00:00,  3.02it/s]


Epoch 1 Loss: 2.61843310991923


Epoch 2/100: 100%|██████████| 45/45 [00:14<00:00,  3.15it/s]


Epoch 2 Loss: 2.391757493548923


Epoch 3/100: 100%|██████████| 45/45 [00:14<00:00,  3.13it/s]


Epoch 3 Loss: 2.316099198659261


Epoch 4/100: 100%|██████████| 45/45 [00:14<00:00,  3.12it/s]


Epoch 4 Loss: 2.2940815819634333


Epoch 5/100: 100%|██████████| 45/45 [00:14<00:00,  3.13it/s]


Epoch 5 Loss: 2.2805609597100154


Epoch 6/100: 100%|██████████| 45/45 [00:14<00:00,  3.14it/s]


Epoch 6 Loss: 2.2755589061313204


Epoch 7/100: 100%|██████████| 45/45 [00:14<00:00,  3.11it/s]


Epoch 7 Loss: 2.272165303760105


Epoch 8/100: 100%|██████████| 45/45 [00:14<00:00,  3.14it/s]


Epoch 8 Loss: 2.2694707340664335


Epoch 9/100: 100%|██████████| 45/45 [00:14<00:00,  3.13it/s]


Epoch 9 Loss: 2.261663007736206


Epoch 10/100: 100%|██████████| 45/45 [00:14<00:00,  3.15it/s]


Epoch 10 Loss: 2.2572187900543215


Epoch 11/100: 100%|██████████| 45/45 [00:14<00:00,  3.15it/s]


Epoch 11 Loss: 2.263004774517483


Epoch 12/100: 100%|██████████| 45/45 [00:14<00:00,  3.13it/s]


Epoch 12 Loss: 2.258334000905355


Epoch 13/100: 100%|██████████| 45/45 [00:14<00:00,  3.10it/s]


Epoch 13 Loss: 2.250767464107937


Epoch 14/100: 100%|██████████| 45/45 [00:14<00:00,  3.08it/s]


Epoch 14 Loss: 2.2581855773925783


Epoch 15/100: 100%|██████████| 45/45 [00:14<00:00,  3.11it/s]


Epoch 15 Loss: 2.2580406188964846


Epoch 16/100: 100%|██████████| 45/45 [00:14<00:00,  3.14it/s]


Epoch 16 Loss: 2.2546072747972277


Epoch 17/100: 100%|██████████| 45/45 [00:14<00:00,  3.11it/s]


Epoch 17 Loss: 2.2522464275360106


Epoch 18/100: 100%|██████████| 45/45 [00:14<00:00,  3.12it/s]


Epoch 18 Loss: 2.2460405190785724


Epoch 19/100: 100%|██████████| 45/45 [00:14<00:00,  3.13it/s]


Epoch 19 Loss: 2.236226028866238


Epoch 20/100: 100%|██████████| 45/45 [00:14<00:00,  3.16it/s]


Epoch 20 Loss: 2.2538653055826825


Epoch 21/100: 100%|██████████| 45/45 [00:14<00:00,  3.12it/s]


Epoch 21 Loss: 2.2429238425360785


Epoch 22/100: 100%|██████████| 45/45 [00:14<00:00,  3.11it/s]


Epoch 22 Loss: 2.24307918548584


Epoch 23/100: 100%|██████████| 45/45 [00:14<00:00,  3.13it/s]


Epoch 23 Loss: 2.242837821112739


Epoch 24/100: 100%|██████████| 45/45 [00:14<00:00,  3.10it/s]


Epoch 24 Loss: 2.239001512527466


Epoch 25/100: 100%|██████████| 45/45 [00:14<00:00,  3.11it/s]


Epoch 25 Loss: 2.2509723292456734


Epoch 26/100: 100%|██████████| 45/45 [00:14<00:00,  3.09it/s]


Epoch 26 Loss: 2.253504906760322


Epoch 27/100: 100%|██████████| 45/45 [00:14<00:00,  3.10it/s]


Epoch 27 Loss: 2.245178953806559


Epoch 28/100: 100%|██████████| 45/45 [00:14<00:00,  3.09it/s]


Epoch 28 Loss: 2.2470795101589625


Epoch 29/100: 100%|██████████| 45/45 [00:14<00:00,  3.07it/s]


Epoch 29 Loss: 2.2491931014590794


Epoch 30/100: 100%|██████████| 45/45 [00:14<00:00,  3.10it/s]


Epoch 30 Loss: 2.2438429090711804


Epoch 31/100: 100%|██████████| 45/45 [00:14<00:00,  3.10it/s]


Epoch 31 Loss: 2.245377598868476


Epoch 32/100: 100%|██████████| 45/45 [00:14<00:00,  3.11it/s]


Epoch 32 Loss: 2.2478711234198676


Epoch 33/100: 100%|██████████| 45/45 [00:14<00:00,  3.09it/s]


Epoch 33 Loss: 2.2493165810902913


Epoch 34/100: 100%|██████████| 45/45 [00:14<00:00,  3.09it/s]


Epoch 34 Loss: 2.250622918870714


Epoch 35/100: 100%|██████████| 45/45 [00:14<00:00,  3.10it/s]


Epoch 35 Loss: 2.2446562078264023


Epoch 36/100: 100%|██████████| 45/45 [00:14<00:00,  3.06it/s]


Epoch 36 Loss: 2.24194622569614


Epoch 37/100: 100%|██████████| 45/45 [00:14<00:00,  3.10it/s]


Epoch 37 Loss: 2.250500551859538


Epoch 38/100: 100%|██████████| 45/45 [00:14<00:00,  3.09it/s]


Epoch 38 Loss: 2.240746980243259


Epoch 39/100: 100%|██████████| 45/45 [00:14<00:00,  3.09it/s]


Epoch 39 Loss: 2.2384305318196613


Epoch 40/100: 100%|██████████| 45/45 [00:14<00:00,  3.07it/s]


Epoch 40 Loss: 2.2462343745761446


Epoch 41/100: 100%|██████████| 45/45 [00:14<00:00,  3.10it/s]


Epoch 41 Loss: 2.239237239625719


Epoch 42/100: 100%|██████████| 45/45 [00:14<00:00,  3.10it/s]


Epoch 42 Loss: 2.235830889807807


Epoch 43/100: 100%|██████████| 45/45 [00:14<00:00,  3.10it/s]


Epoch 43 Loss: 2.2480659749772816


Epoch 44/100: 100%|██████████| 45/45 [00:14<00:00,  3.11it/s]


Epoch 44 Loss: 2.2431240399678547


Epoch 45/100: 100%|██████████| 45/45 [00:14<00:00,  3.10it/s]


Epoch 45 Loss: 2.2387666861216227


Epoch 46/100: 100%|██████████| 45/45 [00:14<00:00,  3.10it/s]


Epoch 46 Loss: 2.231555869844225


Epoch 47/100: 100%|██████████| 45/45 [00:14<00:00,  3.09it/s]


Epoch 47 Loss: 2.247238206863403


Epoch 48/100: 100%|██████████| 45/45 [00:14<00:00,  3.11it/s]


Epoch 48 Loss: 2.243992222679986


Epoch 49/100: 100%|██████████| 45/45 [00:14<00:00,  3.10it/s]


Epoch 49 Loss: 2.247679090499878


Epoch 50/100: 100%|██████████| 45/45 [00:14<00:00,  3.11it/s]


Epoch 50 Loss: 2.240541362762451


Epoch 51/100: 100%|██████████| 45/45 [00:14<00:00,  3.08it/s]


Epoch 51 Loss: 2.2475527180565726


Epoch 52/100: 100%|██████████| 45/45 [00:14<00:00,  3.10it/s]


Epoch 52 Loss: 2.249759038289388


Epoch 53/100: 100%|██████████| 45/45 [00:14<00:00,  3.12it/s]


Epoch 53 Loss: 2.244593932893541


Epoch 54/100: 100%|██████████| 45/45 [00:14<00:00,  3.12it/s]


Epoch 54 Loss: 2.2410761621263293


Epoch 55/100: 100%|██████████| 45/45 [00:14<00:00,  3.12it/s]


Epoch 55 Loss: 2.2447218788994685


Epoch 56/100: 100%|██████████| 45/45 [00:14<00:00,  3.12it/s]


Epoch 56 Loss: 2.2509480794270833


Epoch 57/100: 100%|██████████| 45/45 [00:14<00:00,  3.21it/s]


Epoch 57 Loss: 2.2348995102776423


Epoch 58/100: 100%|██████████| 45/45 [00:13<00:00,  3.23it/s]


Epoch 58 Loss: 2.25010052257114


Epoch 59/100: 100%|██████████| 45/45 [00:13<00:00,  3.23it/s]


Epoch 59 Loss: 2.234870280159844


Epoch 60/100: 100%|██████████| 45/45 [00:13<00:00,  3.22it/s]


Epoch 60 Loss: 2.2440619574652776


Epoch 61/100: 100%|██████████| 45/45 [00:13<00:00,  3.22it/s]


Epoch 61 Loss: 2.2421982447306315


Epoch 62/100: 100%|██████████| 45/45 [00:13<00:00,  3.25it/s]


Epoch 62 Loss: 2.238228612475925


Epoch 63/100: 100%|██████████| 45/45 [00:13<00:00,  3.24it/s]


Epoch 63 Loss: 2.231844594743517


Epoch 64/100: 100%|██████████| 45/45 [00:13<00:00,  3.22it/s]


Epoch 64 Loss: 2.23938537173801


Epoch 65/100: 100%|██████████| 45/45 [00:13<00:00,  3.24it/s]


Epoch 65 Loss: 2.2455813354916043


Epoch 66/100: 100%|██████████| 45/45 [00:13<00:00,  3.23it/s]


Epoch 66 Loss: 2.233431355158488


Epoch 67/100: 100%|██████████| 45/45 [00:13<00:00,  3.24it/s]


Epoch 67 Loss: 2.236192904578315


Epoch 68/100: 100%|██████████| 45/45 [00:13<00:00,  3.25it/s]


Epoch 68 Loss: 2.2397953775193957


Epoch 69/100: 100%|██████████| 45/45 [00:13<00:00,  3.25it/s]


Epoch 69 Loss: 2.2499874697791205


Epoch 70/100: 100%|██████████| 45/45 [00:13<00:00,  3.24it/s]


Epoch 70 Loss: 2.242641639709473


Epoch 71/100: 100%|██████████| 45/45 [00:13<00:00,  3.25it/s]


Epoch 71 Loss: 2.242232338587443


Epoch 72/100: 100%|██████████| 45/45 [00:13<00:00,  3.24it/s]


Epoch 72 Loss: 2.2428574721018473


Epoch 73/100: 100%|██████████| 45/45 [00:14<00:00,  3.18it/s]


Epoch 73 Loss: 2.2356852584415012


Epoch 74/100: 100%|██████████| 45/45 [00:13<00:00,  3.24it/s]


Epoch 74 Loss: 2.2349746174282497


Epoch 75/100: 100%|██████████| 45/45 [00:13<00:00,  3.26it/s]


Epoch 75 Loss: 2.2363783253563776


Epoch 76/100: 100%|██████████| 45/45 [00:13<00:00,  3.25it/s]


Epoch 76 Loss: 2.2368129200405544


Epoch 77/100: 100%|██████████| 45/45 [00:13<00:00,  3.25it/s]


Epoch 77 Loss: 2.2394453949398465


Epoch 78/100: 100%|██████████| 45/45 [00:13<00:00,  3.23it/s]


Epoch 78 Loss: 2.243478547202216


Epoch 79/100: 100%|██████████| 45/45 [00:13<00:00,  3.25it/s]


Epoch 79 Loss: 2.2298188474443226


Epoch 80/100: 100%|██████████| 45/45 [00:13<00:00,  3.23it/s]


Epoch 80 Loss: 2.231358559926351


Epoch 81/100: 100%|██████████| 45/45 [00:13<00:00,  3.25it/s]


Epoch 81 Loss: 2.244440449608697


Epoch 82/100: 100%|██████████| 45/45 [00:13<00:00,  3.25it/s]


Epoch 82 Loss: 2.237566910849677


Epoch 83/100: 100%|██████████| 45/45 [00:13<00:00,  3.24it/s]


Epoch 83 Loss: 2.232971864276462


Epoch 84/100: 100%|██████████| 45/45 [00:13<00:00,  3.23it/s]


Epoch 84 Loss: 2.24236724641588


Epoch 85/100: 100%|██████████| 45/45 [00:13<00:00,  3.25it/s]


Epoch 85 Loss: 2.236549054251777


Epoch 86/100: 100%|██████████| 45/45 [00:14<00:00,  3.21it/s]


Epoch 86 Loss: 2.232776647143894


Epoch 87/100: 100%|██████████| 45/45 [00:13<00:00,  3.26it/s]


Epoch 87 Loss: 2.240918000539144


Epoch 88/100: 100%|██████████| 45/45 [00:13<00:00,  3.26it/s]


Epoch 88 Loss: 2.238979371388753


Epoch 89/100: 100%|██████████| 45/45 [00:13<00:00,  3.24it/s]


Epoch 89 Loss: 2.243789397345649


Epoch 90/100: 100%|██████████| 45/45 [00:13<00:00,  3.25it/s]


Epoch 90 Loss: 2.2374096976386175


Epoch 91/100: 100%|██████████| 45/45 [00:13<00:00,  3.22it/s]


Epoch 91 Loss: 2.244846789042155


Epoch 92/100: 100%|██████████| 45/45 [00:13<00:00,  3.26it/s]


Epoch 92 Loss: 2.2454488701290556


Epoch 93/100: 100%|██████████| 45/45 [00:13<00:00,  3.24it/s]


Epoch 93 Loss: 2.241017034318712


Epoch 94/100: 100%|██████████| 45/45 [00:13<00:00,  3.25it/s]


Epoch 94 Loss: 2.240464374754164


Epoch 95/100: 100%|██████████| 45/45 [00:13<00:00,  3.26it/s]


Epoch 95 Loss: 2.251050519943237


Epoch 96/100: 100%|██████████| 45/45 [00:13<00:00,  3.25it/s]


Epoch 96 Loss: 2.2407872094048393


Epoch 97/100: 100%|██████████| 45/45 [00:13<00:00,  3.25it/s]


Epoch 97 Loss: 2.241880014207628


Epoch 98/100: 100%|██████████| 45/45 [00:13<00:00,  3.25it/s]


Epoch 98 Loss: 2.23750491672092


Epoch 99/100: 100%|██████████| 45/45 [00:13<00:00,  3.26it/s]


Epoch 99 Loss: 2.241205750571357


Epoch 100/100: 100%|██████████| 45/45 [00:13<00:00,  3.25it/s]


Epoch 100 Loss: 2.241130797068278


[]