# Fine-tune CLIP on TACO 100

To improve the results, we fine-tune CLIP on 100 images with descriptions and see if the results are better.

In [1]:
import clip
import datetime
import math
import os
import tqdm
import dtlpy as dl
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import umap.umap_ as umap

from PIL import Image
from pathlib import Path
from sklearn.metrics.pairwise import cosine_similarity

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split

[2023-08-10 16:45:19][ERR][dtlpy:v1.81.4][services.api_client:935] POST https://gate.dataloop.ai/api/v1/sdk/check
User-Agent: dtlpy/1.81.4 CPython/3.8.10 Windows/10
Content-Length: 41
Content-Type: application/x-www-form-urlencoded
version=1.81.4&email=yaya.t%40dataloop.ai
  @numba.jit()
  @numba.jit()
  @numba.jit()
  @numba.jit()


In [2]:
BATCH_SIZE = 32
NUM_EPOCHS = 20
MODEL_VERSION = "1_TACO100"

## Generate descriptions for images from labels

A description is generated for the dataset based on the objects detected in the image, and will be printed to help assess model efficacy.

In [120]:
# create descriptions for each image from its annotations
def create_descrip(labels: list):
    description = "a photo"
    if len(labels) != 0:
        description += " of a "
        for i, label in enumerate(labels):
            if i < len(labels) - 1:
                description += f"{label}, "
            else:
                description += f"and {label}."
    return description

def get_data(data_path):
    # data_pairs = pd.read_csv(r"C:\Users\Yaya Tang\Documents\DATASETS\TACO 100\taco_100_INPUTS_nb.csv")
    data_pairs = pd.read_csv(data_path)
    return data_pairs['filepath'], data_pairs['img_description']


descrips_path = r"C:\Users\Yaya Tang\Documents\DATASETS\TACO dataloop\taco_descriptions.csv"
try:
    dataset_df = pd.read_csv(descrips_path)
    item_labels_lookup = dict(dataset_df)

    print(f"Full TACO dataset descriptions loaded")
except FileNotFoundError:
    import dtlpy as dl
    
    # setup dtlpy
    dl.setenv('prod')
    if dl.token_expired():
        dl.login()
    
    dl_dataset = dl.datasets.get(dataset_id='64c27e74615b1c5d7d576776')
    
    # for training, adding annotations as descriptions
    all_labels = dl_dataset.labels
    new_label_names = [label.tag for label in all_labels]
    
    # create text descriptions from labels
    items = list(dl_dataset.items.list().all())
    item_labels_lookup = {}
    pbar = tqdm.tqdm(total=len(items))
    for i, item in enumerate(items):
        item_name = item.name
        annotations = item.annotations.list()
        item_labels = []
        for annotation in annotations:
            item_labels.append(str(annotation.label).split(".")[-1])
        item_labels_lookup[item_name] = item_labels
        pbar.update()

    descrips_df = pd.DataFrame(item_labels_lookup)
    descrips_df.to_csv(descrips_path)


Iterate Entity:   0%|                                                   | 0/1500 [00:00<?, ?it/s][A
Iterate Entity:   0%|                                           | 1/1500 [00:00<16:01,  1.56it/s][A
Iterate Entity:   7%|██▋                                     | 101/1500 [00:01<00:11, 119.18it/s][A
Iterate Entity:  53%|████████████████████▊                  | 801/1500 [00:01<00:00, 1014.54it/s][A
Iterate Entity: 100%|███████████████████████████████████████| 1500/1500 [00:01<00:00, 955.83it/s][A



100%|████████████████████████████████████████████████████████| 1499/1499 [04:00<00:00,  6.23it/s][A

  0%|                                                           | 1/1500 [00:00<02:41,  9.26it/s][A
  0%|                                                           | 2/1500 [00:00<02:41,  9.26it/s][A
  0%|                                                           | 3/1500 [00:00<02:41,  9.26it/s][A
  0%|▏                                                          | 4/1500 [00:00<02:41,  9.26it/s][A
  0%|▏                                                          | 5/1500 [00:00<02:41,  9.26it/s][A
  0%|▏                                                          | 6/1500 [00:00<02:43,  9.13it/s][A
  0%|▎                                                          | 7/1500 [00:00<02:42,  9.17it/s][A
  1%|▎                                                          | 8/1500 [00:00<02:44,  9.09it/s][A
  1%|▎                                                          | 9/1500 [00:00<02:43,  9

ValueError: All arrays must be of the same length

In [3]:
class image_title_dataset(Dataset):
    def __init__(self, list_image_path, list_txt):
        self.image_path = list_image_path
        # you can tokenize everything at once in here(slow at the beginning), or tokenize it in the training loop.
        self.title = clip.tokenize(list_txt)

    def __len__(self):
        return len(self.title)

    def __getitem__(self, idx):
        image = preprocess(Image.open(self.image_path[idx]))  # Image from PIL module
        title = self.title[idx]
        return image, title

In [4]:
# load data
random_seed = 11
torch.manual_seed(random_seed)

list_image_path, list_txt = get_data()
dataset = image_title_dataset(list_image_path, list_txt)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE)

def convert_models_to_fp32(model):
    for p in model.parameters():
        p.data = p.data.float()
        p.grad.data = p.grad.data.float()

In [5]:
# train model
# if model iteration name already exists, skip and load instead
device = "cuda" if torch.cuda.is_available() else "cpu"
try:
    checkpoint_path = rf"C:\Users\Yaya Tang\PycharmProjects\clip-smart-search\checkpoints\model_{MODEL_VERSION}_BEST.pt"

    model, preprocess = clip.load("ViT-B/32", device=device)  # Must set jit=False for training
    checkpoint = torch.load(checkpoint_path)

    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Model {MODEL_VERSION} loaded")
except FileNotFoundError:
    if device == "cpu":
        model.float()
    else:
        clip.model.convert_weights(model)  # Actually this line is unnecessary since clip by default already on float16
    
    # keep track of the best model
    EARLY_STOPPING = 10 
    best_loss = np.Inf
    best_iter = 0
    end_training = False
    
    # loss fxns for images and their descriptions
    loss_img = nn.CrossEntropyLoss()
    loss_txt = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=5e-5, betas=(0.9, 0.98), eps=1e-6,
                                 weight_decay=0.2)  # Params used from paper, the lr is smaller, more safe for fine tuning to new dataset
    
    for epoch in range(NUM_EPOCHS):
        pbar = tqdm.tqdm(dataloader, total=len(dataloader))
        for batch in dataloader:
            optimizer.zero_grad()
    
            images, texts = batch
            images = images.to(device)
            texts = texts.to(device)
    
            # forward pass
            logits_per_image, logits_per_text = model(images, texts)
    
            # calc loss + backprop
            ground_truth = torch.arange(len(images), dtype=torch.long, device=device)
            total_loss = (loss_img(logits_per_image, ground_truth) + loss_txt(logits_per_text, ground_truth)) / 2
            total_loss.backward()
            if device == "cpu":
                optimizer.step()
            else:
                convert_models_to_fp32(model)
                optimizer.step()
                clip.model.convert_weights(model)
    
            pbar.set_description(f"Epoch {epoch + 1}/{NUM_EPOCHS}, Loss: {total_loss.item():.4f}")
        pbar.update()
        
        
        if epoch == 0:
            best_loss = total_loss # val_loss
        elif total_loss < best_loss: # val_loss < best_loss:
            best_iter = epoch + 1
            best_loss = total_loss # val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': total_loss,  # val_loss,
            }, rf"C:\Users\Yaya Tang\PycharmProjects\clip-smart-search\checkpoints\model_{MODEL_VERSION}_BEST.pt")
    
        if ((epoch + 1) - best_iter) > EARLY_STOPPING:
            print("Early stop achieved at", epoch + 1)
            break
        
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': total_loss,
        }, rf"C:\Users\Yaya Tang\PycharmProjects\clip-smart-search\checkpoints\model_{MODEL_VERSION}_epoch_{epoch + 1}.pt")


Model 1_TACO100 loaded


### Re-create image embeddings

In [6]:
def is_image_file(filename):
    img_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff']
    return any(filename.lower().endswith(ext) for ext in img_extensions)

# load images to embed and query
img_dir = Path(r"C:\Users\Yaya Tang\Documents\DATASETS\TACO dataloop\items\raw")
orig_img_paths = [str(img_path) for img_path in img_dir.glob("*") if is_image_file(str(img_path))]
img_paths = orig_img_paths.copy()

# prepare images
images_np = []
pbar = tqdm.tqdm(total=len(img_paths))
for img_path in img_paths:
    img_np = Image.open(img_path).convert("RGB")
    images_np.append(preprocess(img_np))
    pbar.set_description(f"Processing {img_path}...")
    pbar.update()
print(len(img_paths))

Processing C:\Users\Yaya Tang\Documents\DATASETS\TACO dataloop\items\raw\IMG_5068.JPG...: 100%|█|

1499


In [7]:
# create image features
model = model.eval()
image_input = torch.tensor(np.stack(images_np)).to(device)
with torch.no_grad():
    image_features = model.encode_image(image_input)

image_features /= image_features.norm(dim=-1, keepdim=True)

Processing C:\Users\Yaya Tang\Documents\DATASETS\TACO dataloop\items\raw\IMG_5068.JPG...: 100%|█|

### Query images with fine-tuned CLIP

We can now take the images that were returned and color them in on the UMAP plot:

In [8]:
# create query feature
QUERY_STRING = "cigarettes on the sidewalk"
query_keyword = "cigarette"
NUM_RESULTS = 20

text_tokens = clip.tokenize([QUERY_STRING]).to(device)
with torch.no_grad():
    text_features = model.encode_text(text_tokens)
text_features /= text_features.norm(dim=-1, keepdim=True)

## Get the K nearest images to the query

In [9]:
# get top results
results = cosine_similarity(text_features.cpu().numpy(), image_features.cpu().numpy())

results_dict = {'name': [], 'prob': [], 'filepath': []}

# reset img_paths from the last use
img_paths = orig_img_paths.copy()
pbar = tqdm.tqdm(total=len(img_paths))
for i, img_path in enumerate(img_paths):
    results_dict['name'].append(Path(img_path).name)
    results_dict['prob'].append(results[0][i])
    results_dict['filepath'].append(img_path)
    results_dict.update()
    pbar.update()

results_df = pd.DataFrame(results_dict)
results_df.sort_values(by=['prob'], ascending=False, inplace=True)

results_df = results_df.iloc[:NUM_RESULTS][['name', 'prob', 'filepath']]

results_with_labels = results_df[['name', 'prob']]
results_with_labels['labels'] = [item_labels_lookup[name] for name in results_df['name']]
results_with_labels['has_keyword'] = [1 if query_keyword in [item.lower() for item in result] else 0 for result in results_with_labels['labels']]
results_with_labels.to_csv(Path(save_dir, f'finetuned_results_{MODEL_VERSION}_{timestamp}.csv'))

print(results_with_labels[['name', 'prob']]) #, 'labels']])

# show the number of images with labels that include the keyword
num_hits = sum(results_with_labels['has_keyword'])
print(f'Number of images that have the keyword in the labels: {num_hits}')


Processing C:\Users\Yaya Tang\Documents\DATASETS\TACO dataloop\items\raw\IMG_5068.JPG...: 100%|█|[A


NameError: name 'item_labels_lookup' is not defined

## Plot returned images from query


In [None]:
# plot returned images on a grid
num_grid = math.isqrt(NUM_RESULTS)
subplot_dims = num_grid + 1 if num_grid ** 2 < NUM_RESULTS else num_grid

plt.figure(figsize=(20, 20))
plt.tight_layout()
for i, img_path in enumerate(results_df['filepath'].iloc[:NUM_RESULTS]):
    plt.subplot(subplot_dims, subplot_dims, i + 1)
    image = Image.open(img_path).convert("RGB")
    plt.text(0,-1, f'{Path(img_path).name}', verticalalignment="bottom", wrap=True)
    plt.text(0,0, f'{item_labels_lookup[Path(img_path).name]}', verticalalignment="top", wrap=True)
    plt.imshow(image)

plt.suptitle(f"Query: '{QUERY_STRING}', returned {len(results_df)}, on fine-tuned CLIP {MODEL_VERSION}\nfound {num_hits}")
plot_filename = f"clip_query_results_{MODEL_VERSION}_{timestamp}.png"
save_path = os.path.join(save_dir, plot_filename)
plt.savefig(save_path, dpi=800)
print(f'Saved query results to {save_path}')

## Visualize: dim reduction of fine-tuned features

For UMAP dimension reduction, we can take the image features (n_images, 512) to (n_images, 2)


In [None]:
####################################
# UMAP reduction and visualization #
####################################
# concatenate both image + query features and reduce with UMAP
all_features = torch.cat((image_features, text_features), 0)
reducer = umap.UMAP(random_state=42, metric='cosine')
embedding = reducer.fit_transform(all_features.cpu())

# update lists to include the query string as the last item
names = [Path(path).name for path in img_paths]
results = results_df['name'].tolist()
query_returned = ['results' if name in results else '0' for name in names]

names.append('query')
query_returned.append('query')
img_paths.append('query')

thumbs_df = pd.DataFrame(embedding, columns=['x', 'y'])
thumbs_df['filename'] = img_paths
thumbs_df['name'] = names
thumbs_df['query_returned'] = query_returned

In [None]:
plt.figure(figsize=(15, 10))
sns.scatterplot(x=thumbs_df['x'], y=thumbs_df['y'], hue=np.array(thumbs_df['query_returned']), palette="deep")
plt.axis('off')
plt.title(f'UMAP of fine-tuned CLIP features, with returned images for model {MODEL_VERSION}')
plt.show()

plot_filename = f"UMAP_clip_query_results_{MODEL_VERSION}_{timestamp}.png"
save_path = os.path.join(save_dir, plot_filename)
plt.savefig(save_path, dpi=800)
print(f'Saved query results to {save_path}')