In [1]:
import json

#For Loading Images
from PIL import Image

#For displaying loadbar
from tqdm import tqdm

#Importing pytorch to finetune our clip
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

#Here, we import clip model
import clip

#Constructs a CLIP processor which wraps a CLIP image processor and a CLIP tokenizer into a single processor.
from transformers import CLIPProcessor, CLIPModel

In [2]:
json_path = 'rsicd/annotations.json'
image_path = 'rsicd/images/'

In [3]:
f = open(json_path)
data = json.load(f)
input_data = data['images']
f.close()

In [4]:
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

In [5]:
device = "cuda:0" if torch.cuda.is_available() else "cpu" 

In [6]:
model, preprocess = clip.load("ViT-B/32", device=device, jit=False)


In [7]:
class image_title_dataset():
    def __init__(self, list_image_path, list_txt):
        # Initialize image paths and corresponding texts
        self.image_path = list_image_path

        # Tokenize text using CLIP's tokenizer
        self.title = clip.tokenize(list_txt)

    def __len__(self):
        # Define the length of the dataset
        return len(self.title)

    def __getitem__(self, idx):

        # Get an item from the dataset
        # Preprocess image using CLIP's preprocessing function

        image = preprocess(Image.open(self.image_path[idx]))
        title = self.title[idx]
        return image, title

In [8]:
list_image_path = []
list_txt = []
for item in input_data:
    img_path = image_path + item['filename']
  
  #As we have image text pair, we use product title as description.
    caption = item['sentences'][0]['raw']
    #input_data[0]['sentences'][0]['raw']
    list_image_path.append(img_path)
    list_txt.append(caption)

dataset = image_title_dataset(list_image_path, list_txt)
train_dataloader = DataLoader(dataset, batch_size=50, shuffle=True) 
#Define your own dataloader

# Function to convert model's parameters to FP32 format
#This is done so that our model loads in the provided memory.
def convert_models_to_fp32(model): 
    for p in model.parameters(): 
        p.data = p.data.float() 
        p.grad.data = p.grad.data.float() 

In [9]:
# Check if the device is set to CPU
if device == "cpu":
    model.float()  # Convert the model's parameters to float if using CPU

# Prepare the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5, betas=(0.9, 0.98), eps=1e-6 ,weight_decay=0.2) 
    
# Adam optimizer is used with specific hyperparameters
# lr (learning rate) is set to 5e-5, which is considered safe for fine-tuning to a new dataset
# betas are used for the optimization algorithm
# eps is a small value to prevent division by zero
# weight_decay adds L2 regularization to the optimizer

# Specify the loss function for images
loss_img = nn.CrossEntropyLoss()

# Specify the loss function for text
loss_txt = nn.CrossEntropyLoss()

In [10]:
num_epochs = 4 # Number of training epochs
for epoch in tqdm(range(num_epochs)):
    
    # Iterate through the batches in the training data
    for batch in train_dataloader:
        optimizer.zero_grad()  # Zero out gradients for the optimizer
        
        # Extract images and texts from the batch
        images, texts = batch
        
        # Print the current device (CPU or GPU)
#         print(device)
        
        # Move images and texts to the specified device (CPU or GPU)
        images = images.to(device)
        texts = texts.to(device)

        # Forward pass through the model
        logits_per_image, logits_per_text = model(images, texts)

        # Compute the loss
        ground_truth = torch.arange(len(images), dtype=torch.long, device=device)
        total_loss = (loss_img(logits_per_image, ground_truth) + loss_txt(logits_per_text, ground_truth)) / 2

        # Backward pass and update the model's parameters
        total_loss.backward()
        
        # If the device is CPU, directly update the model
        if device == "cpu":
            optimizer.step()
        else:
            # Convert model's parameters to FP32 format, update, and convert back
            convert_models_to_fp32(model)
            optimizer.step()
            clip.model.convert_weights(model)

 25%|█████████████████████                                                               | 1/4 [01:39<04:57, 99.18s/it]

KeyboardInterrupt

