In [1]:
import json
import time
from tqdm import tqdm
import torch
from torch import nn
import torch.nn.functional as F
import pandas as pd
import torch.optim as optim
from transformers import logging
import matplotlib.pyplot as plt
from PIL import Image
import requests
from transformers import (
    VisionTextDualEncoderModel,
    VisionTextDualEncoderProcessor,
    ViTFeatureExtractor,
    BertTokenizer,
)
import numpy as np
logging.set_verbosity_error()

  from .autonotebook import tqdm as notebook_tqdm


[2023-10-19 14:48:59,141] [INFO] [real_accelerator.py:133:get_accelerator] Setting ds_accelerator to cuda (auto detect)


In [2]:
class CFG:
    debug = False
    captions_path = "."
    max_text_tokens_length = 128
    text_backbone = 'bert-base-uncased'
    image_backbone = 'google/vit-base-patch16-224'
    image_path = "./dataset/flickr30k_images"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    batch_size = 32
    max_epochs = 20
    max_bad_epochs = 9
    patience = 3
    factor = 0.1

In [3]:
df = pd.read_csv("./dataset/flickr30k_images/results.csv", delimiter="|")
df.columns = ['image', 'caption_number', 'caption']
df['caption'] = df['caption'].str.lstrip()
df['caption_number'] = df['caption_number'].str.lstrip()
df.loc[19999, 'caption_number'] = "4"
df.loc[19999, 'caption'] = "A dog runs across the grass ."
ids = [id_ for id_ in range(len(df) // 5) for i in range(5)]
df['id'] = ids
df.to_csv("captions.csv", index=False)
df.head()
def make_train_valid_dfs():
    dataframe = pd.read_csv(f"{CFG.captions_path}/captions.csv")
    max_id = dataframe["id"].max() + 1 if not CFG.debug else 100
    image_ids = np.arange(0, max_id)
    np.random.seed(42)
    valid_ids = np.random.choice(
        image_ids, size=int(0.2 * len(image_ids)), replace=False
    )
    train_ids = [id_ for id_ in image_ids if id_ not in valid_ids]
    train_dataframe = dataframe[dataframe["id"].isin(train_ids)].reset_index(drop=True)
    valid_dataframe = dataframe[dataframe["id"].isin(valid_ids)].reset_index(drop=True)
    return train_dataframe, valid_dataframe

In [4]:
class CLIPDataset(torch.utils.data.Dataset):  
    def __init__(self,image_files, captions, processor):
        self.image_files = image_files
        self.captions = list(captions) 
        self.processor = processor
    def __getitem__(self,idx):
        caption = self.captions[idx]
        image = Image.open(f"{CFG.image_path}/{self.image_files[idx]}") 
        encoded_pair = self.processor(text=[caption], images=[image], return_tensors="pt", max_length=CFG.max_text_tokens_length, padding='max_length', truncation=True)
        return encoded_pair
    
    def __len__(self):
        return len(self.captions)

def collate_fn(batch):
    batch = list(filter(lambda x: x is not None, batch))
    return torch.utils.data.dataloader.default_collate(batch)

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
processor = VisionTextDualEncoderProcessor(feature_extractor, tokenizer)
train_df, valid_df = make_train_valid_dfs()
train_ds = CLIPDataset(train_df["image"].values,train_df["caption"].values, processor=processor)
valid_ds = CLIPDataset(valid_df["image"].values,valid_df["caption"].values, processor=processor)
train_dataloader = torch.utils.data.DataLoader(train_ds, collate_fn=collate_fn, batch_size=CFG.batch_size)
val_dataloader = torch.utils.data.DataLoader(valid_ds, collate_fn=collate_fn, batch_size=CFG.batch_size)



In [5]:
def train_epoch(model, train_loader, optimizer, epoch, max_epochs):
    model.train()
    nb_batches = len(train_loader)
    tqdm_object = tqdm(train_loader, total=len(train_loader))   
    epoch_loss = 0.0
    for i, batch in enumerate(tqdm_object):
      outputs = model(
          input_ids=batch['input_ids'].squeeze().to(CFG.device),
          attention_mask=batch['attention_mask'].squeeze().to(CFG.device),
          pixel_values=batch['pixel_values'].squeeze().to(CFG.device),
          return_loss=True)
      loss, logits_per_image = outputs.loss, outputs.logits_per_image  # this is the image-text similarity score
      epoch_loss += loss.item()
      loss.backward()
      optimizer.step()
      tqdm_object.set_postfix(
          batch="{}/{}".format(i+1,nb_batches),
          train_loss=loss.item(),
          lr=get_lr(optimizer)
          )
    epoch_loss = epoch_loss / nb_batches
    return epoch_loss

def valid_epoch(model, dev_loader, epoch, max_epochs):
    model.eval()
    nb_batches = len(dev_loader)
    tqdm_object = tqdm(dev_loader, total=len(dev_loader))
    epoch_loss = 0.0   
    for i, batch in enumerate(tqdm_object):
      outputs = model(
          input_ids=batch['input_ids'].squeeze().to(CFG.device),
          attention_mask=batch['attention_mask'].squeeze().to(CFG.device),
          pixel_values=batch['pixel_values'].squeeze().to(CFG.device),
          return_loss=True)
      loss, logits_per_image = outputs.loss, outputs.logits_per_image  # this is the image-text similarity score
      epoch_loss += loss.item()
      tqdm_object.set_postfix(
          batch="{}/{}".format(i+1,nb_batches),
          dev_loss=loss.item(),
          )
    epoch_loss = epoch_loss / nb_batches
    return epoch_loss

def learning_loop(model):
    model.to(CFG.device)
    optimizer = torch.optim.AdamW(model.parameters(), weight_decay=0.)
    lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", patience=CFG.patience, factor=CFG.factor)

    best_dev_score = float('inf')
    train_history = []
    dev_history = []
    nb_bad_epochs = 0

    print("Learning phase")
    print('Used device:', CFG.device)
    print("--------------")
    for epoch in range(1, CFG.max_epochs+1):

        print("Epoch {:03d}/{:03d}".format(epoch, CFG.max_epochs))

        if nb_bad_epochs >= CFG.max_bad_epochs:
            print("Epoch {:03d}/{:03d}: exiting training after too many bad epochs.".format(epoch, CFG.max_epochs))
            torch.save(model.state_dict(), "final.pt")
            break

        else:

            epoch_start_time = time.time()

            epoch_train_loss = train_epoch(model=model, train_loader=train_dataloader, optimizer=optimizer, epoch=epoch, max_epochs=CFG.max_epochs)
            epoch_dev_score = valid_epoch(model=model, dev_loader=val_dataloader, epoch=epoch, max_epochs=CFG.max_epochs)

            duration = time.time() - epoch_start_time

            lr_scheduler.step(epoch_dev_score)

            train_history.append(epoch_train_loss)
            dev_history.append(epoch_dev_score)

            if epoch_dev_score < best_dev_score:
                nb_bad_epochs = 0
                best_dev_score = epoch_dev_score
                torch.save(model.state_dict(), "best.pt")
                print("Finished epoch {:03d}/{:03d} - Train loss: {:.7f} - Valid loss: {:.7f} - SAVED (NEW) BEST MODEL. Duration: {:.3f} s".format(
                epoch, CFG.max_epochs, epoch_train_loss, epoch_dev_score, duration))
            else:
                nb_bad_epochs += 1
                print("Finished epoch {:03d}/{:03d} - Train loss: {:.7f} - Valid loss: {:.7f} - NUMBER OF BAD EPOCH.S: {}. Duration: {:.3f} s".format(
                epoch, CFG.max_epochs, epoch_train_loss, epoch_dev_score, nb_bad_epochs, duration))

    history = {'train':train_history,'dev':dev_history}
    return history
  
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group["lr"]

def plot_history(history):
    train_history = history['train']
    dev_history = history['dev']
    plt.plot(list(range(1, len(train_history)+1)), train_history, label="train loss")
    plt.plot(list(range(1, len(dev_history)+1)), dev_history, label="dev loss")
    plt.xticks(list(range(1, len(train_history)+1)))
    plt.xlabel("epoch")
    plt.legend()

In [6]:
clip = VisionTextDualEncoderModel.from_vision_text_pretrained(CFG.image_backbone, CFG.text_backbone)

In [7]:
%%time
history = learning_loop(clip)

Learning phase
Used device: cuda
--------------
Epoch 001/020


100%|██████████| 3973/3973 [42:14<00:00,  1.57it/s, batch=3973/3973, lr=0.001, train_loss=3.43]
100%|██████████| 994/994 [06:43<00:00,  2.46it/s, batch=994/994, dev_loss=1.39]


Finished epoch 001/020 - Train loss: 3.4663736 - Valid loss: 3.4636442 - SAVED (NEW) BEST MODEL. Duration: 2937.668 s
Epoch 002/020


100%|██████████| 3973/3973 [40:57<00:00,  1.62it/s, batch=3973/3973, lr=0.001, train_loss=3.43]
100%|██████████| 994/994 [06:02<00:00,  2.74it/s, batch=994/994, dev_loss=1.39]


Finished epoch 002/020 - Train loss: 3.4657281 - Valid loss: 3.4636442 - SAVED (NEW) BEST MODEL. Duration: 2819.360 s
Epoch 003/020


100%|██████████| 3973/3973 [39:22<00:00,  1.68it/s, batch=3973/3973, lr=0.001, train_loss=3.43]
100%|██████████| 994/994 [06:07<00:00,  2.70it/s, batch=994/994, dev_loss=1.39]


Finished epoch 003/020 - Train loss: 3.4657280 - Valid loss: 3.4636442 - NUMBER OF BAD EPOCH.S: 1. Duration: 2729.844 s
Epoch 004/020


100%|██████████| 3973/3973 [39:21<00:00,  1.68it/s, batch=3973/3973, lr=0.001, train_loss=3.43]
100%|██████████| 994/994 [05:53<00:00,  2.81it/s, batch=994/994, dev_loss=1.39]


Finished epoch 004/020 - Train loss: 3.4657282 - Valid loss: 3.4636442 - NUMBER OF BAD EPOCH.S: 2. Duration: 2715.327 s
Epoch 005/020


100%|██████████| 3973/3973 [39:16<00:00,  1.69it/s, batch=3973/3973, lr=0.001, train_loss=3.43]
100%|██████████| 994/994 [05:52<00:00,  2.82it/s, batch=994/994, dev_loss=1.39]


Finished epoch 005/020 - Train loss: 3.4657282 - Valid loss: 3.4636442 - NUMBER OF BAD EPOCH.S: 3. Duration: 2709.018 s
Epoch 006/020


100%|██████████| 3973/3973 [39:12<00:00,  1.69it/s, batch=3973/3973, lr=0.0001, train_loss=3.43]
100%|██████████| 994/994 [05:51<00:00,  2.83it/s, batch=994/994, dev_loss=1.39]


Finished epoch 006/020 - Train loss: 3.4657282 - Valid loss: 3.4636442 - NUMBER OF BAD EPOCH.S: 4. Duration: 2704.367 s
Epoch 007/020


100%|██████████| 3973/3973 [39:03<00:00,  1.70it/s, batch=3973/3973, lr=0.0001, train_loss=3.43]
100%|██████████| 994/994 [05:50<00:00,  2.83it/s, batch=994/994, dev_loss=1.39]


Finished epoch 007/020 - Train loss: 3.4657282 - Valid loss: 3.4636442 - NUMBER OF BAD EPOCH.S: 5. Duration: 2694.562 s
Epoch 008/020


100%|██████████| 3973/3973 [39:04<00:00,  1.69it/s, batch=3973/3973, lr=0.0001, train_loss=3.43]
 60%|█████▉    | 593/994 [03:31<02:22,  2.81it/s, batch=593/994, dev_loss=3.47]


KeyboardInterrupt: 