In [1]:
from pathlib import Path
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import cv2
from sklearn.model_selection import train_test_split
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from lightning import Trainer, seed_everything
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
from func import Flickr8kDataset, CLIPDualEncoderModel, Config

seed_everything(42)

  warn(f"Failed to load image Python extension: {e}")
Global seed set to 42


42

In [2]:
path = Path.home() / 'OneDrive - Seagroup/ai/image_captioning/flickr'
image_path = path / 'Images'

df = pd.read_csv(path / 'captions.txt')
df['id'] = [id_ for id_ in range(df.shape[0] // 5) for _ in range(5)]
df['image_path'] = [str(image_path / i) for i in df['image'].to_numpy()]
df.head(7)

Unnamed: 0,image,caption,id,image_path
0,1000268201_693b08cb0e.jpg,A child in a pink dress is climbing up a set o...,0,C:\Users\Kevin\OneDrive - Seagroup\ai\image_ca...
1,1000268201_693b08cb0e.jpg,A girl going into a wooden building .,0,C:\Users\Kevin\OneDrive - Seagroup\ai\image_ca...
2,1000268201_693b08cb0e.jpg,A little girl climbing into a wooden playhouse .,0,C:\Users\Kevin\OneDrive - Seagroup\ai\image_ca...
3,1000268201_693b08cb0e.jpg,A little girl climbing the stairs to her playh...,0,C:\Users\Kevin\OneDrive - Seagroup\ai\image_ca...
4,1000268201_693b08cb0e.jpg,A little girl in a pink dress going into a woo...,0,C:\Users\Kevin\OneDrive - Seagroup\ai\image_ca...
5,1001773457_577c3a7d70.jpg,A black dog and a spotted dog are fighting,1,C:\Users\Kevin\OneDrive - Seagroup\ai\image_ca...
6,1001773457_577c3a7d70.jpg,A black dog and a tri-colored dog playing with...,1,C:\Users\Kevin\OneDrive - Seagroup\ai\image_ca...


In [3]:
pretrain_model = 'distilbert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(pretrain_model)

In [4]:
train, test = train_test_split(df, test_size=.2, random_state=42)
train_dataset = Flickr8kDataset(train['image_path'].values.tolist(), train['caption'].values.tolist(),tokenizer=tokenizer)
test_dataset = Flickr8kDataset(test['image_path'].values.tolist(), test['caption'].values.tolist(),tokenizer=tokenizer)

batch_size = 32
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size)

In [None]:
model = CLIPDualEncoderModel(Config.pretrain_image, Config.pretrain_text, batch_size=batch_size)
model_checkpoint = ModelCheckpoint(dirpath='clip/',
                                   save_top_k=1,
                                   monitor='val_loss',
                                   mode='min',)
lr_monitor = LearningRateMonitor(logging_interval='step')

trainer = Trainer(
    accelerator='gpu',
    max_epochs=5,
    callbacks=[model_checkpoint, lr_monitor],
    deterministic=True,
)
trainer.fit(model,
            train_dataloaders=train_dataloader,
            val_dataloaders=test_dataloader)

Some weights of the model checkpoint at microsoft/resnet-50 were not used when initializing ResNetModel: ['classifier.1.bias', 'classifier.1.weight']
- This IS expected if you are initializing ResNetModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ResNetModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_transform.weight', 'vocab_projector.bias', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_layer_norm.weight', 'vocab_transform.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with anoth

Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

In [None]:
print(f'model path: {model_checkpoint.best_model_path}')
print(f'best loss: {model_checkpoint.best_model_score.cpu().item():,.2f}')
best_model = model.load_from_checkpoint(model_checkpoint.best_model_path).to('cuda')

In [None]:
def create_image_embedding(model, dataloader, device='cuda'):
    model.eval()
    image_embeddings = []
    with torch.no_grad():
        for batch in tqdm(dataloader, desc='Create Image Embedding'):
            image_features = model.image_encoder(batch["image"].to(device))
            embeb = model.image_projection(image_features)
            image_embeddings.append(embeb.cpu())
    return torch.cat(image_embeddings)

image_embeddings = create_image_embedding(best_model, train_dataloader)

In [None]:
def fletch_similar(model, image_embeddings, query, image_filenames, n=10):
    encoded_query = tokenizer([query])
    batch = {k: torch.tensor(v).to('cuda') for k, v in encoded_query.items()}
    with torch.no_grad():
        text_features = model.text_encoder(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"])
        text_embeddings = model.text_projection(text_features).cpu()

    image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1)
    text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1)
    dot_similarity = text_embeddings_n @ image_embeddings_n.T

    values, indices = dot_similarity.squeeze().topk(n * 5)
    matches = [image_filenames[idx] for idx in indices[::5].cpu()]

    fig, axes = plt.subplots(2, 5, figsize=(20, 10))
    axes = axes.flatten()
    for i, v in enumerate(matches):
        image = cv2.imread(v)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        axes[i].imshow(image)
        axes[i].axis("off")
    fig.tight_layout()

    plt.show()

In [None]:
query = test['caption'].sample(1).values[0]
# query = 'two people wearing hats'
print(query)
fletch_similar(best_model, image_embeddings, query, train['image_path'].values)