In [1]:
%cd ..

/Users/hunarbatra/Hunar/oxford/dphil/svt-llava


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


In [2]:
from SemCLIP.semclip import SemCLIP
from SemCLIP.image_utils import DEVICE

semclip = SemCLIP(model_name="openai/clip-vit-base-patch32", pool_type='attention', projection_dim=512, device=DEVICE)

In [3]:
import torch

from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset

dataset = load_dataset("hunarbatra/CLIP-LLaVA-Instruct-COCO-13k")

## Download Data Files

In [17]:
import os
import requests
from datasets import load_dataset


dataset = load_dataset("hunarbatra/CLIP-LLaVA-Instruct-COCO-13k")
dataset = dataset.remove_columns("downloaded_img")

save_dir = "data/CLIP-LLaVA-Instruct-COCO-13k/"

os.makedirs(save_dir, exist_ok=True)

def download_image(record):
    url = record['coco_url']
    file_name = os.path.join(save_dir, f"{record['image']}")
    try:
        response = requests.get(url, stream=True)
        if response.status_code == 200:
            with open(file_name, 'wb') as f:
                for chunk in response.iter_content(1024):
                    f.write(chunk)
        record['local_path'] = file_name
    except Exception as e:
        print(f"Failed to download {url}: {e}")
        record['local_path'] = None
    return record

dataset = dataset.map(download_image, num_proc=8)

# dataset.save_to_disk("CLIP-LLaVA-Instruct-COCO-13k-no-img")


Map (num_proc=8):   0%|          | 0/13430 [00:00<?, ? examples/s]

Map (num_proc=8):   0%|          | 0/3365 [00:00<?, ? examples/s]

## Finetune Model with local data/

In [None]:
import clip

from tqdm import tqdm


train_loader = DataLoader(dataset['train'], batch_size=64, shuffle=True)
test_loader = DataLoader(dataset['validation'], batch_size=64, shuffle=True)

model = semclip.model.to(DEVICE)
processor = semclip.processor

def convert_models_to_fp32(model): 
    for p in model.parameters(): 
        p.data = p.data.float() 
        p.grad.data = p.grad.data.float() 
        
def create_batches(dataset, batch_size):
    for i in range(0, len(dataset), batch_size):
        yield dataset[i:i + batch_size]
        
if DEVICE == "cpu":
    model.float() # convert the model params to float if using CPU
    
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, betas=(0.9, 0.98), eps=1e-6, weight_decay=0.2) # weight decay adds L2 regularization to the optimizer

loss = torch.nn.CrossEntropyLoss()

# Training loop
num_epochs = 1
batch_size = 64
data_dir = 'CLIP-LLaVA-Instruct-COCO-13k'

for epoch in range(num_epochs):
    model.train()
    train_loader = create_batches(dataset['train'], batch_size)
    
    pbar = tqdm(train_loader, total=len(dataset['train']) // batch_size, desc=f"Epoch {epoch+1}/{num_epochs}")

    for batch in pbar:
        optimizer.zero_grad()
        
        image_batch = batch['image']
        text_batch = batch['caption']
        
        # Forward pass through the model
        with torch.no_grad():
            logits_per_image, logits_per_text = semclip.get_semclip_embeddings(images=image_batch, captions=text_batch, images_folder=data_dir)
            
        # Compute the loss
        ground_truth = torch.arange(len(image_batch), dtype=torch.long, device=DEVICE)
        total_loss = (loss_img(logits_per_image, ground_truth) + loss_txt(logits_per_text, ground_truth)) / 2

        # Backward pass
        total_loss.backward()
        
        # if the device is CPU, directly update the model
        if DEVICE == "cpu":
            optimizer.step()
        else:
            convert_models_to_fp32(model)
            optimizer.step()
            clip.model.convert_weights(model)
        
        # Update the progress bar with the current loss
        pbar.set_postfix(Loss=total_loss.item())


## Finetune Model without downloading images locally

In [None]:
import clip

from tqdm import tqdm

from PIL import Image
import cv2
import numpy as np

model = semclip.model.to(DEVICE)

def convert_models_to_fp32(model): 
    for p in model.parameters(): 
        p.data = p.data.float() 
        p.grad.data = p.grad.data.float() 
        
def create_batches(dataset, batch_size):
    for i in range(0, len(dataset), batch_size):
        yield dataset[i:i + batch_size]
        
def pil_to_cv2(pil_image):
    open_cv_image = np.array(pil_image) 
    # Convert RGB to BGR 
    open_cv_image = open_cv_image[:, :, ::-1].copy()
    return open_cv_image
        
if DEVICE == "cpu":
    model.float() # convert the model params to float if using CPU
    
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, betas=(0.9, 0.98), eps=1e-6, weight_decay=0.2) # weight decay adds L2 regularization to the optimizer

loss = torch.nn.CrossEntropyLoss()

# Training loop
num_epochs = 1
batch_size = 64
# data_dir = 'CLIP-LLaVA-Instruct-COCO-13k'

for epoch in range(num_epochs):
    model.train()
    train_loader = create_batches(dataset['train'], batch_size)
    
    pbar = tqdm(train_loader, total=len(dataset['train']) // batch_size, desc=f"Epoch {epoch+1}/{num_epochs}")

    for batch in pbar:
        optimizer.zero_grad()
        
        image_batch_pil = batch['downloaded_img']
        text_batch = batch['caption']
        
        # Convert the batch of PIL images to OpenCV images
        image_batch_cv2 = [pil_to_cv2(img) for img in image_batch_pil]
        
        # Forward pass through the model
        with torch.no_grad():
            logits_per_image, logits_per_text = semclip.get_semclip_embeddings_direct_img(images=image_batch_cv2, captions=text_batch)
            
        # Compute the loss
        ground_truth = torch.arange(len(image_batch), dtype=torch.long, device=DEVICE)
        total_loss = (loss_img(logits_per_image, ground_truth) + loss_txt(logits_per_text, ground_truth)) / 2

        # Backward pass
        total_loss.backward()
        
        # if the device is CPU, directly update the model
        if DEVICE == "cpu":
            optimizer.step()
        else:
            convert_models_to_fp32(model)
            optimizer.step()
            clip.model.convert_weights(model)
        
        # Update the progress bar with the current loss
        pbar.set_postfix(Loss=total_loss.item())


In [None]:
semclip.model = model

semclip.upload_model_to_hf_hub(model_name='semclip-v1', hf_name='hunarbatra')