In [2]:
%cd ..

/Users/hunarbatra/Hunar/oxford/dphil/svt-llava


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


In [10]:
# !pip3 install -r requirements.txt
# !pip3 install supervision torch tqdm fire datasets opencv-python openai-clip huggingface-hub torch python-dotenv clip torchvision Pillow pandas numpy matplotlib transformers
# !pip install --upgrade notebook
# !pip install --upgrade ipywidgets
# !jupyter nbextension enable --py widgetsnbextension

In [11]:
# !bash sam_model_setup.sh
# !pip3 install -q 'git+https://github.com/facebookresearch/segment-anything.git'

In [3]:
from SemCLIP.semclip import SemCLIP
from SemCLIP.image_utils import DEVICE, create_batches, pil_to_cv2
from SemCLIP.model_utils import convert_models_to_fp32, convert_models_to_fp16


semclip = SemCLIP(model_name="openai/clip-vit-base-patch32", pool_type='attention', projection_dim=512, device=DEVICE)

In [5]:
import torch

from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset

from config import dataset_mapper

dataset = load_dataset(dataset_mapper['COCO-13k'])

## Download Data Files Locally [Not required to run]

In [None]:
import os
import requests
from datasets import load_dataset


dataset = load_dataset("hunarbatra/CLIP-LLaVA-Instruct-COCO-13k")
dataset = dataset.remove_columns("downloaded_img")

save_dir = "data/CLIP-LLaVA-Instruct-COCO-13k/"

os.makedirs(save_dir, exist_ok=True)

def download_image(record):
    url = record['coco_url']
    file_name = os.path.join(save_dir, f"{record['image']}")
    try:
        response = requests.get(url, stream=True)
        if response.status_code == 200:
            with open(file_name, 'wb') as f:
                for chunk in response.iter_content(1024):
                    f.write(chunk)
        record['local_path'] = file_name
    except Exception as e:
        print(f"Failed to download {url}: {e}")
        record['local_path'] = None
    return record

dataset = dataset.map(download_image, num_proc=8)

# dataset.save_to_disk("CLIP-LLaVA-Instruct-COCO-13k-no-img")


Map (num_proc=8):   0%|          | 0/13430 [00:00<?, ? examples/s]

Map (num_proc=8):   0%|          | 0/3365 [00:00<?, ? examples/s]

## Finetune Model with local data/ [Not required to run - only if data/ files are local]

In [None]:
import cv2
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
from PIL import Image


semclip.model.to(DEVICE)

if DEVICE == "cpu":
    semclip.model.float() # convert the model params to float if using CPU

optimizer = torch.optim.Adam(semclip.model.parameters(), lr=1e-5)

loss = torch.nn.CrossEntropyLoss()

# Training loop
num_epochs = 1
batch_size = 64
data_dir = 'CLIP-LLaVA-Instruct-COCO-13k'

for epoch in range(num_epochs):
    semclip.model.train()
    train_loader = create_batches(dataset['train'], batch_size)
    
    pbar = tqdm(train_loader, total=len(dataset['train']) // batch_size, desc=f"Epoch {epoch+1}/{num_epochs}")

    for batch in pbar:
        optimizer.zero_grad()
        
        image_batch = batch['image']
        text_batch = batch['caption']
        
        if DEVICE != "cpu":
            convert_models_to_fp32(semclip.model)
            
        # Forward pass through the model
        try:
            image_embeddings, text_embeddings = semclip.get_semclip_embeddings_direct_img(images=image_batch_cv2, captions=text_batch)
        except Exception as e:
            print(f"error: {e}; images batch being processed: {batch['image']}")
            continue
        
        # Process final embeddings (normalize, compute logits)
        logits_per_image, logits_per_text = semclip.process_final_embeddings(image_embeddings, text_embeddings)
            
        # Compute the loss
        ground_truth = torch.arange(len(logits)).to(DEVICE)
        text_loss = loss_fn(logits_per_text, ground_truth) # contrastive loss
        image_loss = loss_fn(logits_per_text.t(), ground_truth) # contrastive loss
        total_loss = (text_loss + image_loss) / 2.0
            
        # Backward pass
        total_loss.backward()
        
        # if the device is CPU, directly update the model
        if DEVICE == "cpu":
            optimizer.step()
        else:
            convert_models_to_fp32(semclip.model)
            optimizer.step()
            convert_models_to_fp16(semclip.model)
        
        # Update the progress bar with the current loss
        pbar.set_postfix(Loss=total_loss.item())

## Finetune Model without downloading images locally [with HF data]

In [14]:
import os
import cv2
import torch
import wandb

import torch.nn as nn
import numpy as np

from tqdm import tqdm
from PIL import Image


wandb.login(key=os.getenv("WANDB_API_KEY"))

semclip.model.to(DEVICE)

learning_rate = 1e-4 # huggingface trainer default lr
betas = (0.9, 0.98) # huggingface trainer default betas
epsilon = 1e-6 # huggingface trainer default epsilon
weight_decay = 0.2 # L2 regularization - finetuning CLIP with a small dataset can lead to overfitting so we add L2 regularization # 0.001
num_epochs = 1 # note: huggingface trainer default num_train_epochs = 3
batch_size = 1

wandb_config = {
    "learning_rate": learning_rate,
    "betas": betas,
    "epsilon": epsilon,
    "weight_decay": weight_decay,
    "num_epochs": num_epochs,
    "batch_size": batch_size,
}

if DEVICE == "cpu":
    semclip.model.float() # convert the model params to float if using CPU

optimizer = torch.optim.AdamW(semclip.model.parameters(), lr=learning_rate, betas=betas, eps=epsilon, weight_decay=weight_decay)

loss_fn = torch.nn.CrossEntropyLoss()

train_loader = create_batches(dataset['train'], batch_size)

resume_training = False
train_name = 'semclip-v1-test'
checkpoint_dir = "model_ckpts"
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_path = os.path.join(checkpoint_dir, f"{train_name}_checkpoint.pth")

start_epoch = 0
start_batch = 0
wandb_run_id = None

if resume_training and os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location=DEVICE)
    semclip.model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch']
    start_batch = checkpoint['batch'] + 1
    wandb_run_id = checkpoint['wandb_run_id']
    print(f"Resuming training from epoch {start_epoch}, batch {start_batch}")
    
if wandb_run_id:
    wandb.init(project="semclip", name=train_name, id=wandb_run_id, resume="must", config=wandb_config)
else:
    wandb_run = wandb.init(project="semclip", name=train_name, config=wandb_config)
    wandb_run_id = wandb_run.id

for epoch in range(start_epoch, num_epochs):
    semclip.model.train()
    
    pbar = tqdm(train_loader, total=len(dataset['train']) // batch_size, desc=f"Epoch {epoch+1}/{num_epochs}")

    for batch_idx, batch in enumerate(pbar):
        if epoch == start_epoch and batch_idx < start_batch:
            continue  # Skip batches already processed in the current epoch if resuming
        
        optimizer.zero_grad()
        
        image_batch_pil = batch['downloaded_img']
        text_batch = batch['caption']
        
        # Convert the batch of PIL images to OpenCV images
        image_batch_cv2 = [pil_to_cv2(img) for img in image_batch_pil]
        
        if DEVICE != "cpu":
            convert_models_to_fp32(semclip.model)

        # Forward pass through the model
        try:
            image_embeddings, text_embeddings = semclip.get_semclip_embeddings_direct_img(images=image_batch_cv2, captions=text_batch)
        except Exception as e:
            print(f"error: {e}; images batch being processed: {batch['image']}")
            continue
        
        # Process final embeddings (normalize, compute logits)
        logits_per_image, logits_per_text = semclip.process_final_embeddings(image_embeddings, text_embeddings)
            
        # Compute the loss
        # ground_truth = torch.arange(len(logits_per_text)).to(DEVICE)
        ground_truth = torch.arange(batch_size).to(DEVICE)
        text_loss = loss_fn(logits_per_text, ground_truth) # contrastive loss
        image_loss = loss_fn(logits_per_text.t(), ground_truth) # contrastive loss
        total_loss = (text_loss + image_loss) / 2.0
            
        # Backward pass
        total_loss.backward()
        
        # if the device is CPU, directly update the model
        if DEVICE == "cpu":
            optimizer.step()
        else:
            convert_models_to_fp32(semclip.model)
            optimizer.step()
            convert_models_to_fp16(semclip.model)
        
        # Save checkpoint after each batch
        print(f'Saving checkpoint at... {checkpoint_path}')
        torch.save({
            'epoch': epoch,
            'batch': batch_idx,
            'model_state_dict': semclip.model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'wandb_run_id': wandb_run_id,
        }, checkpoint_path)
        
        # Log the loss to wandb after each batch
        wandb.log({"Loss": total_loss.item()})
        
        # Update the progress bar with the current loss
        pbar.set_postfix(Loss=total_loss.item())




In [None]:
semclip.upload_model_to_hf_hub(model_name=train_name, hf_name='hunarbatra')