In [2]:
import json
from PIL import Image
from tqdm import tqdm
import torch
import torch.nn as nn
import clip
from transformers import CLIPProcessor, CLIPModel
from sklearn.model_selection import train_test_split
import torch
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
from torchvision.io import ImageReadMode, read_image
from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize
from torchvision.transforms.functional import InterpolationMode
# from torchvision.transforms.functional as F
import torch.optim as optim  # This import is necessary for using the optim.Adam optimizer
import torch.nn as nn
from transformers import (
    Trainer,
    TrainingArguments,
    CLIPModel,
    CLIPTokenizer,
    CLIPProcessor,
    CLIPImageProcessor
)


In [3]:

# TO ADD :
# Gradient Checkpointing
# Filter out bias from weight decay
# Decaying learning rate with cosine schedule
# Half-precision Adam statistics
# Half-precision stochastically rounded text encoder weights were used

# #BATCH_SIZE must larger than 1
# model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
# processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

device = "cuda:0" if torch.cuda.is_available() else "cpu" # If using GPU then use mixed precision training.
model, preprocess = clip.load("ViT-B/32",device=device,jit=False) #Must set jit=False for training


In [4]:
data = [] 
with open('../../../novice/vlm.jsonl', 'r') as file: 
    for line in file: 
        data.append(json.loads(line)) 

In [5]:
list_image_path = []
list_txt = []
list_bboxes = []

base_path = '../../../novice/images/'
# Extract information from data
for entry in data:
    image_path = base_path + entry['image']
    for annotation in entry['annotations']:
        list_image_path.append(image_path)
        list_txt.append(annotation['caption'])
        list_bboxes.append(annotation['bbox'])

In [6]:
#https://github.com/openai/CLIP/issues/57
def convert_models_to_fp32(model): 
    for p in model.parameters(): 
        p.data = p.data.float() 
        p.grad.data = p.grad.data.float() 

In [7]:
class image_title_dataset(Dataset):
    def __init__(self, list_image_path, list_txt, list_bboxes):

        self.image_path = list_image_path
        self.title  = clip.tokenize(list_txt) #you can tokenize everything at once in here(slow at the beginning), or tokenize it in the training loop.
        self.bboxes = list_bboxes  # List of bounding boxes

    def __len__(self):
        return len(self.title)

    def __getitem__(self, idx):
        # open image and crop it
        with Image.open(self.image_path[idx]) as img:
            bbox = self.bboxes[idx]
            left = bbox[0]
            top = bbox[1]
            width = bbox[2] + 5
            height = bbox[3] + 5

            right = left + width
            bottom = top + height

            img_cropped = img.crop((left, top, right, bottom))  # Crop image based on the bounding box

        image = preprocess(img_cropped) # Image from PIL module
        title = self.title[idx]
        return image,title

BATCH_SIZE = 64
EPOCH = 5

dataset = image_title_dataset(list_image_path,list_txt, list_bboxes)
train_dataloader = DataLoader(dataset,batch_size = BATCH_SIZE) #Define your own dataloader

if device == "cpu":
    model.float()
else:
    clip.model.convert_weights(model) # Actually this line is unnecessary since clip by default already on float16
    
# # Test the DataLoader output
# sample_batch = next(iter(train_dataloader))
# print(sample_batch)


In [9]:
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter()

In [12]:
%load_ext tensorboard


In [13]:
tensorboard --logdir /runs --load_fast true


In [10]:
loss_img = nn.CrossEntropyLoss()
loss_txt = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=5e-6,betas=(0.9,0.98),eps=1e-6,weight_decay=0.2) #Params used from paper, the lr is smaller, more safe for fine tuning to new dataset

# add your own code to track the training progress.
for epoch in range(EPOCH):
    running_loss = 0.0
    pbar = tqdm(train_dataloader, total=len(train_dataloader))
    for batch in pbar:
        # zero the parameter gradients 
        optimizer.zero_grad()

        images,texts = batch 

        images= images.to(device)
        texts = texts.to(device)

        # Forward pass
        logits_per_image, logits_per_text = model(images, texts)

        # Compute loss
        ground_truth = torch.arange(len(images),dtype=torch.long,device=device)

        total_loss = (loss_img(logits_per_image,ground_truth) + loss_txt(logits_per_text,ground_truth))/2
        
        # Backward pass
        total_loss.backward()
        running_loss += total_loss.item()
        
        if device == "cpu":
            optimizer.step()
        else: 
            convert_models_to_fp32(model)
            optimizer.step()
            clip.model.convert_weights(model)
        
        pbar.set_description(f"Epoch {epoch}/{EPOCH}, Loss: {total_loss.item():.4f}")
        writer.add_scalar("Loss x epoch", running_loss/len(train_dataloader), epoch)
        

writer.close()

Epoch 0/5, Loss: 0.3540: 100%|██████████| 234/234 [11:24<00:00,  2.92s/it]
Epoch 1/5, Loss: 0.2395: 100%|██████████| 234/234 [05:31<00:00,  1.42s/it]
Epoch 2/5, Loss: 0.2476: 100%|██████████| 234/234 [05:41<00:00,  1.46s/it]
Epoch 3/5, Loss: 0.2339: 100%|██████████| 234/234 [05:43<00:00,  1.47s/it]
Epoch 4/5, Loss: 0.2180: 100%|██████████| 234/234 [05:37<00:00,  1.44s/it]


In [12]:
torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': total_loss,
        }, f"clip-finetune/clip.pt") #just change to your preferred folder/filename

In [None]:
model, preprocess = clip.load("ViT-B/32",device=device,jit=False) #Must set jit=False for training
checkpoint = torch.load("clip-finetune/clip.pt")

# # Use these 3 lines if you use default model setting(not training setting) of the clip. For example, if you set context_length to 100 since your string is very long during training, then assign 100 to checkpoint['model_state_dict']["context_length"] 
# # checkpoint['model_state_dict']["input_resolution"] = model.input_resolution #default is 224
# checkpoint['model_state_dict']["context_length"] = model.context_length # default is 77
# checkpoint['model_state_dict']["vocab_size"] = model.vocab_size 
model.load_state_dict(checkpoint['model_state_dict'])


## Alternative Method 

In [None]:
#!/usr/bin/env python
# coding=utf-8
import logging
import json
import os
from pathlib import Path
from dataclasses import dataclass, field
from typing import Optional, Callable
import torch
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from torch.nn import functional as F
from torchvision.io import read_image, ImageReadMode
from torchvision.transforms.functional import InterpolationMode
from PIL import Image
import jax
import jax.numpy as jnp
import optax
from flax.training import train_state
from transformers import FlaxCLIPModel, CLIPProcessor, TrainingArguments, is_tensorboard_available
from tqdm import tqdm

logger = logging.getLogger(__name__)

@dataclass
class ModelArguments:
    model_name_or_path: Optional[str] = field(
        default="openai/clip-vit-base-patch32",
        metadata={"help": "The model checkpoint for weights initialization."}
    )
    dtype: Optional[str] = field(
        default="float32", 
        metadata={"help": "Floating-point format in which the model weights should be initialized."}
    )

@dataclass
class DataTrainingArguments:
    data_dir: str = field(
        default="./data",
        metadata={"help": "Path to the directory containing the data files."}
    )

class Transform(torch.nn.Module):
    def __init__(self, image_size):
        super().__init__()
        self.image_size = image_size

    def forward(self, img, bbox):
        """ Resize image keeping aspect ratio and apply transforms """
        img = F.crop(img, *bbox)  # Crop image to bounding box
        img = F.resize(img, [self.image_size, self.image_size], interpolation=Image.BICUBIC)
        img = F.to_tensor(img)
        img = F.normalize(img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        return img

class ImageTextDataset(torch.utils.data.Dataset):
    def __init__(self, json_file_path, transform=None):
        self.json_file_path = json_file_path
        self.transform = transform
        self.data = self._load_data()

    def _load_data(self):
        """ Load data from JSONL file and transform into single caption per item """
        data = []
        with open(self.json_file_path, 'r') as file:
            for line in file:
                entry = json.loads(line)
                for annotation in entry['annotations']:
                    data.append({
                        'image_path': os.path.join(self.json_file_path, entry['image']),
                        'caption': annotation['caption'],
                        'bbox': annotation['bbox']  # [top, left, width, height]
                    })
        return data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        image_path = item['image_path']
        image = read_image(image_path, mode=ImageReadMode.RGB)
        bbox = item['bbox']  # Assume bbox format is [top, left, width, height]
        if self.transform:
            image = self.transform(image, bbox)
        caption = item['caption']
        return image, caption

def main():
    model_args = ModelArguments()
    data_args = DataTrainingArguments(data_dir="../../../novice/images/vlm.jsonl")

    training_args = TrainingArguments(
        output_dir="./training_output",
        num_train_epochs=3,
        per_device_train_batch_size=32,
        save_steps=10,
        save_total_limit=5
    )

    transform = Transform(image_size=224)
    dataset = ImageTextDataset(data_args.data_dir, transform=transform)
    data_loader = torch.utils.data.DataLoader(dataset, batch_size=training_args.per_device_train_batch_size, shuffle=True)

    model = FlaxCLIPModel.from_pretrained(model_args.model_name_or_path, dtype=getattr(jnp, model_args.dtype))
    state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optax.adam(5e-5))

    rng = jax.random.PRNGKey(0)
    for epoch in range(training_args.num_train_epochs):
        for batch in tqdm(data_loader):
            state, loss = train_step(state, batch, rng)
            logger.info(f"Epoch {epoch}, Loss: {loss}")

if __name__ == "__main__":
    main()
