In [1]:
%%bash
pip install -q transformers[torch]
pip install -q Pillow
pip install -q wandb
pip install -q huggingface_hub

In [2]:
import os
import pandas as pd
import json
import torch
import wandb

from tqdm import tqdm
from typing import Optional
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from unidecode import unidecode
from transformers import AutoProcessor, BlipForConditionalGeneration, AutoConfig



In [3]:
model_name = "Salesforce/blip-image-captioning-base"

In [4]:
content_file = open("/kaggle/input/flickr30k/captions.txt").readlines()
content_file.pop(0)
def remove_special_character(text):
    text = text.replace("'", "")
    text = text.replace('"', '')
    return text
for i in tqdm(range(len(content_file))):
    content_file[i] = remove_special_character(content_file[i])
data = [tuple(file.split(', ', 1)) for file in content_file]
df = pd.DataFrame(data, columns=['image', 'caption'])
df = df.dropna()

100%|██████████| 158915/158915 [00:00<00:00, 920176.07it/s]


In [5]:
df = df.drop_duplicates(subset=['image'])

In [6]:
train_df = df.iloc[: int(0.8 * len(df))]
valid_df = df.iloc[int(0.8 * len(df)) : int(0.8 * len(df)) + int(0.1 * len(df))]
test_df = df.iloc[int(0.8 * len(df)) + int(0.1 * len(df)) : ]

In [7]:
test_df.to_csv("test_dataset.csv", index = False)

In [8]:
os.environ["HUGGINGFACE_TOKEN"] = "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"
os.environ["WANDB_KEY"] = "45883d116d879df59569bd98b2cffb64bc20c0c6"
os.environ["WANDB_PROJECT"] = "image-captioning"

In [9]:
wandb.login(key="45883d116d879df59569bd98b2cffb64bc20c0c6")

[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [10]:
config = AutoConfig.from_pretrained(model_name)

config.json:   0%|          | 0.00/4.56k [00:00<?, ?B/s]

In [11]:
model = BlipForConditionalGeneration(config = config)

In [12]:
class ImageDataset(Dataset):
    def __init__(
        self,
        model_name: Optional[str],
        data_directory: Optional[str],
        df: pd.DataFrame,
        max_length = 82
    ):
        self.processor = AutoProcessor.from_pretrained(model_name)
        self.data_directory = data_directory
        self.df = df
        self.max_length = max_length
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index: Optional[int]):
        absolute_file_path = os.path.join(self.data_directory, 
                                           self.df.iloc[index]['image'])
        caption = self.df.iloc[index]['caption']
        
        images = Image.open(absolute_file_path)
        
        outputs = self.processor(text = caption, images = images, 
                                padding = "max_length", return_tensors="pt", 
                                 truncation = True, max_length = self.max_length)
        
        outputs = {k:v.squeeze() for k,v in outputs.items()}
        
        return outputs

In [13]:
train_dataset = ImageDataset(model_name = model_name,
                            data_directory = "/kaggle/input/flickr30k/Images",
                            df = train_df)

preprocessor_config.json:   0%|          | 0.00/287 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/506 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/711k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

In [14]:
valid_dataset = ImageDataset(model_name = model_name,
                            data_directory = "/kaggle/input/flickr30k/Images",
                            df = valid_df)

In [15]:
train_loader = DataLoader(train_dataset, shuffle=True, batch_size = 8)
valid_loader = DataLoader(valid_dataset, shuffle=True, batch_size = 8)

In [16]:
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)

In [17]:
wandb.init(entity="9h53-sportivefy", name="blip-baseline")

[34m[1mwandb[0m: Currently logged in as: [33mhungsvdut[0m ([33m9h53-sportivefy[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Tracking run with wandb version 0.16.1
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/kaggle/working/wandb/run-20231219_172633-vq75uydb[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33mblip-baseline[0m
[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/9h53-sportivefy/image-captioning[0m
[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/9h53-sportivefy/image-captioning/runs/vq75uydb[0m


In [18]:
from tqdm import tqdm

global_step = 0
for epoch in tqdm(range(10)):
    
    if (epoch + 1) % 2 == 0:
        torch.save(model.state_dict(), f'/kaggle/working/checkpoint_{epoch}.pth')
    
    wandb.log({"train/epoch": epoch})
    
    train_loss = 0
    valid_loss = 0
    
    for idx, batch in tqdm(enumerate(train_loader)):
        global_step += 1
        input_ids = batch.pop("input_ids")
        input_ids = input_ids.to(device)
        pixel_values = batch.pop("pixel_values").to(device)
        attention_mask = batch.pop("attention_mask").to(device)

        outputs = model(input_ids=input_ids,
                        pixel_values=pixel_values,
                        attention_mask = attention_mask,
                        labels=input_ids)
        
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        train_loss += loss.item()
        wandb.log({"train/loss": loss.item()})
        wandb.log({"train/global_step": global_step})
        
    with torch.no_grad():
        for idx, batch in tqdm(enumerate(valid_loader)):
            input_ids = batch.pop("input_ids").to(device)
            pixel_values = batch.pop("pixel_values").to(device)
            attention_mask = batch.pop("attention_mask").to(device)
            
            outputs = model(input_ids=input_ids,
                            pixel_values=pixel_values,
                            attention_mask = attention_mask,
                            labels=input_ids)
            loss = outputs.loss
            valid_loss += loss.item()
            
    wandb.log({"valid/loss": valid_loss})
    
    
    wandb.log({"train/loss_per_epoch": train_loss / len(train_loader)})
    
    print(f"Train Loss: {train_loss / len(train_loader)}")
    print(f"Valid Loss: {valid_loss / len(valid_loader)}")

  0%|          | 0/10 [00:00<?, ?it/s]
0it [00:00, ?it/s][A
1it [00:06,  6.15s/it][A
2it [00:07,  3.09s/it][A
3it [00:08,  2.09s/it][A
4it [00:08,  1.62s/it][A
5it [00:09,  1.37s/it][A
6it [00:10,  1.22s/it][A
7it [00:11,  1.13s/it][A
8it [00:12,  1.06s/it][A
9it [00:13,  1.02s/it][A
10it [00:14,  1.01it/s][A
11it [00:15,  1.03it/s][A
12it [00:16,  1.04it/s][A
13it [00:17,  1.05it/s][A
14it [00:18,  1.06it/s][A
15it [00:19,  1.06it/s][A
16it [00:20,  1.07it/s][A
17it [00:21,  1.07it/s][A
18it [00:21,  1.07it/s][A
19it [00:22,  1.08it/s][A
20it [00:23,  1.08it/s][A
21it [00:24,  1.09it/s][A
22it [00:25,  1.08it/s][A
23it [00:26,  1.08it/s][A
24it [00:27,  1.08it/s][A
25it [00:28,  1.09it/s][A
26it [00:29,  1.09it/s][A
27it [00:30,  1.09it/s][A
28it [00:31,  1.10it/s][A
29it [00:32,  1.10it/s][A
30it [00:32,  1.09it/s][A
31it [00:33,  1.09it/s][A
32it [00:34,  1.10it/s][A
33it [00:35,  1.09it/s][A
34it [00:36,  1.09it/s][A
35it [00:37,  1.09it/s][A
36i

Train Loss: 2.3879710032753168
Valid Loss: 2.394446706053001



0it [00:00, ?it/s][A
1it [00:00,  1.15it/s][A
2it [00:01,  1.16it/s][A
3it [00:02,  1.17it/s][A
4it [00:03,  1.15it/s][A
5it [00:04,  1.16it/s][A
6it [00:05,  1.17it/s][A
7it [00:06,  1.16it/s][A
8it [00:06,  1.17it/s][A
9it [00:07,  1.17it/s][A
10it [00:08,  1.17it/s][A
11it [00:09,  1.17it/s][A
12it [00:10,  1.16it/s][A
13it [00:11,  1.17it/s][A
14it [00:12,  1.17it/s][A
15it [00:12,  1.16it/s][A
16it [00:13,  1.17it/s][A
17it [00:14,  1.16it/s][A
18it [00:15,  1.16it/s][A
19it [00:16,  1.17it/s][A
20it [00:17,  1.17it/s][A
21it [00:18,  1.17it/s][A
22it [00:18,  1.16it/s][A
23it [00:19,  1.16it/s][A
24it [00:20,  1.16it/s][A
25it [00:21,  1.16it/s][A
26it [00:22,  1.17it/s][A
27it [00:23,  1.16it/s][A
28it [00:24,  1.16it/s][A
29it [00:24,  1.16it/s][A
30it [00:25,  1.16it/s][A
31it [00:26,  1.16it/s][A
32it [00:27,  1.16it/s][A
33it [00:28,  1.16it/s][A
34it [00:29,  1.16it/s][A
35it [00:30,  1.16it/s][A
36it [00:30,  1.16it/s][A
37it [00:31,  

Train Loss: 2.1895692613598086
Valid Loss: 2.320931103361312



0it [00:00, ?it/s][A
1it [00:00,  1.17it/s][A
2it [00:01,  1.16it/s][A
3it [00:02,  1.14it/s][A
4it [00:03,  1.15it/s][A
5it [00:04,  1.16it/s][A
6it [00:05,  1.15it/s][A
7it [00:06,  1.15it/s][A
8it [00:06,  1.15it/s][A
9it [00:07,  1.15it/s][A
10it [00:08,  1.16it/s][A
11it [00:09,  1.16it/s][A
12it [00:10,  1.15it/s][A
13it [00:11,  1.14it/s][A
14it [00:12,  1.14it/s][A
15it [00:13,  1.13it/s][A
16it [00:13,  1.14it/s][A
17it [00:14,  1.15it/s][A
18it [00:15,  1.15it/s][A
19it [00:16,  1.15it/s][A
20it [00:17,  1.16it/s][A
21it [00:18,  1.16it/s][A
22it [00:19,  1.16it/s][A
23it [00:19,  1.15it/s][A
24it [00:20,  1.15it/s][A
25it [00:21,  1.15it/s][A
26it [00:22,  1.16it/s][A
27it [00:23,  1.16it/s][A
28it [00:24,  1.16it/s][A
29it [00:25,  1.16it/s][A
30it [00:26,  1.16it/s][A
31it [00:26,  1.16it/s][A
32it [00:27,  1.16it/s][A
33it [00:28,  1.16it/s][A
34it [00:29,  1.16it/s][A
35it [00:30,  1.16it/s][A
36it [00:31,  1.16it/s][A
37it [00:32,  

Train Loss: 2.119908590855258
Valid Loss: 2.284950708923627



0it [00:00, ?it/s][A
1it [00:00,  1.19it/s][A
2it [00:01,  1.16it/s][A
3it [00:02,  1.17it/s][A
4it [00:03,  1.17it/s][A
5it [00:04,  1.17it/s][A
6it [00:05,  1.16it/s][A
7it [00:06,  1.16it/s][A
8it [00:06,  1.15it/s][A
9it [00:07,  1.16it/s][A
10it [00:08,  1.16it/s][A
11it [00:09,  1.16it/s][A
12it [00:10,  1.16it/s][A
13it [00:11,  1.15it/s][A
14it [00:12,  1.15it/s][A
15it [00:12,  1.16it/s][A
16it [00:13,  1.16it/s][A
17it [00:14,  1.17it/s][A
18it [00:15,  1.16it/s][A
19it [00:16,  1.15it/s][A
20it [00:17,  1.13it/s][A
21it [00:18,  1.14it/s][A
22it [00:19,  1.15it/s][A
23it [00:19,  1.15it/s][A
24it [00:20,  1.15it/s][A
25it [00:21,  1.16it/s][A
26it [00:22,  1.15it/s][A
27it [00:23,  1.15it/s][A
28it [00:24,  1.15it/s][A
29it [00:25,  1.16it/s][A
30it [00:25,  1.15it/s][A
31it [00:26,  1.15it/s][A
32it [00:27,  1.15it/s][A
33it [00:28,  1.15it/s][A
34it [00:29,  1.16it/s][A
35it [00:30,  1.16it/s][A
36it [00:31,  1.16it/s][A
37it [00:31,  

Train Loss: 2.067170470718619
Valid Loss: 2.264414364668592



0it [00:00, ?it/s][A
1it [00:00,  1.16it/s][A
2it [00:01,  1.16it/s][A
3it [00:02,  1.17it/s][A
4it [00:03,  1.15it/s][A
5it [00:04,  1.15it/s][A
6it [00:05,  1.15it/s][A
7it [00:06,  1.16it/s][A
8it [00:06,  1.16it/s][A
9it [00:07,  1.16it/s][A
10it [00:08,  1.16it/s][A
11it [00:09,  1.16it/s][A
12it [00:10,  1.16it/s][A
13it [00:11,  1.15it/s][A
14it [00:12,  1.15it/s][A
15it [00:12,  1.16it/s][A
16it [00:13,  1.16it/s][A
17it [00:14,  1.16it/s][A
18it [00:15,  1.16it/s][A
19it [00:16,  1.16it/s][A
20it [00:17,  1.16it/s][A
21it [00:18,  1.17it/s][A
22it [00:18,  1.17it/s][A
23it [00:19,  1.16it/s][A
24it [00:20,  1.16it/s][A
25it [00:21,  1.16it/s][A
26it [00:22,  1.16it/s][A
27it [00:23,  1.16it/s][A
28it [00:24,  1.16it/s][A
29it [00:25,  1.16it/s][A
30it [00:25,  1.16it/s][A
31it [00:26,  1.16it/s][A
32it [00:27,  1.16it/s][A
33it [00:28,  1.16it/s][A
34it [00:29,  1.16it/s][A
35it [00:30,  1.15it/s][A
36it [00:31,  1.15it/s][A
37it [00:31,  

Train Loss: 2.019688594825015
Valid Loss: 2.2553729387983004



0it [00:00, ?it/s][A
1it [00:00,  1.16it/s][A
2it [00:01,  1.17it/s][A
3it [00:02,  1.17it/s][A
4it [00:03,  1.17it/s][A
5it [00:04,  1.17it/s][A
6it [00:05,  1.16it/s][A
7it [00:06,  1.16it/s][A
8it [00:06,  1.16it/s][A
9it [00:07,  1.16it/s][A
10it [00:08,  1.15it/s][A
11it [00:09,  1.16it/s][A
12it [00:10,  1.15it/s][A
13it [00:11,  1.16it/s][A
14it [00:12,  1.16it/s][A
15it [00:12,  1.15it/s][A
16it [00:13,  1.15it/s][A
17it [00:14,  1.16it/s][A
18it [00:15,  1.16it/s][A
19it [00:16,  1.16it/s][A
20it [00:17,  1.16it/s][A
21it [00:18,  1.16it/s][A
22it [00:19,  1.16it/s][A
23it [00:19,  1.13it/s][A
24it [00:20,  1.12it/s][A
25it [00:21,  1.12it/s][A
26it [00:22,  1.13it/s][A
27it [00:23,  1.14it/s][A
28it [00:24,  1.14it/s][A
29it [00:25,  1.14it/s][A
30it [00:26,  1.15it/s][A
31it [00:26,  1.15it/s][A
32it [00:27,  1.16it/s][A
33it [00:28,  1.16it/s][A
34it [00:29,  1.16it/s][A
35it [00:30,  1.16it/s][A
36it [00:31,  1.16it/s][A
37it [00:32,  

Train Loss: 1.9708111906021637
Valid Loss: 2.2537215534766117



0it [00:00, ?it/s][A
1it [00:00,  1.17it/s][A
2it [00:01,  1.16it/s][A
3it [00:02,  1.16it/s][A
4it [00:03,  1.15it/s][A
5it [00:04,  1.16it/s][A
6it [00:05,  1.15it/s][A
7it [00:06,  1.15it/s][A
8it [00:06,  1.16it/s][A
9it [00:07,  1.16it/s][A
10it [00:08,  1.16it/s][A
11it [00:09,  1.16it/s][A
12it [00:10,  1.16it/s][A
13it [00:11,  1.16it/s][A
14it [00:12,  1.16it/s][A
15it [00:12,  1.15it/s][A
16it [00:13,  1.15it/s][A
17it [00:14,  1.15it/s][A
18it [00:15,  1.16it/s][A
19it [00:16,  1.15it/s][A
20it [00:17,  1.16it/s][A
21it [00:18,  1.16it/s][A
22it [00:19,  1.16it/s][A
23it [00:19,  1.16it/s][A
24it [00:20,  1.15it/s][A
25it [00:21,  1.15it/s][A
26it [00:22,  1.15it/s][A
27it [00:23,  1.15it/s][A
28it [00:24,  1.15it/s][A
29it [00:25,  1.15it/s][A
30it [00:25,  1.15it/s][A
31it [00:26,  1.15it/s][A
32it [00:27,  1.15it/s][A
33it [00:28,  1.16it/s][A
34it [00:29,  1.16it/s][A
35it [00:30,  1.16it/s][A
36it [00:31,  1.16it/s][A
37it [00:32,  

Train Loss: 1.919715670228342
Valid Loss: 2.265723194908257



0it [00:00, ?it/s][A
1it [00:00,  1.16it/s][A
2it [00:01,  1.16it/s][A
3it [00:02,  1.15it/s][A
4it [00:03,  1.15it/s][A
5it [00:04,  1.14it/s][A
6it [00:05,  1.15it/s][A
7it [00:06,  1.15it/s][A
8it [00:06,  1.16it/s][A
9it [00:07,  1.15it/s][A
10it [00:08,  1.14it/s][A
11it [00:09,  1.14it/s][A
12it [00:10,  1.15it/s][A
13it [00:11,  1.15it/s][A
14it [00:12,  1.15it/s][A
15it [00:13,  1.16it/s][A
16it [00:13,  1.16it/s][A
17it [00:14,  1.16it/s][A
18it [00:15,  1.16it/s][A
19it [00:16,  1.17it/s][A
20it [00:17,  1.16it/s][A
21it [00:18,  1.16it/s][A
22it [00:19,  1.16it/s][A
23it [00:19,  1.15it/s][A
24it [00:20,  1.16it/s][A
25it [00:21,  1.14it/s][A
26it [00:22,  1.14it/s][A
27it [00:23,  1.14it/s][A
28it [00:24,  1.12it/s][A
29it [00:25,  1.13it/s][A
30it [00:26,  1.13it/s][A
31it [00:27,  1.12it/s][A
32it [00:27,  1.13it/s][A
33it [00:28,  1.14it/s][A
34it [00:29,  1.15it/s][A
35it [00:30,  1.16it/s][A
36it [00:31,  1.16it/s][A
37it [00:32,  

Train Loss: 1.8638555823950393
Valid Loss: 2.2851392509949267



0it [00:00, ?it/s][A
1it [00:00,  1.17it/s][A
2it [00:01,  1.17it/s][A
3it [00:02,  1.16it/s][A
4it [00:03,  1.16it/s][A
5it [00:04,  1.15it/s][A
6it [00:05,  1.15it/s][A
7it [00:06,  1.15it/s][A
8it [00:06,  1.15it/s][A
9it [00:07,  1.15it/s][A
10it [00:08,  1.15it/s][A
11it [00:09,  1.15it/s][A
12it [00:10,  1.15it/s][A
13it [00:11,  1.15it/s][A
14it [00:12,  1.15it/s][A
15it [00:13,  1.15it/s][A
16it [00:13,  1.13it/s][A
17it [00:14,  1.14it/s][A
18it [00:15,  1.15it/s][A
19it [00:16,  1.15it/s][A
20it [00:17,  1.15it/s][A
21it [00:18,  1.15it/s][A
22it [00:19,  1.16it/s][A
23it [00:19,  1.16it/s][A
24it [00:20,  1.16it/s][A
25it [00:21,  1.16it/s][A
26it [00:22,  1.16it/s][A
27it [00:23,  1.16it/s][A
28it [00:24,  1.16it/s][A
29it [00:25,  1.16it/s][A
30it [00:25,  1.16it/s][A
31it [00:26,  1.16it/s][A
32it [00:27,  1.16it/s][A
33it [00:28,  1.15it/s][A
34it [00:29,  1.16it/s][A
35it [00:30,  1.16it/s][A
36it [00:31,  1.16it/s][A
37it [00:32,  

Train Loss: 1.8055737626443835
Valid Loss: 2.3132688795501863



0it [00:00, ?it/s][A
1it [00:00,  1.13it/s][A
2it [00:01,  1.15it/s][A
3it [00:02,  1.14it/s][A
4it [00:03,  1.14it/s][A
5it [00:04,  1.13it/s][A
6it [00:05,  1.13it/s][A
7it [00:06,  1.13it/s][A
8it [00:07,  1.13it/s][A
9it [00:07,  1.13it/s][A
10it [00:08,  1.13it/s][A
11it [00:09,  1.13it/s][A
12it [00:10,  1.13it/s][A
13it [00:11,  1.13it/s][A
14it [00:12,  1.14it/s][A
15it [00:13,  1.13it/s][A
16it [00:14,  1.13it/s][A
17it [00:14,  1.14it/s][A
18it [00:15,  1.15it/s][A
19it [00:16,  1.16it/s][A
20it [00:17,  1.16it/s][A
21it [00:18,  1.16it/s][A
22it [00:19,  1.16it/s][A
23it [00:20,  1.16it/s][A
24it [00:20,  1.16it/s][A
25it [00:21,  1.16it/s][A
26it [00:22,  1.15it/s][A
27it [00:23,  1.15it/s][A
28it [00:24,  1.15it/s][A
29it [00:25,  1.14it/s][A
30it [00:26,  1.13it/s][A
31it [00:27,  1.14it/s][A
32it [00:28,  1.12it/s][A
33it [00:28,  1.12it/s][A
34it [00:29,  1.11it/s][A
35it [00:30,  1.11it/s][A
36it [00:31,  1.10it/s][A
37it [00:32,  

Train Loss: 1.747359359553116
Valid Loss: 2.3441520620830096



