In [None]:
!pip install timm
!pip install transformers
!pip install sentence-transformers
!pip install numba
!pip install tensorboard
from google.colab import drive
drive.mount('/gdrive')
%cd /gdrive/My Drive/CLIP Project
%matplotlib inline
%load_ext tensorboard
# !tensorboard --logdir=runs --bind_all # http://localhost:6006/

Collecting timm
  Downloading timm-0.9.2-py3-none-any.whl (2.2 MB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/2.2 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━[0m [32m1.4/2.2 MB[0m [31m46.9 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m2.2/2.2 MB[0m [31m49.4 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m28.7 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub (from timm)
  Downloading huggingface_hub-0.16.2-py3-none-any.whl (268 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m268.5/268.5 kB[0m [31m25.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting safetensors (from timm)
  Downloading safetensors-0.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 

In [None]:
# !unzip 'Train/resized_train.zip' -d ''

In [None]:
import os
import cv2
import gc
import numpy as np
import pandas as pd
import itertools
from tqdm.autonotebook import tqdm
import albumentations as A
import matplotlib.pyplot as plt

import torch
from torch.utils.tensorboard import SummaryWriter
from torch import nn
import torch.nn.functional as F
import timm
from transformers import DistilBertModel, DistilBertConfig, DistilBertTokenizer
from sentence_transformers import SentenceTransformer, util
# from numba import cuda
# device = cuda.get_current_device()
# device.reset()
%tensorboard --logdir runs

  from tqdm.autonotebook import tqdm


## Config

*A note on config and CFG: I wrote the codes with python scripts and then converted it into a Jupyter Notebook. So, in case of python scripts, config is a normal python file where I put all the hyperparameters and in the case of Jupyter Notebook, its a class defined in the beginning of the notebook to keep all the hyperparameters.*

In [None]:
class CFG:
    debug = False
    image_path = 'resized_train' #'Train/resized_train'
    captions_path = 'Train/caption_prediction_train.csv'
    transfer_path = 'Pre-Train/pre_train_100.pt'
    batch_size = 64
    validation_ratio = 0.2
    num_workers = 2
    head_lr = 1e-3
    image_encoder_lr = 1e-4
    text_encoder_lr = 1e-5
    logit_scale_lr = 1e-4
    weight_decay = 1e-3
    patience = 1
    factor = 0.8
    epochs = 20
    cylambda1 = 0.001
    cylambda2 = 0.001
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # writer = SummaryWriter(comment="convnext_nano")
    writer_comment = '-CyClip_learnable-0.01'

    model_name = 'convnext_nano'
    m = timm.create_model(model_name, pretrained=True)
    input_size = m.pretrained_cfg['input_size']
    num_classes = 100 # Number of concepts
    image_embedding = m.forward_features(torch.randn(1, input_size[0], input_size[1], input_size[2])).shape[1]
    # forward_features(torch.randn(2, 3, 299, 299)).shape[1] # 299, 299
    # image_embedding = 112
    text_encoder_model = 'distilbert-base-uncased'
    text_embedding = 768
    text_tokenizer = 'distilbert-base-uncased'
    max_length = 200
    samples = 10000

    pretrained = True # for both image encoder and text encoder
    trainable = True # for both image encoder and text encoder
    logit_scale_init_value = 0.07 # 2.6592
    temperature = 0.07

    # image size
    size = 224

    # for projection head; used for both image and text encoders
    num_projection_layers = 1
    projection_dim = 256
    dropout = 0.1

    similarity_threshold = 0.8
    global_epoch = 0


## Utils

In [None]:
class AvgMeter:
    def __init__(self, name="Metric"):
        self.name = name
        self.reset()

    def reset(self):
        self.avg_loss, self.accuracy, self.loss_sum, self.correct_predictions, self.count = [0] * 5

    def update(self, val, corr, count=1):
        self.count += count
        self.loss_sum += val * count
        self.correct_predictions += corr
        self.avg_loss = self.loss_sum / self.count
        self.accuracy = self.correct_predictions / self.count

    def __repr__(self):
        text = f"{self.name}: {self.avg_loss:.4f}, {self.accuracy:.4f}"
        return text

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group["lr"]


## Dataset

As you can see in the tittle image of this article, we need to encode both images and their describing texts. So, the dataset needs to **return both images and texts**. Of course we are not going to feed raw text to our text encoder! We will use **DistilBERT** model (which is smaller than BERT but performs nearly as well as BERT) from **HuggingFace** library as our text encoder; so, we need to **tokenize** the sentences (captions) with DistilBERT tokenizer and then feed the token ids (input_ids) and the attention masks to DistilBERT. Therefore, the dataset needs to take care of the tokenization as well. Below you can see the dataset's code. Below that I'll explain the most important things that is happening in the code.

In the **\_\_init\_\_** we receive a tokenizer object which is actually a HuggingFace tokinzer; this tokenizer will be loaded when running the model. We are padding and truncating the captions to a specified max_length. In the **\_\_getitem\_\_** we will first load an encoded caption which is a dictionary with keys input_ids and attention_mask, make tensors out of its values and after that we will load the corresponding image, transform and augment it (if there is any!) and then we make it a tensor and put it in the dictionary with "image" as the key. Finally we put the raw text of the caption with the key "caption" in the dictionary only for visualization purposes.

I did not use additional data augmentations but you can add them if you want to improve the model's performance.

In [None]:
class CLIPDataset(torch.utils.data.Dataset):
    def __init__(self, dataframe, tokenizer, transforms):
        """
        image_filenames and cpations must have the same length; so, if there are
        multiple captions for each image, the image_filenames must have repetitive
        file names
        """

        self.image_filenames = dataframe['ID'].values
        self.captions = dataframe['caption'].values
        self.encoded_captions = tokenizer(
            list(self.captions), padding=True, truncation=True, max_length=CFG.max_length
        )
        self.transforms = transforms

    def __getitem__(self, idx):
        item = {
            key: torch.tensor(values[idx])
            for key, values in self.encoded_captions.items()
        }

        image = cv2.imread(f"{CFG.image_path}/{self.image_filenames[idx] + '.jpg'}")
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = self.transforms(image=image)['image']
        item['image'] = torch.tensor(image).permute(2, 0, 1).float()
        item['caption'] = self.captions[idx]

        return item


    def __len__(self):
        return len(self.captions)



def get_transforms(mode="train"):
    if mode == "train":
        return A.Compose(
            [
                A.Resize(CFG.size, CFG.size, always_apply=True),
                A.Normalize(max_pixel_value=255.0, always_apply=True),
            ]
        )
    else:
        return A.Compose(
            [
                A.Resize(CFG.size, CFG.size, always_apply=True),
                A.Normalize(max_pixel_value=255.0, always_apply=True),
            ]
        )

## Image Encoder

The image encoder code is straight forward. I'm using PyTorch Image Models library (timm) here which makes a lot of different image models available from ResNets to EfficientNets and many more. Here we will use a ResNet50 as our image encoder. You can easily use torchvision library to use ResNets if you don't want to install a new library.

The code encodes each image to a fixed size vector with the size of the model's output channels (in case of ResNet50 the vector size will be **2048**). This is the output after the nn.AdaptiveAvgPool2d() layer.

In [None]:
class TransferImageEncoder(nn.Module):
    """
    Encode images to a fixed size vector
    """

    def __init__(
        self, model_name=CFG.model_name, pretrained=CFG.pretrained, trainable=CFG.trainable
    ):
        super().__init__()
        self.model = timm.create_model(
            model_name, pretrained, num_classes=CFG.num_classes, global_pool="avg"
        )
        for p in self.model.parameters():
            p.requires_grad = trainable

    def forward(self, x):
        return self.model(x)

In [None]:
class ImageEncoder(nn.Module):
    """
    Encode images to a fixed size vector
    """

    def __init__(self, model_name=CFG.model_name, pretrained=CFG.pretrained, trainable=CFG.trainable):
        super().__init__()
        self.model = timm.create_model(model_name, pretrained, num_classes=0, global_pool="avg")
        for p in self.model.parameters():
            p.requires_grad = trainable

        self.transfer_model = TransferImageEncoder().to(CFG.device)
        self.transfer_model.load_state_dict(torch.load(CFG.transfer_path, map_location=CFG.device))

        self.transfer_model.model.head = self.model.head
        self.model = self.transfer_model.model

    def forward(self, x):
        return self.model(x)

## Text Encoder

As I mentioned before, I'll use DistilBERT as the text encoder. Like its bigger brother BERT, two special tokens will be added to the actual input tokens: **CLS** and **SEP** which mark the start and end of a sentence. To grab the whole representation of a sentence (as the related BERT and DistilBERT papers point out) we use the final representations of the CLS token and we hope that this representation captures the overall meaning of the sentence (caption). Thinking it in this way, it is similar to what we did to images and converted them into a fixed size vector.

In the case of DistilBERT (and also BERT) the output hidden representation for each token is a vector with size **768**. So, the whole caption will be encoded in the CLS token representation whose size is 768.

In [None]:
class TextEncoder(nn.Module):
    def __init__(self, model_name=CFG.text_encoder_model, pretrained=CFG.pretrained, trainable=CFG.trainable):
        super().__init__()
        if pretrained:
            self.model = DistilBertModel.from_pretrained(model_name)
        else:
            self.model = DistilBertModel(config=DistilBertConfig())

        for p in self.model.parameters():
            p.requires_grad = trainable

        # we are using the CLS token hidden representation as the sentence's embedding
        self.target_token_idx = 0

    def forward(self, input_ids, attention_mask):
        output = self.model(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = output.last_hidden_state
        return last_hidden_state[:, self.target_token_idx, :]

## Projection Head

I used [Keras code example implementation](https://keras.io/examples/nlp/nl_image_search/) of projection head to write the following in PyTorch.
Now that we have encoded both our images and texts into fixed size vectors (2048 for image and 768 for text) we need to bring (project) them into a **new world** (!) with **similar dimensions** for both images and texts in order to be able to compare them and push apart the non-relevant image and texts and pull together those that match. So, the following code will bring the 2048 and 768 dimensional vectors into a 256 (projection_dim) dimensional world, where we can **compare** them.

"embedding_dim" is the size of the input vector (2048 for images and 768 for texts) and "projection_dim" is the the size of the output vector which will be 256 for our case. For understanding the details of this part you can refer to the CLIP paper.

In [None]:
class ProjectionHead(nn.Module):
    def __init__(
        self,
        embedding_dim,
        projection_dim=CFG.projection_dim,
        dropout=CFG.dropout
    ):
        super().__init__()
        self.projection = nn.Linear(embedding_dim, projection_dim)
        self.gelu = nn.GELU()
        self.fc = nn.Linear(projection_dim, projection_dim)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(projection_dim)

    def forward(self, x):
        projected = self.projection(x)
        x = self.gelu(projected)
        x = self.fc(x)
        x = self.dropout(x)
        x = x + projected
        x = self.layer_norm(x)
        return x

## CLIP

In [None]:
class CLIPModel(nn.Module):
    def __init__(
        self,
        temperature=CFG.temperature,
        image_embedding=CFG.image_embedding,
        text_embedding=CFG.text_embedding,
    ):
        super().__init__()
        self.image_encoder = ImageEncoder()
        self.text_encoder = TextEncoder()
        self.image_projection = ProjectionHead(embedding_dim=image_embedding)
        self.text_projection = ProjectionHead(embedding_dim=text_embedding)
        # self.temperature = torch.nn.Parameter(torch.FloatTensor([temperature]))
        self.logit_scale = nn.Parameter(torch.ones([]) * CFG.logit_scale_init_value)

    def forward(self, batch):
        # Getting Image and Text Features
        image_features = self.image_encoder(batch["image"])
        text_features = self.text_encoder(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"])
        batch_size = batch['input_ids'].shape[0]
        temperature = self.logit_scale.exp()

        # Getting Image and Text Embeddings (with same dimension)
        image_embeddings = self.image_projection(image_features)
        text_embeddings = self.text_projection(text_features)
        criterion = nn.CrossEntropyLoss(reduction = "sum").to(CFG.device)

        # Cross-modal contrastive alignment (CLIP)
        logits_text_per_image = (image_embeddings @ text_embeddings.T) * temperature
        logits_image_per_text = logits_text_per_image.T
        target = torch.arange(batch_size).long().to(CFG.device, non_blocking = True)
        contrastive_loss = (criterion(logits_text_per_image, target) + criterion(logits_image_per_text, target)) / 2

        # In-modal consistency (CyCLIP)
        logits_image_per_image = temperature * image_embeddings @ image_embeddings.t()
        logits_text_per_text = temperature * text_embeddings @ text_embeddings.t()
        inmodal_cyclic_loss = (logits_image_per_image - logits_text_per_text).square().mean() / (temperature * temperature) * batch_size

        # Cross-modal consistency (CyCLIP)
        crossmodal_cyclic_loss = (logits_text_per_image - logits_image_per_text).square().mean() * (temperature * temperature) * batch_size
        crossmodal_cyclic_loss = (logits_text_per_image - logits_image_per_text).square().mean() / (temperature * temperature) * batch_size

        cyclic_loss = CFG.cylambda1 * inmodal_cyclic_loss + CFG.cylambda2 * crossmodal_cyclic_loss

        loss = contrastive_loss + cyclic_loss

        correct_preds = num_of_correct_preds(logits_image_per_text, batch['caption'])

        return loss, correct_preds


def num_of_correct_preds(logits, captions):
    model = SentenceTransformer('bert-base-nli-mean-tokens')
    softmax_logits = F.softmax(logits, dim=-1)
    correct_predictions = 0
    for i, row in enumerate(softmax_logits):
        pred = torch.argmax(row)
        sentences = [captions[pred].split('\"')[1], captions[i].split('\"')[1]]
        sentence_embeddings = model.encode(sentences)
        similarity = F.cosine_similarity(torch.from_numpy(sentence_embeddings[0]), torch.from_numpy(sentence_embeddings[1]), dim=0)
        if similarity >= CFG.similarity_threshold:
          correct_predictions += 1

    return correct_predictions


Now that we've got our targets matrix, we will use simple cross entropy to calculate the actual loss. I've written the full matrix form of cross entropy as a function which you can see in the bottom of the code block. Okay! We are done! Wasn't it simple?! Alright, you can ignore the next paragraph but if you are curious, there is an important note in that.

**Here's why I didn't use a simpler approach**: I need to admit that there's a simpler way to calculate this loss in PyTorch; by doing this: nn.CrossEntropyLoss()(logits, torch.arange(batch_size)). Why I did not use it here? For 2 reasons. 1- The dataset we are using has multiple captions for a single image; so, there is the possibility that two identical images with their similar captions exist in a batch (it is rare but it can happen). Taking the loss with this easier method will ignore this possibility and the model learns to pull apart two representations (assume them different)  that are actually the same. Obviously, we don't want this to happen so I calculated the whole target matrix in a way that takes care of these edge cases. 2- Doing it the way I did, gave me a better understanding of what is happening in this loss function; so, I thought it would give you a better intuition as well!

## Train

Here are some funtions to help us load train and valid dataloaders, our model and then train and evaluate our model on those. There's not much going on here; just simple training loop and utility functions

In [None]:
def pre_process_data(df):
  caption_max = 100
  return df[df['caption'].apply(lambda x: len(x.split()) <= caption_max)]

In [None]:
def add_concept_to_captions(df_train_captions):
  df_train_concepts = pd.read_csv('Train/concept_detection_train.csv', sep='\t')
  df_concepts = pd.read_csv('Train/concepts.csv', sep='\t')
  for train_id, train_caption_row in df_train_captions.iterrows():
    img_id = train_caption_row['ID']
    cuis = df_train_concepts.loc[df_train_concepts['ID'] == img_id, 'cuis'].values[0].split(';')
    caption_builder = "The concept of this image is " # sentence_infront_medical_caption


    if len(cuis) == 0:
        orig_caption = df_train_captions.loc[df_train_captions['ID']==img_id]['caption'].values[0]
        df_train_captions.loc[df_train_captions["ID"] == img_id, 'caption'] = "This image has no concepts. " + orig_caption
        continue

    for i, cui in enumerate(cuis):
      concept_row = df_concepts.loc[df_concepts['concept'] == cui]['concept_name']
      if i==len(cuis)-1 and not i==0:
        caption_builder += " and " + concept_row.values[0]
      elif len(cuis)==1 or i==len(cuis)-2:
        caption_builder += concept_row.values[0]
      else:
        caption_builder += concept_row.values[0] + ", "

    caption_builder += "."
    orig_caption = df_train_captions.loc[df_train_captions['ID']==img_id]['caption'].values[0]
    df_train_captions.loc[df_train_captions["ID"] == img_id, 'caption'] = caption_builder + " The caption of the medical image is: \"" + orig_caption + "\"." # sentence_infront_medical_caption



  return df_train_captions

dataframe = pd.read_csv('Train/caption_prediction_train.csv', sep='\t')
dataframe = dataframe.head(10)
dataframe = add_concept_to_captions(dataframe)
print(dataframe['caption'][3])
print(dataframe['caption'][3].split('\"')[1])

In [None]:
def make_train_valid_dfs(add_concepts=False):
    dataframe = pd.read_csv('Train/caption_prediction_train.csv', sep='\t')
    # dataframe = dataframe.head(CFG.samples)
    dataframe = pre_process_data(dataframe)
    dataframe = add_concept_to_captions(dataframe) if add_concepts else dataframe

    max_id = dataframe.shape[0]
    image_ids = np.arange(0, max_id)
    np.random.seed(42)
    valid_ids = np.random.choice(
        image_ids, size=int(CFG.validation_ratio * len(image_ids)), replace=False
    )
    train_ids = [id_ for id_ in image_ids if id_ not in valid_ids]
    train_dataframe = dataframe.iloc[train_ids,:].reset_index()
    valid_dataframe = dataframe.iloc[valid_ids,:].reset_index()
    return train_dataframe, valid_dataframe


def build_loaders(dataframe, tokenizer, mode):
    transforms = get_transforms(mode=mode)
    dataset = CLIPDataset(
        dataframe=dataframe,
        tokenizer=tokenizer,
        transforms=transforms,
    )
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=CFG.batch_size,
        num_workers=CFG.num_workers,
        shuffle=True if mode == "train" else False,
    )
    return dataloader

Here's a handy function to train our model. There's not much happening here; just loading the batches, feeding them to the model and stepping the optimizer and lr_scheduler.

In [None]:
def train_epoch(model, train_loader, optimizer, lr_scheduler, step, train_interval, epoch, writer):
    metrics = AvgMeter()

    tqdm_object = tqdm(train_loader, total=len(train_loader))
    for i, batch in enumerate(tqdm_object):
        cuda_batch = {k: v.to(CFG.device) for k, v in batch.items() if k != "caption"}
        cuda_batch['caption'] = batch['caption']
        loss, correct_preds = model(cuda_batch)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if step == "batch":
            lr_scheduler.step()
        count = batch["image"].size(0)
        metrics.update(loss.item(), correct_preds, count)

        if (i + 1) % train_interval == 0 or i + 1 == len(train_loader):
            writer.add_scalar("Train/Loss", metrics.avg_loss, len(train_loader)*epoch + i)
            writer.add_scalar("Train/Accuracy", metrics.accuracy, len(train_loader)*epoch + i)

        tqdm_object.set_postfix(train_loss = metrics.avg_loss, accuracy=metrics.accuracy, lr=get_lr(optimizer))
    return metrics


def valid_epoch(model, valid_loader, epoch, writer):
    metrics = AvgMeter()

    tqdm_object = tqdm(valid_loader, total=len(valid_loader))
    for i, batch in enumerate(tqdm_object):
        cuda_batch = {k: v.to(CFG.device) for k, v in batch.items() if k != "caption"}
        cuda_batch['caption'] = batch['caption']
        loss, correct_preds = model(cuda_batch)

        count = batch["image"].size(0)
        metrics.update(loss.item(), correct_preds, count)

        tqdm_object.set_postfix(valid_loss=metrics.avg_loss, accuracy=metrics.accuracy)

    writer.add_scalar("Validation/Loss", metrics.avg_loss, epoch + 1)
    writer.add_scalar("Validation/Accuracy", metrics.accuracy, epoch + 1)
    return metrics


def main():
    train_df, valid_df = make_train_valid_dfs(add_concepts=True)
    tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)
    train_loader = build_loaders(train_df, tokenizer, mode="train")
    valid_loader = build_loaders(valid_df, tokenizer, mode="valid")

    model = CLIPModel().to(CFG.device)
    params = [
        {"params": model.image_encoder.parameters(), "lr": CFG.image_encoder_lr},
        {"params": model.text_encoder.parameters(), "lr": CFG.text_encoder_lr},
        {"params": itertools.chain(model.image_projection.parameters(), model.text_projection.parameters()), "lr": CFG.head_lr, "weight_decay": CFG.weight_decay},
        {"params": model.logit_scale, "lr": CFG.logit_scale_lr}
    ]
    optimizer = torch.optim.AdamW(params, weight_decay=0.)
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode="min", patience=CFG.patience, factor=CFG.factor
    )
    step = "epoch"
    train_interval = 25 # Intermediate batch sampling

    writer = SummaryWriter(comment=CFG.writer_comment)
    best_loss = float('inf')
    for epoch in range(CFG.epochs):
        print(f"Epoch: {epoch + 1}, Temperature: {model.logit_scale}")
        print(model.logit_scale.grad)

        CFG.global_epoch = epoch + 1
        model.train()

        train_epoch(model, train_loader, optimizer, lr_scheduler, step, train_interval, epoch, writer)

        model.eval()
        with torch.no_grad():
            valid_metrics = valid_epoch(model, valid_loader, epoch, writer)

        if valid_metrics.avg_loss < best_loss:
            best_loss = valid_metrics.avg_loss
            #torch.save(model.state_dict(), "new.pt")
            print("Saved Best Model!")

        lr_scheduler.step(valid_metrics.avg_loss)

    writer.flush()
    writer.close()

Running the next cell start training the model. Put the kernel on GPU mode. Every epoch should take about 8 minutes on GPU if you are using 8k version (even one epoch is enough!). It can take some seconds before training actually starts because we are going to encode all the captions once in the train and valid dataset, so please don't stop it! Every thing is working fine.

In [None]:
%tensorboard --logdir runs
image = None
while(image is None):
    image = cv2.imread(f"{CFG.image_path}/ImageCLEFmedCaption_2022_train_053449.jpg") # FIX colab cv2 error
print(image.shape)
main()

## Inference

Okay! We are done with training the model. Now, we need to do inference which in our case will be giving the model a piece of text and want it to retrieve the most relevant images from an unseen validation (or test) set.

### Getting Image Embeddings

In this function, we are loading the model that we saved after training, feeding it images in validation set and returning the image_embeddings with shape (valid_set_size, 256) and the model itself.

In [None]:
def get_image_embeddings(valid_df, model_path):
    tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)
    valid_loader = build_loaders(valid_df, tokenizer, mode="valid")

    model = CLIPModel().to(CFG.device)
    model.load_state_dict(torch.load(model_path, map_location=CFG.device))
    model.eval()

    valid_image_embeddings = []
    with torch.no_grad():
        for batch in tqdm(valid_loader):
            image_features = model.image_encoder(batch["image"].to(CFG.device))
            image_embeddings = model.image_projection(image_features)
            valid_image_embeddings.append(image_embeddings)
    return model, torch.cat(valid_image_embeddings)

In [None]:
image = cv2.imread(f"{CFG.image_path}/ImageCLEFmedCaption_2022_train_053449.jpg")
_, valid_df = make_train_valid_dfs()
model, image_embeddings = get_image_embeddings(valid_df, "63_percent.pt")

### Finding Matches

This function does the final task that we wished our model would be capable of: it gets the model, image_embeddings, and a text query. It will display the most relevant images from the validation set! Isn't it amazing? Let's see how it performs after all!

In [None]:
def find_matches(model, image_embeddings, query, image_filenames, n=9):
    tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)
    encoded_query = tokenizer([query])
    batch = {
        key: torch.tensor(values).to(CFG.device)
        for key, values in encoded_query.items()
    }
    with torch.no_grad():
        text_features = model.text_encoder(
            input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
        )
        text_embeddings = model.text_projection(text_features)

    image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1)
    text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1)
    dot_similarity = text_embeddings_n @ image_embeddings_n.T

    values, indices = torch.topk(dot_similarity.squeeze(0), n * 5)
    matches = [image_filenames[idx] for idx in indices[::5]]

    _, axes = plt.subplots(3, 3, figsize=(10, 10))
    for match, ax in zip(matches, axes.flatten()):
        image = cv2.imread(f"{CFG.image_path}/{match+'.jpg'}")
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        ax.imshow(image)
        ax.axis("off")

    plt.show()

In [None]:
find_matches(model,
             image_embeddings,
             query="elbow",
             image_filenames=valid_df['ID'].values,
             n=9)