# Set up

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1hYHb0FTdKQCXZs3qCwVZnSuVGrZU2Z1w?usp=sharing)

JDK

In [None]:
!apt-get install -y openjdk-11-jdk-headless

In [None]:
# from google.colab import drive
# drive.mount('/content/gdrive', force_remount=True)

# # change "/content/gdrive/MyDrive/"  to "/mydrive so you can use directly /mydrive"
# !ln -s /content/gdrive/MyDrive/ /mydrive

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1hYHb0FTdKQCXZs3qCwVZnSuVGrZU2Z1w?usp=sharing)

In [None]:
!pip install timm --quiet
!pip install transformers --quiet

!pip install gdown --quiet
!pip install py_vncorenlp --quiet

In [None]:
import os
import json
import cv2
import gc
import numpy as np
import pandas as pd
import itertools
from tqdm.autonotebook import tqdm
import albumentations as A
import matplotlib.pyplot as plt
import seaborn as sns
import shutil
import requests

import torch
from torch import nn
import torch.nn.functional as F
import timm
from transformers import DistilBertTokenizer, AutoModel, AutoTokenizer, AutoConfig


## Config

In [None]:
import py_vncorenlp

# Automatically download VnCoreNLP components from the original repository
# and save them in some local machine folder
py_vncorenlp.download_model()

# Load the word and sentence segmentation component
rdrsegmenter = py_vncorenlp.VnCoreNLP(annotators=["wseg"])

text = "Ông Nguyễn Khắc Chúc  đang làm việc tại Đại học Quốc gia Hà Nội. Bà Lan, vợ ông Chúc, cũng làm việc tại đây."

output = rdrsegmenter.word_segment(text)

print(' '.join(output))
# ['Ông Nguyễn_Khắc_Chúc đang làm_việc tại Đại_học Quốc_gia Hà_Nội .', 'Bà Lan , vợ ông Chúc , cũng làm_việc tại đây .']


In [None]:
class CFG:
    debug = False
    image_path = "../tmp/images"
    captions_path = "."
    batch_size = 32
    num_workers = 4
    head_lr = 1e-3
    image_encoder_lr = 1e-4
    text_encoder_lr = 1e-5
    weight_decay = 1e-3
    patience = 1
    factor = 0.8
    epochs = 1
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    image_model = None
    image_embedding = None

    text_encoder_model = None
    text_embedding = None
    text_tokenizer = None
    
    max_length = 70
    segmenter = None

    text_encoder_pretrained = True
    image_encoder_pretrained = True
    text_encoder_trainable = True
    image_encoder_trainable = True
    temperature = 1.0

    # image size
    size = 224

    # for projection head; used for both image and text encoders
    num_projection_layers = 1
    # projection_dim = 256
    projection_dim = 512
    dropout = 0.1

    '''
    Dùng Resnet
    Ko pretrain
    Xem lại dataset (image size)
    Thêm dataset
    '''

### Detailed CFG

In [None]:
__text_models__ = {
    "PhoBERT-base": "vinai/phobert-base-v2",
    "PhoBERT-large": "vinai/phobert-large",
    "ViT5-base": "VietAI/vit5-base",
    "ViT5-large": "VietAI/vit5-large"
}
__image_models__ = {
    "ViT-S": "vit_small_patch16_224",
    "ViT-B": "vit_base_patch16_224",
    "ViT-L": "vit_large_patch16_224",
    "ViT-H": "vit_huge_patch16_224",
    "ResNet50": "resnet50"
}

In [None]:
__text_models__ = {
    "PhoBERT-base": "vinai/phobert-base-v2",
    "PhoBERT-large": "vinai/phobert-large",
    "ViT5-base": "VietAI/vit5-base",
    "ViT5-large": "VietAI/vit5-large"
}

text_encoder_model = __text_models__["PhoBERT-large"]
CFG.text_encoder_pretrained = True
CFG.text_encoder_trainable = True

####################################################################################
CFG.text_encoder_model = text_encoder_model
CFG.text_tokenizer = text_encoder_model

if "pho" in text_encoder_model:
    CFG.segmenter = lambda sentence: ' '.join(rdrsegmenter.word_segment(sentence))
    
if text_encoder_model == "vinai/phobert-base-v2":
    CFG.text_embedding = 768
if text_encoder_model == "vinai/phobert-large":
    CFG.text_embedding = 1024
if text_encoder_model == "VietAI/vit5-base":
    CFG.text_embedding = 768
if text_encoder_model == "VietAI/vit5-large":
    CFG.text_embedding = 1024


In [None]:
__image_models__ = {
    "ViT-S": "vit_small_patch16_224",
    "ViT-B": "vit_base_patch16_224",
    "ViT-L": "vit_large_patch16_224",
    "ViT-H": "vit_huge_patch16_224",
    "ResNet50": "resnet50"
}
image_encoder_model = __image_models__["ResNet50"]
CFG.image_encoder_pretrained = True
CFG.image_encoder_trainable = False

####################################################################################
CFG.image_encoder_model = image_encoder_model

if image_encoder_model == "vit_small_patch16_224":
    CFG.image_embedding = None
if image_encoder_model == "vit_base_patch16_224":
    CFG.image_embedding = 768
if image_encoder_model == "vit_large_patch16_224":
    CFG.image_embedding = 1024
if image_encoder_model == "vit_huge_patch16_224":
    CFG.image_embedding = 1280
if image_encoder_model == "resnet50":
    CFG.image_embedding = 2048  
    

## Utils

In [None]:
class AvgMeter:
    def __init__(self, name="Metric"):
        self.name = name
        self.reset()

    def reset(self):
        self.avg, self.sum, self.count = [0] * 3

    def update(self, val, count=1):
        self.count += count
        self.sum += val * count
        self.avg = self.sum / self.count

    def __repr__(self):
        text = f"{self.name}: {self.avg:.4f}"
        return text

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group["lr"]


# Datasets

In [None]:
import shutil

def move_images(source, destination=CFG.image_path, prefix=""):
    files = os.listdir(source)
    L = len(files)
    for file in files:
        source_path = os.path.join(source, file)
        destination_path = os.path.join(destination, prefix + file)
        shutil.copyfile(source_path, destination_path)

    print(f"{L} images copied from {source} to {destination} successfully.")

In [None]:
%cd ..
!mkdir tmp
%cd tmp
!mkdir images

In [None]:
!pwd

### coco

In [None]:
COCO_PREFIX = "coco-"

In [None]:
!wget http://images.cocodataset.org/zips/train2014.zip
!wget http://images.cocodataset.org/zips/val2014.zip
!unzip -q train2014.zip
!rm train2014.zip
!unzip -q val2014.zip
!rm val2014.zip

In [None]:
print(os.listdir('.'))
print(len(os.listdir('val2014')))
print(len(os.listdir('train2014')))

In [None]:
!pip install jsonlines
import jsonlines

In [None]:
!gdown 10AbcXZaQmgUeKz6aRsHNV8HFMAni2qc4
!gdown 1Ldvmxa9sykv805nJ4-PqtjbUvLKcrgz7

In [None]:
L = 82783
i = 0
for file in os.listdir("train2014"):
    i += 1
    print(f"\r{i}/{L}", end='')

    id = int(file[-10:-4])

    shutil.copy(
        os.path.join("train2014", file),
        os.path.join("images", COCO_PREFIX + str(id) + file[-4:]),
    )

print()

L = 40504

i = 0
for file in os.listdir("val2014"):
    i += 1
    print(f"\r{i}/{L}", end='')

    id = int(file[-10:-4])

    shutil.copy(
        os.path.join("val2014", file),
        os.path.join("images", COCO_PREFIX + str(id) + file[-4:]),
    )

In [None]:
image_list = []
caption_list = []

with jsonlines.open('cocopathT.jsonl') as reader:
    L = 566435
    i = 0
    for line in reader:
        i += 1
        print(f"\r{i}/{L}", end='')

        image_list.append(line['image'])
        caption_list.append(line['caption'])

df_coco = pd.DataFrame({'image': image_list, 'caption': caption_list})
df_coco.head(20)

In [None]:
image_list = []
caption_list = []

with jsonlines.open('cocopathD.jsonl') as reader:
    L = 25000
    i = 0
    for line in reader:
        i += 1
        print(f"\r{i}/{L}", end='')

        image_list.append(line['image'])
        caption_list.append(line['caption'])

df_coco_val = pd.DataFrame({'image': image_list, 'caption': caption_list})
df_coco_val.head(20)

### flickr

In [None]:
!mkdir flickr
%cd flickr
FLICKR_PREFIX = "flickr-"

In [None]:
!pip install kaggle
!export KAGGLE_USERNAME=ducngg
!export KAGGLE_KEY=179be8b8664502e51504151a35dec9c4

In [None]:
!kaggle datasets download adityajn105/flickr8k
!kaggle datasets download trungit/flickr8k-vi-caps

In [None]:
!unzip flickr8k-vi-caps.zip
!rm flickr8k-vi-caps.zip

In [None]:
!unzip -q flickr8k.zip
!rm flickr8k.zip

In [None]:
with open('captions_vi.txt', 'r', encoding='utf-8') as f:
    lines = f.readlines()

image_list = []
caption_list = []

for line in lines:
    image, caption = line.strip().split('\t')
    image_list.append(FLICKR_PREFIX + image.strip())
    caption_list.append(caption.strip())

df_flickr = pd.DataFrame({'image': image_list, 'caption': caption_list})
df_flickr.head(20)

In [None]:
move_images('Images', destination='../images', prefix=FLICKR_PREFIX)

In [None]:
%cd ..

### uit-viic

In [None]:
!mkdir uit-viic
%cd uit-viic
UITVIIC_PREFIX = "uitviic-"

In [None]:
!gdown 1YexKrE6o0UiJhFWpE8M5LKoe6-k3AiM4

In [None]:
!unzip UIT-ViIC-20200417T021508Z-001.zip
!rm UIT-ViIC-20200417T021508Z-001.zip

In [None]:
with open('UIT-ViIC/uitviic_captions_train2017.json', 'r', encoding='utf-8') as f:
    data = json.load(f)
images = data['images']
captions = data['annotations']

In [None]:
images = list(map(lambda item: {'id': item['id'],'image': item['coco_url']}, images))
captions = list(map(lambda item: {'id': item['image_id'], 'caption': item['caption']}, captions))

images_df = pd.DataFrame(images)
captions_df = pd.DataFrame(captions)

df_uitviic = pd.merge(images_df, captions_df, on='id')
df_uitviic

In [None]:
df_uitviic = df_uitviic.drop(['id'], axis=1)
df_uitviic

In [None]:
%cd ..

### ktvic

In [None]:
!mkdir ktvic
%cd ktvic
KTVIC_PREFIX = "ktvic-"

In [None]:
!gdown 11bwkfj8Qr9AIGSDe_xxlwj_asgwa8nNL
!gdown 1xi5uyB_8obamsnv0COq4wsT90DBQg7-h
!gdown 1GM8uMB4P93TPE6lZe2YcQnexzBxkl8c-
!gdown 1ntMeBhf-Nut88fJfXEnEOvsnjG_Y7gKo

In [None]:
!unzip -q train-images.zip
!rm train-images.zip
!unzip -q public-test-images.zip
!rm public-test-images.zip

In [None]:
with open('train_data.json', 'r', encoding='utf-8') as f:
    data = json.load(f)

images = data['images']
captions = data['annotations']

In [None]:
images = list(map(lambda item: {'id': item['id'], 'image': KTVIC_PREFIX + item['filename']}, images))
captions = list(map(lambda item: {'id': item['image_id'], 'caption': item['caption']}, captions))

images_df = pd.DataFrame(images)
captions_df = pd.DataFrame(captions)

df_ktvic = pd.merge(images_df, captions_df, on='id')
df_ktvic

In [None]:
df_ktvic = df_ktvic.drop(['id'], axis=1)
df_ktvic

In [None]:
with open('test_data.json', 'r', encoding='utf-8') as f:
    data = json.load(f)

images = data['images']
captions = data['annotations']

In [None]:
images = list(map(lambda item: {'id': item['id'], 'image': KTVIC_PREFIX + item['filename']}, images))
captions = list(map(lambda item: {'id': item['image_id'], 'caption': item['caption']}, captions))

images_df = pd.DataFrame(images)
captions_df = pd.DataFrame(captions)

df_ktvic_val = pd.merge(images_df, captions_df, on='id')
df_ktvic_val 

In [None]:
df_ktvic_val = df_ktvic_val.drop(['id'], axis=1)
df_ktvic_val

In [None]:
move_images('train-images', destination='../images', prefix=KTVIC_PREFIX)
move_images('public-test-images', destination='../images', prefix=KTVIC_PREFIX)

In [None]:
%cd ..

### open-viic

In [None]:
!mkdir open-ViIC
%cd open-ViIC
OPENVIIC_PREFIX = "openviic-"

In [None]:
!gdown 10E0cuWBaTgTvRj-bibTdfyABDDxCjXci
!gdown 1rovWCFcA6s0CXQD6SkHisUjgcdheHVQy
!gdown 1YfLMn-yRYN2ZT2CtdHeeRMPol0mVcWtw
!gdown 1tedlrYhlBUMV7TeurW2DhbpSrX_1F9QD
!unzip -q images.zip

In [None]:
with open('uit-openviic-annotation-train.json', 'r') as f:
    json_data = json.load(f)

data = []
for image_path, image_data in json_data.items():
    for caption in image_data["captions"]:
        data.append({"image": OPENVIIC_PREFIX + image_path, "caption": caption})

df_openviic = pd.DataFrame(data)
df_openviic

In [None]:
with open('uit-openviic-annotation-dev.json', 'r') as f:
    json_data = json.load(f)

data = []
for image_path, image_data in json_data.items():
    for caption in image_data["captions"]:
        data.append({"image": OPENVIIC_PREFIX + image_path, "caption": caption})

df_openviic_val = pd.DataFrame(data)
df_openviic_val

In [None]:
with open('uit-openviic-annotation-test.json', 'r') as f:
    json_data = json.load(f)

data = []
for image_path, image_data in json_data.items():
    for caption in image_data["captions"]:
        data.append({"image": OPENVIIC_PREFIX + image_path, "caption": caption})

df_openviic_test = pd.DataFrame(data)
df_openviic_test

In [None]:
move_images('images', destination='../images', prefix=OPENVIIC_PREFIX)

In [None]:
%cd ..

### Summary

In [None]:
len(os.listdir("images"))

In [None]:
data_df = pd.concat([df_flickr, df_ktvic, df_openviic, df_coco], ignore_index=True)
data_df

In [None]:
# data_df_val = (df_ktvic_val, df_openviic_val, df_coco_val)

In [None]:
for index, row in data_df.iterrows():
    image = row['image']
    if not os.path.exists("images/" + image):
        print("Fail", index)

In [None]:
# Count null values in each column
null_counts = data_df.isnull().sum()

# Total count of null values in the DataFrame
total_null_count = null_counts.sum()
total_null_count

# Components

In [None]:
class CLIPDataset(torch.utils.data.Dataset):
    def __init__(self, image_filenames, captions, tokenizer, transforms):
        """
        image_filenames and cations must have the same length; so, if there are
        multiple captions for each image, the image_filenames must have repetitive
        file names
        """

        self.image_filenames = image_filenames
        self.captions = list(captions)

        # segment Vietnamese words
        if CFG.segmenter:
            self.captions = [CFG.segmenter(caption) for caption in captions]
        print(f"Sample caption: {self.captions[0]}")

        self.encoded_captions = tokenizer(
            self.captions, padding=True, truncation=True, max_length=CFG.max_length
        )
        self.transforms = transforms

    def __len__(self):
        return len(self.captions)

    def __getitem__(self, idx):
        item = {
            key: torch.tensor(values[idx])
            for key, values in self.encoded_captions.items()
        }

        if self.image_filenames[idx].startswith("http"):
            response = requests.get(self.image_filenames[idx])
            if response.status_code == 200:
                image_data = np.frombuffer(response.content, dtype=np.uint8)
                image = cv2.imdecode(image_data, cv2.IMREAD_COLOR)
                image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            else:
                print("Failed to download image from URL:", self.image_filenames[idx])
                return None
        else:
            image = cv2.imread(f"{CFG.image_path}/{self.image_filenames[idx]}")
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        # Apply transformations
        image = self.transforms(image=image)['image']
        item['image'] = torch.tensor(image).permute(2, 0, 1).float()
        item['caption'] = self.captions[idx]

        return item


def get_transforms(mode="train"):
    if mode == "train":
        return A.Compose(
            [
                A.Resize(CFG.size, CFG.size, always_apply=True),
                A.Normalize(max_pixel_value=255.0, always_apply=True),
            ]
        )
    else:
        return A.Compose(
            [
                A.Resize(CFG.size, CFG.size, always_apply=True),
                A.Normalize(max_pixel_value=255.0, always_apply=True),
            ]
        )
    

## Image Encoder

In [None]:
class ImageEncoder(nn.Module):
    """
    Encode images
    """

    def __init__(
        self, model_name=CFG.image_encoder_model, pretrained=CFG.image_encoder_pretrained, trainable=CFG.image_encoder_pretrained
    ):
        super().__init__()
        
        self.model = timm.create_model(
            model_name, pretrained, num_classes=0, global_pool="avg"
        )
        for p in self.model.parameters():
            p.requires_grad = trainable

    def forward(self, x):
        return self.model(x)

## Text Encoder

**CLS** and **SEP**: the start and end of a sentence. To grab the whole representation of a sentence (as the related BERT and DistilBERT papers point out) we use the final representations of the CLS token.

In [None]:
class TextEncoder(nn.Module):
    def __init__(self, model_name=CFG.text_encoder_model, pretrained=CFG.text_encoder_pretrained, trainable=CFG.text_encoder_trainable):
        super().__init__()
        
        if pretrained:
            self.model = AutoModel.from_pretrained(model_name)
        else:
            self.model = AutoModel(config=AutoConfig.from_pretrained(model_name))

        for p in self.model.parameters():
            p.requires_grad = trainable

        # we are using the CLS token hidden representation as the sentence's embedding
        self.target_token_idx = 0

    def forward(self, input_ids, attention_mask):
        output = self.model(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = output.last_hidden_state
        return last_hidden_state[:, self.target_token_idx, :]

## Projection Head

In [None]:

# For 1 layer
class ProjectionHead(nn.Module):
    def __init__(
        self,
        embedding_dim,
        projection_dim=CFG.projection_dim,
        dropout=CFG.dropout,
        n_layers=CFG.num_projection_layers
    ):
        super().__init__()
        self.projection = nn.Linear(embedding_dim, projection_dim)
        self.gelu = nn.GELU()
        self.fc = nn.Linear(projection_dim, projection_dim)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(projection_dim)

    def forward(self, x):
        projected = self.projection(x)
        x = self.gelu(projected)
        x = self.fc(x)
        x = self.dropout(x)
        x = x + projected
        x = self.layer_norm(x)
        return x

    
# # For 2 layer above
# class ProjectionHead(nn.Module):
#     def __init__(
#         self,
#         embedding_dim,
#         projection_dim=CFG.projection_dim,
#         dropout=CFG.dropout,
#         n_layers=CFG.num_projection_layers
#     ):
#         super().__init__()
#         self.projection_layers = nn.ModuleList(
#             [nn.Linear(embedding_dim, projection_dim)] + \
#             [nn.Linear(projection_dim, projection_dim) for _ in range(n_layers-1)]
#         )
#         self.gelu = nn.GELU()
#         self.fc = nn.Linear(projection_dim, projection_dim)
#         self.dropout = nn.Dropout(dropout)
#         self.layer_norm = nn.LayerNorm(projection_dim)

#     def forward(self, x):
#         projected = x
#         for projection_layer in self.projection_layers:
#             projected = projection_layer(projected)
#             projected = self.gelu(projected)
#         x = self.fc(projected)
#         x = self.dropout(x)
#         x = x + projected
#         x = self.layer_norm(x)
#         return x


## CLIP

In [None]:
class CLIPModel(nn.Module):
    def __init__(
        self,
        temperature=CFG.temperature,
        image_embedding=CFG.image_embedding,
        text_embedding=CFG.text_embedding,
    ):
        super().__init__()
        self.image_encoder = ImageEncoder()
        self.text_encoder = TextEncoder()
        self.image_projection = ProjectionHead(embedding_dim=image_embedding)
        self.text_projection = ProjectionHead(embedding_dim=text_embedding)
        self.temperature = temperature

    def forward(self, batch):
        # Getting Image and Text Features
        image_features = self.image_encoder(
            batch["image"]
        )
        text_features = self.text_encoder(
            input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
        )
        # Getting Image and Text Embeddings (same dimension)
        image_embeddings = self.image_projection(image_features)
        text_embeddings = self.text_projection(text_features)

        # Calculating the Loss
        logits = (text_embeddings @ image_embeddings.T) / self.temperature
        images_similarity = image_embeddings @ image_embeddings.T
        texts_similarity = text_embeddings @ text_embeddings.T
        targets = F.softmax(
            (images_similarity + texts_similarity) / 2 * self.temperature, dim=-1
        )
        texts_loss = cross_entropy(logits, targets, reduction='none')
        images_loss = cross_entropy(logits.T, targets.T, reduction='none')
        loss =  (images_loss + texts_loss) / 2.0 # shape: (batch_size)
        return loss.mean()


def cross_entropy(preds, targets, reduction='none'):
    log_softmax = nn.LogSoftmax(dim=-1)
    loss = (-targets * log_softmax(preds)).sum(1)
    if reduction == "none":
        return loss
    elif reduction == "mean":
        return loss.mean()

In [None]:
# A simple Example

batch_size = 4
dim = 256
embeddings = torch.randn(batch_size, dim)
out = embeddings @ embeddings.T
print(F.softmax(out, dim=-1))


# Train

In [None]:
def make_train_valid_dfs():
    global data_df
    valid_df = pd.concat([df_ktvic_val, df_openviic_val, df_coco_val], ignore_index=True)

    print(f"Train: {len(data_df)} rows")
    print(f"Valid: {len(valid_df)} rows")
    return data_df, valid_df

def build_loaders(dataframe, tokenizer, mode):
    transforms = get_transforms(mode=mode)
    dataset = CLIPDataset(
        dataframe["image"].values,
        dataframe["caption"].values,
        tokenizer=tokenizer,
        transforms=transforms,
    )
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=CFG.batch_size,
        num_workers=CFG.num_workers,
        shuffle=True if mode == "train" else False,
    )
    return dataloader

Create train/valid loader

In [None]:
tokenizer = AutoTokenizer.from_pretrained(CFG.text_tokenizer)
train_df, valid_df = make_train_valid_dfs()
train_loader = build_loaders(train_df, tokenizer, mode="train")
valid_loader = build_loaders(valid_df, tokenizer, mode="valid")

In [None]:
def train_epoch(model, train_loader, optimizer, lr_scheduler, step):
    loss_meter = AvgMeter()
    tqdm_object = tqdm(train_loader, total=len(train_loader))
    for batch in tqdm_object:
        batch = {k: v.to(CFG.device) for k, v in batch.items() if k != "caption"}
        loss = model(batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if step == "batch":
            lr_scheduler.step()

        count = batch["image"].size(0)
        loss_meter.update(loss.item(), count)

        tqdm_object.set_postfix(train_loss=loss_meter.avg, lr=get_lr(optimizer))
    return loss_meter


def valid_epoch(model, valid_loader):
    loss_meter = AvgMeter()

    tqdm_object = tqdm(valid_loader, total=len(valid_loader))
    for batch in tqdm_object:
        batch = {k: v.to(CFG.device) for k, v in batch.items() if k != "caption"}
        loss = model(batch)
        print(loss)
        count = batch["image"].size(0)
        loss_meter.update(loss.item(), count)

        tqdm_object.set_postfix(valid_loss=loss_meter.avg)
    return loss_meter


def main(train_loader, valid_loader):
    
    model = CLIPModel().to(CFG.device)
    
    print(model.image_encoder.parameters())
    print(model.text_encoder.parameters())
    
    params = [
        {"params": model.image_encoder.parameters(), "lr": CFG.image_encoder_lr},
        {"params": model.text_encoder.parameters(), "lr": CFG.text_encoder_lr},
        {"params": itertools.chain(
            model.image_projection.parameters(), model.text_projection.parameters()
        ), "lr": CFG.head_lr, "weight_decay": CFG.weight_decay}
    ]
    optimizer = torch.optim.AdamW(params, weight_decay=0.)
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode="min", patience=CFG.patience, factor=CFG.factor
    )
    step = "epoch"

    best_loss = float('inf')
    for epoch in range(CFG.epochs):
        print(f"Epoch: {epoch + 1}")
        model.train()
        train_loss = train_epoch(model, train_loader, optimizer, lr_scheduler, step)
        model.eval()
        with torch.no_grad():
            valid_loss = valid_epoch(model, valid_loader)

        if valid_loss.avg < best_loss:
            best_loss = valid_loss.avg
            torch.save(model.state_dict(), "best.pt")
            print("Saved Best Model!")

        lr_scheduler.step(valid_loss.avg)


In [None]:
%cd /kaggle/working

In [None]:
main(train_loader, valid_loader)

In [None]:
from IPython.display import FileLink
FileLink(r'best.pt')

In [None]:
# !gdown 1jQqI9YjJFDPORcPQ3Q2fb-fCFz1lE_YF -O best.pt

In [None]:
valid_loss = 0
def val():
    global valid_loss
    _, valid_df = make_train_valid_dfs()

    tokenizer = AutoTokenizer.from_pretrained(CFG.text_tokenizer)
    valid_loader = build_loaders(valid_df, tokenizer, mode="valid")
    
    model = CLIPModel().to(CFG.device)
    model.load_state_dict(torch.load("best.pt", map_location=CFG.device))

    with torch.no_grad():
        valid_loss = valid_epoch(model, valid_loader)
        print(valid_loss)
        
# val()

In [None]:
EXPERIMENTS = {
    'text-encoder': ['PhoBERT-base', 'PhoBERT-large'],
    'image-encoder': ['ViT-base', 'ResNet'],
    'epochs': [2, 3],
    'hidden-size': [256, 512]
}

# Inference

### Getting Image and Text Embeddings

In [None]:
def get_image_embeddings(valid_df, model_path):
    tokenizer = AutoTokenizer.from_pretrained(CFG.text_tokenizer)
    valid_loader = build_loaders(valid_df, tokenizer, mode="valid")

    model = CLIPModel().to(CFG.device)
    model.load_state_dict(torch.load(model_path, map_location=CFG.device))
    model.eval()

    valid_image_embeddings = []
    with torch.no_grad():
        for batch in tqdm(valid_loader):
            image_features = model.image_encoder(batch["image"].to(CFG.device))
            image_embeddings = model.image_projection(image_features)
            valid_image_embeddings.append(image_embeddings)
    return model, torch.cat(valid_image_embeddings)

def get_text_embeddings(valid_df, model_path):
    tokenizer = AutoTokenizer.from_pretrained(CFG.text_tokenizer)
    valid_loader = build_loaders(valid_df, tokenizer, mode="valid")

    model = CLIPModel().to(CFG.device)
    model.load_state_dict(torch.load(model_path, map_location=CFG.device))
    model.eval()

    valid_text_embeddings = []
    with torch.no_grad():
        for batch in tqdm(valid_loader):
            input_ids = batch['input_ids'].to(CFG.device)
            attention_mask = batch['attention_mask'].to(CFG.device)

            text_features = model.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
            text_embeddings = model.text_projection(text_features)
            valid_text_embeddings.append(text_embeddings)
    return model, torch.cat(valid_text_embeddings)

In [None]:
_, valid_df = make_train_valid_dfs()
model, image_embeddings = get_image_embeddings(valid_df, "best.pt")

_, image_embeddings_ktvic_val = get_image_embeddings(df_ktvic_val, "best.pt")
_, image_embeddings_openviic_val = get_image_embeddings(df_openviic_val, "best.pt")
_, image_embeddings_coco_val = get_image_embeddings(df_coco_val, "best.pt")

In [None]:
_, text_embeddings = get_text_embeddings(valid_df, "phoclip-666k.pt")

_, text_embeddings_ktvic_val = get_text_embeddings(df_ktvic_val, "best.pt")
_, text_embeddings_openviic_val = get_text_embeddings(df_openviic_val, "best.pt")
_, text_embeddings_coco_val = get_text_embeddings(df_coco_val, "best.pt")

## Evaluate

In [None]:
valid_df

In [None]:
sum(p.numel() for p in model.parameters())

In [None]:
import matplotlib.pyplot as plt
from sklearn.metrics.pairwise import cosine_similarity

In [None]:
# !gdown 1UC2TCE4qL5-wHy9JXUTB2dCGFLX4KE6O -O aodai.jpg

In [None]:
def get_tensor_from_path(path):
    image = cv2.imread(path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = get_transforms()(image=image)['image']
    res = torch.tensor(image).permute(2, 0, 1).float()
    print(res)
    return res

In [None]:
def remove_duplicates(lst):
    seen = {}
    result = []
    for item in lst:
        if item not in seen:
            seen[item] = True
            result.append(item)
    return result

def find_matches(model, database_embeddings, image_filenames, text=None, image_path=None, n=25, k=200, unique=True):
    
    text_embeddings_n = None
    if text is not None:
        text = CFG.segmenter(text)
        tokenizer = AutoTokenizer.from_pretrained(CFG.text_tokenizer)

        encoded_text = tokenizer([text])
    
        print(encoded_text)

        batch = {
            key: torch.tensor(values).to(CFG.device)
            for key, values in encoded_text.items()
        }
        with torch.no_grad():
            text_features = model.text_encoder(
                input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
            )
            text_embeddings = model.text_projection(text_features)
    
        text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1)
        
        query_embeddings_n = text_embeddings_n
    
    image_embeddings_n = None
    if image_path is not None:
        image_tensor = get_tensor_from_path(image_path)
        image_features = model.image_encoder(image_tensor.unsqueeze(0).to(CFG.device))
        image_embeddings = model.image_projection(image_features)
        
        image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1)
        
    
    if image_embeddings_n is not None and text_embeddings_n is not None:
        query_embeddings_n = torch.mean(torch.stack([image_embeddings_n, text_embeddings_n]), dim=0)
        name = text + image_path + '.res.png'
    elif image_embeddings_n is not None:
        query_embeddings_n = image_embeddings_n
        name = image_path + '.res.png'
    elif text_embeddings_n is not None:
        query_embeddings_n = text_embeddings_n
        name = text + '.res.png'
    else:
        raise Exception("No query")

    database_embeddings_n = F.normalize(database_embeddings, p=2, dim=-1)    
    dot_similarity = query_embeddings_n @ database_embeddings_n.T

    values, indices = torch.topk(dot_similarity.squeeze(0), k)
    print(values, indices)
    matches = [image_filenames[idx] for idx in indices[::]]
    if unique:
        matches = remove_duplicates(matches)
    matches = matches[:n]
        
    _, axes = plt.subplots(5, 5, figsize=(10, 10))
    for match, ax in zip(matches, axes.flatten()):
        
        image = cv2.imread(f"{CFG.image_path}/{match}")
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        ax.imshow(image)
        ax.axis("off")

    plt.savefig(name)
    plt.show()

def eval_accuracy(text_embeddings, image_embeddings, df, k=5):
    similarity_matrix = torch.matmul(text_embeddings, image_embeddings.T)
    
    _, topk_indices = similarity_matrix.topk(250, dim=1)
    
    correct_predictions = 0
    
    for i in range(text_embeddings.shape[0]):
        pred = df.iloc[i]['image']
        top_truths = remove_duplicates(list(df.iloc[topk_indices[i].cpu()]['image']))
        if pred in top_truths[:k]:
            correct_predictions += 1
        # if i in topk_indices[i]:
            # correct_predictions += 1
    
    
    topk_accuracy = correct_predictions / text_embeddings.shape[0]
    
    return topk_accuracy
def eval_avg_cossim(text_embeddings, image_embeddings):
    cosine_sim_matrix = cosine_similarity(text_embeddings.cpu(), image_embeddings.cpu())
    average_cosine_sim = cosine_sim_matrix.mean()
    return average_cosine_sim

This is how we use this function. Aaaannnndddd the results:

In [None]:
valid_df.iloc[0]['caption']

In [None]:
VALIDS = [
    {
        'name': 'all',
        'T': text_embeddings,
        'I': image_embeddings,
        'df': valid_df
    },
    {
        'name': 'ktvic',
        'T': text_embeddings_ktvic_val,
        'I': image_embeddings_ktvic_val,
        'df': df_ktvic_val
    },
    {
        'name': 'openviic',
        'T': text_embeddings_openviic_val,
        'I': image_embeddings_openviic_val,
        'df': df_openviic_val
    },
    {
        'name': 'coco',
        'T': text_embeddings_coco_val,
        'I': image_embeddings_coco_val,
        'df': df_coco_val
    }
]

In [None]:
top5_accuracies = {}
for SET in VALIDS:
    top5_accuracy = eval_accuracy(SET['T'], SET['I'], SET['df'], k=10)
    top5_accuracies[SET['name']] = top5_accuracy
for name, value in top5_accuracies.items():
    print(f"{name}: {value}")

In [None]:
cosine_similarities = {}
for SET in VALIDS:
    cos_sim = eval_avg_cossim(SET['T'], SET['I'])
    cosine_similarities[SET['name']] = cos_sim
for name, value in cosine_similarities.items():
    print(f"{name}: {value}")

In [None]:
cos_sim = eval_avg_cossim(text_embeddings_ktvic_val[700:701], text_embeddings_ktvic_val[100:101])
cos_sim

In [None]:
!gdown 1mBkJv0Z-3BKpVyHjkn1NVPW3CtzIaZ8p 

In [None]:
!gdown 1FjVaYDIykU2EhWLJZMsjGoyPeb5cxd4i

In [None]:
find_matches(
    model,
    image_embeddings,
    image_filenames=valid_df['image'].values,
    text="đi siêu thị",
    image_path='cute-girl.jpg',
    n=25
)

# nhiều người - chua.jpg 

![](./images/dance.png)