In [None]:
# prompt: pull this repo: "https://github.com/erenyavuz02/LightVision.git"

!git clone https://github.com/erenyavuz02/LightVision.git


Cloning into 'LightVision'...
remote: Enumerating objects: 72, done.[K
remote: Counting objects: 100% (72/72), done.[K
remote: Compressing objects: 100% (58/58), done.[K
remote: Total 72 (delta 22), reused 57 (delta 12), pack-reused 0 (from 0)[K
Receiving objects: 100% (72/72), 5.71 MiB | 19.44 MiB/s, done.
Resolving deltas: 100% (22/22), done.


# Get Pretrained Models

In [None]:
!mkdir -p checkpoints/

!wget https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_s0.pt -P checkpoints

--2025-05-11 08:35:18--  https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_s0.pt
Resolving docs-assets.developer.apple.com (docs-assets.developer.apple.com)... 17.253.118.201, 17.253.118.202, 2403:300:a32:f000::1, ...
Connecting to docs-assets.developer.apple.com (docs-assets.developer.apple.com)|17.253.118.201|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 215934653 (206M) [application/octet-stream]
Saving to: ‘checkpoints/mobileclip_s0.pt’


2025-05-11 08:35:22 (75.6 MB/s) - ‘checkpoints/mobileclip_s0.pt’ saved [215934653/215934653]



# dowload libraries if necessary

In [None]:
!pip install torch
!pip install torchvision
!pip install timm
!pip install open-clip-torch
!pip install datasets
!pip install clip-benchmark

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [None]:
import sys
sys.path.append('/content/LightVision')


# Libaries, Parameters and Model Testing

In [None]:
import torch
import torch.nn.functional as F
from torch import optim
from torch.utils.data import Dataset, DataLoader
import os
import json
from PIL import Image
from tqdm import tqdm
import mobileclip
import random
import zipfile
import requests
import io

# --- Configuration ---
# Using relative paths for better portability
BASE_DATA_DESTINATION = os.path.join(os.getcwd(), "/content/LightVision/data")
FLICKR8K_IMAGES_FOLDER_NAME = "Images"
CAPTIONS_JSON_FILENAME = "all_captions2.json"
CHECKPOINT_DIR = os.path.join(os.getcwd(), "checkpoints")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")




Load the model

In [None]:
# Load the MobileCLIP model
model_path = os.path.join(CHECKPOINT_DIR, 'mobileclip_s0.pt')
model, _, preprocess = mobileclip.create_model_and_transforms(
    'mobileclip_s0',
    pretrained=model_path
)
model.to(DEVICE)

# Load the tokenizer
tokenizer = mobileclip.get_tokenizer('mobileclip_s0')

model.eval()


CLIP(
  (image_encoder): MCi(
    (model): FastViT(
      (patch_embed): Sequential(
        (0): MobileOneBlock(
          (se): Identity()
          (activation): GELU(approximate='none')
          (reparam_conv): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        )
        (1): MobileOneBlock(
          (se): Identity()
          (activation): GELU(approximate='none')
          (reparam_conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=64)
        )
        (2): MobileOneBlock(
          (se): Identity()
          (activation): GELU(approximate='none')
          (reparam_conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
        )
      )
      (network): ModuleList(
        (0): Sequential(
          (0): RepMixerBlock(
            (token_mixer): RepMixer(
              (reparam_conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64)
            )
            (convffn): ConvFFN(
              (con

Test the base model

"A man."
, "A man sitting on a bench."
, "A man sitting on a red bench in a park."
, "A man sitting on a red bench in a park holding a yellow umbrella."
, "A man sitting on a red bench in a park holding a yellow umbrella while feeding pigeons."

In [None]:

image = preprocess(Image.open("/content/Screenshot 2025-05-11 at 01.34.57.png").convert('RGB')).unsqueeze(0)
text = tokenizer(["The lemon on the left is yellow and the eggplant on the right is purple."
, "The lemon on the left is purple and the eggplant on the right is yellow."
, "The lemon on the right is yellow and the eggplant on the left is purple."
, "The lemon on the right is purple and the eggplant on the left is yellow"])

with torch.no_grad(), torch.cuda.amp.autocast():
    image_features = model.encode_image(image.half().to(DEVICE))
    text_features = model.encode_text(text.to(DEVICE))
    image_features /= image_features.norm(dim=-1, keepdim=True)
    text_features /= text_features.norm(dim=-1, keepdim=True)

    text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)

# Set the print options for PyTorch to avoid scientific notation and limit decimal places
torch.set_printoptions(sci_mode=False, precision=4)

print("Label probs:", text_probs)

""

Label probs: tensor([[0.2498, 0.2346, 0.2659, 0.2498]], device='cuda:0')


  with torch.no_grad(), torch.cuda.amp.autocast():


''

In [None]:
tokenizer("At a train station, a group of people, including both young children and adults, are standing on a platform waiting for a train to arrive. The train is already present on the tracks, partially visible on the right side of the image. Some of the people watch the train closely, while others seem to be patiently anticipating its departure. There is a total of eight individuals  waiting ")# for the  train,  with one child in the middle of the platform and the others scattered around. A backpack can be found on the far left side of the platform, suggesting that someone may have set it down while waiting.")

tensor([[49406,   536,   320,  3231,  2631,   267,   320,  1771,   539,  1047,
           267,  2814,  2212,  1888,  2153,   537,  9391,   267,   631,  2862,
           525,   320,  5549,  2680,   556,   320,  3231,   531,  8851,   269,
           518,  3231,   533,  2426,  2881,   525,   518,  7579,   267, 21269,
          8626,   525,   518,  1155,  1145,   539,   518,  2867,   269,   836,
           539,   518,  1047,  1239,   518,  3231, 13478,   267,  1519,  3326,
          7523,   531,   655, 22980, 48067,   902, 17850,   269,   997,   533,
           320,  4445,   539,  7910, 11990,  2680, 49407]])

before : [0.2498, 0.2346, 0.2659, 0.2498]
after : [0.2783, 0.1797, 0.2456, 0.2963]

In [None]:
"A man in a red hat standing next to a yellow car in front of a green grocery store."
→ "A man wearing a bright red hat is casually standing next to a yellow car, which appears to be parked in front of a green grocery store on a sunny afternoon with some people walking nearby."

"A man in a red hat standing next to a yellow car in front of a store."
→ "A man, possibly in his mid-thirties, is seen in a red hat while standing next to a yellow car, which is parked near what seems to be a local store with large windows and some decorative signage."

"A man standing next to a yellow car in front of a green grocery store."
→ "There is a man, casually dressed, standing next to a yellow car that is parked right in front of a green-painted grocery store, where a bicycle rack and some flower pots are also visible."

"A man in a red hat standing next to a car in front of a green grocery store."
→ "Wearing a red hat and dark trousers, a man is standing next to a parked car, which is located near a green grocery store that has various posters and sale signs displayed in the window."

In [None]:

image = preprocess(Image.open("/content/test_image.png").convert('RGB')).unsqueeze(0)
text = tokenizer([ "A man in a red hat standing next to a yellow car in front of a grocery store with green signboard."
, "A man in a red hat standing next to a yellow car in front of a store."
, "A man standing next to a yellow car in front of a grocery store with green signboard."
, "A man in a red hat standing next to a car in front of a grocery store with green signboard."])

with torch.no_grad(), torch.cuda.amp.autocast():
    image_features = model.encode_image(image.half().to(DEVICE))
    text_features = model.encode_text(text.to(DEVICE))
    image_features /= image_features.norm(dim=-1, keepdim=True)
    text_features /= text_features.norm(dim=-1, keepdim=True)

    text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)

# Set the print options for PyTorch to avoid scientific notation and limit decimal places
torch.set_printoptions(sci_mode=False, precision=4)

print("Label probs:", text_probs)



Label probs: tensor([[    0.6294,     0.3700,     0.0001,     0.0005]], device='cuda:0')


  with torch.no_grad(), torch.cuda.amp.autocast():


In [None]:
tokens = tokenizer("place at the harbor where two people are looking out towards ferries in the distance. They stand next to each other near a metal railing with one person hugging the other, possibly sharing a loving moment as they look into the ocean together.")

print(tokens)

tensor([[49406,  1445,   536,   518, 10202,  1234,  1237,  1047,   631,  1312,
           620,  4447, 28489,   530,   518,  7964,   269,   889,  2087,  1131,
           531,  2416,  1010,  2252,   320,  4044,   559,  3299,   593,   637,
          2533, 27058,   518,  1010,   267,  8601,  3567,   320,  3721,  2495,
           601,   889,  1012,  1095,   518,  4918,  1952,   269, 49407,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0]])


# Changing the Positional Embeddings

In [None]:
print(model.get_positional_embedding() )

def get_positional_embedding(self, lambda2: int = 4):
    """
    Get modified positional embedding for text encoder based on the given formula.
    """
    pos_embed = self.text_encoder.get_positional_embedding().pos_embed.pos_embed
    if pos_embed is None:
        raise ValueError("Positional embedding not found in text encoder.")

    max_pos, embed_dim = pos_embed.shape[2], pos_embed.shape[3]
    modified_pos_embed = torch.zeros((1, 1, max_pos, embed_dim), device=pos_embed.device)

    for pos in range(max_pos):
        if pos <= 20:
            modified_pos_embed[:, :, pos, :] = pos_embed[:, :, pos, :]
        else:
            lower_idx = pos // lambda2
            upper_idx = min(lower_idx + 1, max_pos - 1)  # Ensure upper_idx is within bounds
            alpha = (pos % lambda2) / lambda2
            modified_pos_embed[:, :, pos, :] = (1 - alpha) * pos_embed[:, :, lower_idx, :] + alpha * pos_embed[:, :, upper_idx, :]
    # turn the torch tensor into nn parameter
    modified_pos_embed = torch.nn.Parameter(modified_pos_embed, requires_grad=False)
    return modified_pos_embed

# Example usage
lambda2 = 4
new_pos_embed = get_positional_embedding(model, lambda2)
print("Modified Positional Embedding:", new_pos_embed)

# set the models pos embedding to the new one
model.text_encoder.get_positional_embedding().pos_embed.pos_embed = new_pos_embed


LearnablePositionalEmbedding(num_embeddings=77, embedding_dim=512, padding_idx=None)
Modified Positional Embedding: Parameter containing:
tensor([[[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [ 4.0536e-03,  1.6300e-03, -7.1365e-04,  ...,  7.1379e-04,
            3.5601e-03, -7.5971e-03],
          [ 7.7365e-03,  2.6204e-03,  1.1954e-03,  ...,  1.4502e-04,
            1.3051e-03, -4.3484e-03],
          ...,
          [ 7.5012e-03,  5.1169e-03,  4.3844e-06,  ...,  1.0989e-03,
            9.3555e-05, -3.0923e-04],
          [ 3.1206e-03,  6.8456e-03, -7.9795e-04,  ...,  2.3707e-03,
           -1.6804e-04, -2.4519e-03],
          [-1.2599e-03,  8.5743e-03, -1.6003e-03,  ...,  3.6424e-03,
           -4.2963e-04, -4.5946e-03]]]], device='cuda:0')


# Testing the model after changing the positional embeddings

In [None]:
image = preprocess(Image.open("/content/LightVision/pngwing.com.png").convert('RGB')).unsqueeze(0)
image = image.to(DEVICE)
text = tokenizer(["a brown dog", "a white dog", "a black dog"])
text = text.to(DEVICE)

with torch.no_grad(), torch.cuda.amp.autocast():
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)
    image_features /= image_features.norm(dim=-1, keepdim=True)
    text_features /= text_features.norm(dim=-1, keepdim=True)

    text_probs = (100.0 * image_features @ text_features.T)

# Set the print options for PyTorch to avoid scientific notation and limit decimal places
torch.set_printoptions(sci_mode=False, precision=4)

print("Label probs:", text_probs)


Label probs: tensor([[19.4844, 19.2188, 19.0000]], device='cuda:0', dtype=torch.float16)


  with torch.no_grad(), torch.cuda.amp.autocast():


# Downloading the captioned images


In [None]:
# --- Configuration ---
import os

# Using relative paths for better portability
# This creates a 'data' directory in the project folder

KAGGLE_FLICKR8K_URL = "https://www.kaggle.com/api/v1/datasets/download/adityajn105/flickr8k"
FLICKR8K_ZIP_FILENAME = "flickr8k.zip"
FLICKR8K_IMAGES_FOLDER_NAME = "Images"
CAPTIONS_CSV_FILENAME = "captions.txt"
OUTPUT_FOLDER_NAME = "output"


def download_file(url: str, destination_path: str):
    print(f"Downloading from {url} to {destination_path}...")
    try:
        response = requests.get(url, stream=True)
        response.raise_for_status()
        with open(destination_path, "wb") as file:
            for chunk in response.iter_content(chunk_size=1024):
                if chunk:
                    file.write(chunk)
        print("Download complete.")
    except Exception as e:
        print(f"Error: {e}")
        raise

def extract_zip_file(zip_path: str, destination_folder: str):
    print(f"Extracting {zip_path} to {destination_folder}...")
    try:
        with zipfile.ZipFile(zip_path, "r") as zip_ref:
            zip_ref.extractall(destination_folder)
        print("Extraction complete.")
    except Exception as e:
        print(f"Extraction error: {e}")
        raise

def setup_data_directory(base_data_path: str):
    images_path = os.path.join(base_data_path, FLICKR8K_IMAGES_FOLDER_NAME)
    output_path = os.path.join(base_data_path, OUTPUT_FOLDER_NAME)
    os.makedirs(base_data_path, exist_ok=True)
    os.makedirs(output_path, exist_ok=True)
    return base_data_path, images_path, output_path


print(f"Setting up data directories in: {BASE_DATA_DESTINATION}")
base_dir, images_dir, output_dir = setup_data_directory(BASE_DATA_DESTINATION)

zip_file_path = os.path.join(base_dir, FLICKR8K_ZIP_FILENAME)

if not os.path.exists(images_dir):
    print("Images not found. Attempting download...")
    try:
        download_file(KAGGLE_FLICKR8K_URL, zip_file_path)
        extract_zip_file(zip_file_path, base_dir)
        os.remove(zip_file_path)
    except Exception as e:
        print(f"Failed to set up dataset: {e}")
        raise FileNotFoundError(f"Please manually download and extract to: {base_dir}")
else:
    print(f"Images already exist at {images_dir}.")

Setting up data directories in: /content/LightVision/data
Images already exist at /content/LightVision/data/Images.


# Train the model using the downloaded images and custom captions

Importing the required libraries
setting parameters and lookups


## Load the dataset

In [None]:


# Custom Dataset for Flickr8k with the specific JSON caption format
class Flickr8kCaptionedDataset(Dataset):
    def __init__(self, image_dir, captions_file, preprocess_fn, pull_from_json=True):
        self.image_dir = image_dir
        self.preprocess_fn = preprocess_fn

        self.num_samples = 0
        # Create list of samples
        self.samples = []

        if pull_from_json:
            # Load captions from JSON file
            with open(captions_file, 'r') as f:
                self.captions_data = json.load(f)

            # Process JSON with format {"image.jpg": {"long_caption": "...", "short_caption": "..."}, ...}
            for image_name, captions in self.captions_data.items():
                if "long_caption" in captions and "short_caption" in captions:
                    #if image is not in the image directory, skip
                    image_path = os.path.join(self.image_dir, image_name)
                    if not os.path.exists(image_path):
                        print(f"Image {image_path} not found, skipping.")
                        continue
                    # Add both caption types for each image
                    self.samples.append((image_name, captions["short_caption"], captions["long_caption"]))
        else:
            # Use the default Flickr8k captions file
            captions_file = "/content/LightVision/data/captions.txt"
            with open(captions_file, 'r') as f:
                lines = f.readlines()

            # Process the standard Flickr8k format
            # Typically each line has format: "image_name#caption" or "image_name,caption"
            for line in lines:
                line = line.strip()
                if line:
                    # Try to split by common delimiters
                    if '#' in line:
                        parts = line.split('#', 1)
                    else:
                        parts = line.split(',', 1)

                    if len(parts) == 2:
                        image_name, caption = parts
                        #print(f"Image name: {image_name.strip()}, Caption: {caption.strip()}")
                        # Add this check before appending to self.samples
                        image_path = os.path.join(self.image_dir, image_name.strip())
                        if not os.path.exists(image_path):
                            continue  # Skip this caption if image doesn't exist
                        self.samples.append((image_name.strip(), caption.strip(), "standard"))
                        self.num_samples += 1

        print(f"Loaded {len(self.samples)} samples from {captions_file}.")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        image_name, caption, caption_type = self.samples[idx]
        image_path = os.path.join(self.image_dir, image_name)

        # Load and preprocess the image
        try:
            image = Image.open(image_path).convert('RGB')
            image = self.preprocess_fn(image)
        except Exception as e:
            print(f"Error loading image {image_path}: {e}")
            # Return a random valid sample instead
            return self.__getitem__(random.randint(0, len(self) - 1))

        return image, caption, caption_type

    def __reduce__(self):
        return (self.__class__, (self.image_dir, self.captions_file, self.preprocess_fn, True))


# PCA of LongCLIP

In [None]:
#rewrite PCA to avoid inf
def PCA(input_tensor, PCA_dim):
    # 计算均值
    mean = torch.mean(input_tensor, dim=0)
    # 去均值
    X_centered = input_tensor - mean.unsqueeze(0)
    X_centered = X_centered.float()

    # 使用SVD而不是eig来计算主成分
    U, S, Vt = torch.linalg.svd(X_centered, full_matrices=False)
    principal_components = Vt.T[:, :PCA_dim]

    # 转换到新的维度
    X_transformed = torch.mm(X_centered, principal_components)
    # 恢复到原始空间
    X_reversed = torch.mm(X_transformed, principal_components.T)
    X_reversed += mean

    return X_reversed

## Loss functions


In [None]:
# Contrastive Loss Function
def single_loss(image_embeds, text_embeds, temperature=0.07):
    # Normalize embeddings
    image_embeds = F.normalize(image_embeds, dim=1)
    text_embeds = F.normalize(text_embeds, dim=1)

    # Compute similarity matrix
    logits = torch.matmul(image_embeds, text_embeds.T) / temperature

    # Labels are the positions of the positive pairs
    labels = torch.arange(logits.size(0), device=logits.device)

    # Compute loss in both directions (image->text and text->image)
    loss_i2t = F.cross_entropy(logits, labels)
    loss_t2i = F.cross_entropy(logits.T, labels)

    return (loss_i2t + loss_t2i) / 2

In [None]:
def long_clip_loss(image_embedding, long_embedding, short_embedding):
    image_features_long = image_embedding
    text_features_long = long_embedding
    text_features_short = short_embedding

    # Normalize features
    image_features_long = image_features_long / image_features_long.norm(dim=1, keepdim=True)
    text_features_long = text_features_long / text_features_long.norm(dim=1, keepdim=True)
    text_features_short = text_features_short / text_features_short.norm(dim=1, keepdim=True)

    # Apply PCA to get compressed image features
    image_features_short = PCA(image_features_long, 32)
    image_features_short = image_features_short / image_features_short.norm(dim=1, keepdim=True)

    # Since we're not using distributed training, simplify this part
    image_feat_all_long = image_features_long
    image_features_all_short = image_features_short
    text_feat_all_long = text_features_long
    text_feat_all_short = text_features_short

    # Calculate similarity matrices
    sim_i2tl = torch.matmul(image_features_long, text_feat_all_long.T)
    sim_tl2i = torch.matmul(image_feat_all_long, text_features_long.T)
    sim_tl2i = sim_tl2i.T

    sim_i2ts = torch.matmul(image_features_short, text_feat_all_short.T)
    sim_ts2i = torch.matmul(image_features_all_short, text_features_short.T)
    sim_ts2i = sim_ts2i.T

    # Apply temperature scaling
    logit_scale = model.logit_scale if hasattr(model, 'logit_scale') else 1.0

    if isinstance(logit_scale, torch.nn.Parameter):
        sim_i2tl = logit_scale.exp() * sim_i2tl
        sim_tl2i = logit_scale.exp() * sim_tl2i
        sim_i2ts = logit_scale.exp() * sim_i2ts
        sim_ts2i = logit_scale.exp() * sim_ts2i

    # Create targets for loss calculation
    bs = image_embedding.size(0)
    targets = torch.arange(bs, device=image_embedding.device)

    # Calculate losses
    loss_itcl = (
        F.cross_entropy(sim_i2tl, targets, label_smoothing=0.1)
        + F.cross_entropy(sim_tl2i, targets, label_smoothing=0.1)
    ) / 2

    loss_itcs = (
        F.cross_entropy(sim_i2ts, targets, label_smoothing=0.1)
        + F.cross_entropy(sim_ts2i, targets, label_smoothing=0.1)
    ) / 2

    # single loss by combining the two
    total_loss = (loss_itcl + loss_itcs) / 2

    return total_loss

## Training the model

In [None]:
def train_model(
    images_dir,
    captions_file,
    checkpoint_dir,
    device='cuda',
    batch_size=128,
    learning_rate=1e-4,
    num_epochs=10,
    num_workers=0,
    pull_from_json=False,
    long_clip_loss_fn=None,
    single_loss_fn=None
):
    """
    Train a CLIP model with the given parameters.

    Args:
        images_dir: Directory containing images
        captions_file: Path to captions JSON file
        checkpoint_dir: Directory to save checkpoints
        device: Device to train on ('cuda' or 'cpu')
        batch_size: Batch size for training
        learning_rate: Learning rate for optimizer
        num_epochs: Number of training epochs
        num_workers: Number of dataloader workers
        pull_from_json: Whether to pull captions from JSON
        long_clip_loss_fn: Loss function for long captions
        single_loss_fn: Loss function for single captions
    """

    # Check if files exist
    if not os.path.exists(images_dir):
        raise FileNotFoundError(f"Images directory not found: {images_dir}")
    if not os.path.exists(captions_file):
        print(f"Captions file not found: {captions_file}")
        pull_from_json = False


    dataset = Flickr8kCaptionedDataset(images_dir, captions_file, preprocess, pull_from_json=pull_from_json)
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        drop_last=True
    )

    # Optimizer
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # Make sure model is in training mode
    model.train()

    for epoch in range(num_epochs):
        total_loss = 0.0
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")

        for batch_idx, (images, captions, long_captions) in enumerate(progress_bar):
            images = images.to(device)

            # Tokenize the captions
            tokenized_captions = tokenizer(captions).to(device)

            # Forward pass with mixed precision
            with torch.cuda.amp.autocast():
                image_features = model.encode_image(images)
                text_features = model.encode_text(tokenized_captions)

                # Compute contrastive loss
                if long_captions is not None and long_clip_loss_fn is not None:
                    long_captions = tokenizer(long_captions).to(device)
                    long_text_features = model.encode_text(long_captions)
                    loss = long_clip_loss_fn(image_features, text_features, long_text_features)
                else:
                    # Use single loss if long captions are not available
                    if single_loss_fn is None:
                        raise ValueError("Single loss function must be provided")
                    loss = single_loss_fn(image_features, text_features)

            # Backward pass and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Update progress bar
            total_loss += loss.item()
            avg_loss = total_loss / (batch_idx + 1)
            progress_bar.set_postfix(loss=f"{avg_loss:.4f}")

        # Print average loss for the epoch
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")

        # Save checkpoint
        checkpoint_path = os.path.join(checkpoint_dir, f"mobileclip_finetuned_epoch{epoch+1}_last.pt")
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_loss,
        }, checkpoint_path)
        print(f"Checkpoint saved: {checkpoint_path}")

    return model


## Running the training loop

In [None]:
images_dir = os.path.join(BASE_DATA_DESTINATION, FLICKR8K_IMAGES_FOLDER_NAME)
captions_file = os.path.join(BASE_DATA_DESTINATION, CAPTIONS_JSON_FILENAME)
checkpoint_dir = "checkpoints"
"""
# Train the model
trained_model = train_model(
    images_dir=images_dir,
    captions_file="/content/LightVision/data/captions.txt",
    checkpoint_dir=checkpoint_dir,
    device=DEVICE,
    batch_size=256,
    learning_rate=1e-4,
    num_epochs=0,
    num_workers=0,
    pull_from_json=False,
    long_clip_loss_fn=long_clip_loss,
    single_loss_fn=single_loss
)
"""
# Train the model
trained_model = train_model(
    images_dir=images_dir,
    captions_file="/content/LightVision/data/all_captions.json",
    checkpoint_dir=checkpoint_dir,
    device=DEVICE,
    batch_size=128,
    learning_rate=1e-4,
    num_epochs=1,
    num_workers=0,
    pull_from_json=True,
    long_clip_loss_fn=long_clip_loss,
    single_loss_fn=single_loss
)

# Train the model
trained_model = train_model(
    images_dir=images_dir,
    captions_file="/content/LightVision/data/new_file.json",
    checkpoint_dir=checkpoint_dir,
    device=DEVICE,
    batch_size=128,
    learning_rate=1e-4,
    num_epochs=1,
    num_workers=0,
    pull_from_json=True,
    long_clip_loss_fn=long_clip_loss,
    single_loss_fn=single_loss
)

model.eval()


Loaded 1074 samples from /content/LightVision/data/all_captions.json.


  with torch.cuda.amp.autocast():
Epoch 1/1: 100%|██████████| 8/8 [00:10<00:00,  1.29s/it, loss=1.0873]


Epoch 1/1, Loss: 1.0873
Checkpoint saved: checkpoints/mobileclip_finetuned_epoch1_last.pt
Loaded 5840 samples from /content/LightVision/data/new_file.json.


Epoch 1/1: 100%|██████████| 45/45 [00:57<00:00,  1.29s/it, loss=1.0028]


Epoch 1/1, Loss: 1.0028
Checkpoint saved: checkpoints/mobileclip_finetuned_epoch1_last.pt


CLIP(
  (image_encoder): MCi(
    (model): FastViT(
      (patch_embed): Sequential(
        (0): MobileOneBlock(
          (se): Identity()
          (activation): GELU(approximate='none')
          (reparam_conv): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        )
        (1): MobileOneBlock(
          (se): Identity()
          (activation): GELU(approximate='none')
          (reparam_conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=64)
        )
        (2): MobileOneBlock(
          (se): Identity()
          (activation): GELU(approximate='none')
          (reparam_conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
        )
      )
      (network): ModuleList(
        (0): Sequential(
          (0): RepMixerBlock(
            (token_mixer): RepMixer(
              (reparam_conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64)
            )
            (convffn): ConvFFN(
              (con

Load the trained checkpoint

In [None]:
# If not trained, load the trained checkpoint /content/checkpoints/mobileclip_finetuned_epoch4.pt
checkpoint_path = os.path.join(CHECKPOINT_DIR, 'mobileclip_finetuned_epoch1_last.pt')
checkpoint = torch.load(checkpoint_path, map_location=DEVICE)

model.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

# Evaluate the model with short captions

In [None]:



# Example evaluation
test_image_path = "/content/ChatGPT Image 11 May 2025 01_24_53.png"
test_texts = ["A man."
, "A man sitting on a bench."
, "A man sitting on a red bench in a park."
, "A man sitting on a red bench in a park holding a yellow umbrella."
, "A man sitting on a red bench in a park holding a yellow umbrella while feeding pigeons."]


test_image = preprocess(Image.open(test_image_path).convert('RGB')).unsqueeze(0).to(DEVICE)
test_text = tokenizer(test_texts).to(DEVICE)

with torch.no_grad(), torch.cuda.amp.autocast():
    image_features = model.encode_image(test_image)
    text_features = model.encode_text(test_text)
    image_features /= image_features.norm(dim=-1, keepdim=True)
    text_features /= text_features.norm(dim=-1, keepdim=True)

    text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)

torch.set_printoptions(sci_mode=False, precision=4)
print("Label probabilities after training:", text_probs)

"[0.5692, 0.4296, 0.0012]]"


Label probabilities after training: tensor([[    0.0001,     0.0145,     0.0701,     0.4362,     0.4791]],
       device='cuda:0')


  with torch.no_grad(), torch.cuda.amp.autocast():


'[0.5692, 0.4296, 0.0012]]'

Base: [[0.9789,0.0004,  0.0206]]
pos embed: [[0.4075, 0.3124, 0.2801]
after: ([[0.7447, 0.1197, 0.1356]]

# Evalute the model with long captions

In [None]:
import json

def remove_first_n_words(text, n=3):
    return ' '.join(text.split()[n:])

with open("/content/captions_database-10.json", "r") as f:
    data = json.load(f)

for v in data.values():
    v["long_caption"] = remove_first_n_words(v.pop("long_detailed"))

with open("new_file.json", "w") as f:
    json.dump(data, f, indent=2)