In [1]:
from pathlib import Path
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from lightning import Trainer, seed_everything
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
from func import FlickrDataset, CLIPDualEncoderModel, Config

seed_everything(42)

  warn(f"Failed to load image Python extension: {e}")
Global seed set to 42


42

In [2]:
path = Path.home() / 'OneDrive - Seagroup/ai/image_captioning/flickr30k'
image_path = path / 'flickr30k_images'

df = pd.read_csv(path / 'results.csv', delimiter='|')
df.columns = ['image', 'caption_number', 'caption']
for i in df.columns:
    df[i] = df[i].str.lstrip()
df.loc[19999, 'caption_number'] = '4'
df.loc[19999, 'caption'] = 'A dog runs across the grass .'
df['id'] = pd.factorize(df['image'])[0]
df['image_path'] = [str(image_path / i) for i in df['image'].to_numpy()]

print(df.shape)
df.head(7)

(158915, 5)


Unnamed: 0,image,caption_number,caption,id,image_path
0,1000092795.jpg,0,Two young guys with shaggy hair look at their ...,0,C:\Users\Kevin\OneDrive - Seagroup\ai\image_ca...
1,1000092795.jpg,1,"Two young , White males are outside near many ...",0,C:\Users\Kevin\OneDrive - Seagroup\ai\image_ca...
2,1000092795.jpg,2,Two men in green shirts are standing in a yard .,0,C:\Users\Kevin\OneDrive - Seagroup\ai\image_ca...
3,1000092795.jpg,3,A man in a blue shirt standing in a garden .,0,C:\Users\Kevin\OneDrive - Seagroup\ai\image_ca...
4,1000092795.jpg,4,Two friends enjoy time spent together .,0,C:\Users\Kevin\OneDrive - Seagroup\ai\image_ca...
5,10002456.jpg,0,Several men in hard hats are operating a giant...,1,C:\Users\Kevin\OneDrive - Seagroup\ai\image_ca...
6,10002456.jpg,1,Workers look down from up above on a piece of ...,1,C:\Users\Kevin\OneDrive - Seagroup\ai\image_ca...


In [3]:
pretrain_model = 'distilbert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(pretrain_model)

In [4]:
image_ids = range(0, df['id'].max() + 1)
valid_ids = np.random.choice(image_ids, size=int(0.2 * len(image_ids)), replace=False)
train_ids = [id_ for id_ in image_ids if id_ not in valid_ids]

train = df[df['id'].isin(train_ids)].reset_index(drop=True)
test = df[df['id'].isin(valid_ids)].reset_index(drop=True)
print(train.shape, test.shape)

train_dataset = FlickrDataset(train['image_path'].values.tolist(), train['caption'].values.tolist(),tokenizer=tokenizer)
test_dataset = FlickrDataset(test['image_path'].values.tolist(), test['caption'].values.tolist(),tokenizer=tokenizer)

batch_size = 32
num_workers = 4
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers)

(127135, 5) (31780, 5)


In [None]:
model = CLIPDualEncoderModel(Config.pretrain_image, Config.pretrain_text, batch_size=batch_size)
model_checkpoint = ModelCheckpoint(dirpath='clip/',
                                   filename='{epoch}-{val_loss:.2f}',
                                   save_top_k=1,
                                   monitor='val_loss',
                                   mode='min',)
lr_monitor = LearningRateMonitor(logging_interval='step')

trainer = Trainer(
    accelerator='gpu',
    max_epochs=5,
    callbacks=[model_checkpoint, lr_monitor],
    deterministic=True,
)
trainer.fit(model,
            train_dataloaders=train_dataloader,
            val_dataloaders=test_dataloader)

Some weights of the model checkpoint at microsoft/resnet-50 were not used when initializing ResNetModel: ['classifier.1.bias', 'classifier.1.weight']
- This IS expected if you are initializing ResNetModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ResNetModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_layer_norm.bias', 'vocab_projector.bias', 'vocab_layer_norm.weight', 'vocab_transform.bias', 'vocab_transform.weight', 'vocab_projector.weight']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with anoth

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

In [None]:
print(f'model path: {model_checkpoint.best_model_path}')
print(f'best loss: {model_checkpoint.best_model_score.cpu().item():,.2f}')
best_model = model.load_from_checkpoint(model_checkpoint.best_model_path).to('cuda')

In [None]:
def create_image_embedding(model, dataloader, device='cuda'):
    model.eval()
    image_embeddings = []
    with torch.no_grad():
        for batch in tqdm(dataloader, desc='Create Image Embedding'):
            image_features = model.image_encoder(batch["image"].to(device))
            embeb = model.image_projection(image_features)
            image_embeddings.append(embeb.cpu())
    return torch.cat(image_embeddings)

image_embeddings = create_image_embedding(best_model, train_dataloader)

In [None]:
def fletch_similar(model, image_embeddings, query, image_filenames, n=10):
    encoded_query = tokenizer([query])
    batch = {k: torch.tensor(v).to('cuda') for k, v in encoded_query.items()}
    with torch.no_grad():
        text_features = model.text_encoder(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"])
        text_embeddings = model.text_projection(text_features).cpu()

    image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1)
    text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1)
    dot_similarity = text_embeddings_n @ image_embeddings_n.T

    values, indices = dot_similarity.squeeze().topk(n * 5)
    matches = [image_filenames[idx] for idx in indices[::5].cpu()]

    fig, axes = plt.subplots(2, 5, figsize=(20, 10))
    axes = axes.flatten()
    for i, v in enumerate(matches):
        image = cv2.imread(v)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        axes[i].imshow(image)
        axes[i].axis("off")
    fig.tight_layout()

    plt.show()

In [None]:
query = train['caption'].sample(1).values[0]
# query = 'two people wearing hats'
print(query)
fletch_similar(best_model, image_embeddings, query, train['image_path'].values)