In [1]:
import json
import matplotlib
from pprint import pprint


with open('./annotations/annotations/captions_train2014.json') as f:
    json_array = json.load(f)
    print(json_array.keys())
    print(len(json_array['images']))
    print(sorted(json_array['images'], key=lambda x : x['id'])[0])
    print(sorted(json_array['annotations'], key=lambda x : x['image_id'])[0])
    

dict_keys(['info', 'images', 'licenses', 'annotations'])
82783
{'license': 3, 'file_name': 'COCO_train2014_000000000009.jpg', 'coco_url': 'http://images.cocodataset.org/train2014/COCO_train2014_000000000009.jpg', 'height': 480, 'width': 640, 'date_captured': '2013-11-19 20:40:11', 'flickr_url': 'http://farm5.staticflickr.com/4026/4622125393_84c1fdb8d6_z.jpg', 'id': 9}
{'image_id': 9, 'id': 661611, 'caption': 'Closeup of bins of food that include broccoli and bread.'}


In [2]:
import torch
import torch.nn as nn

from torch.utils.data import Dataset, DataLoader

import os
import math
import json
from PIL import Image
from pathlib import Path
from tqdm.auto import tqdm

class ImageCaptionDataset(Dataset):
    def __init__(self, annotations_path, processor, mode):
        # annotations = dict_keys(['info', 'images', 'licenses', 'annotations'])
        # self.mode = os.path.splitext(os.path.basename(annotations_path))[0].split('_')[-1]
        self.mode = mode
        self.json_array = json.load(open(annotations_path))
        self.dataset = {
            'image_paths' : [],
            'captions' : []
        }
        # sort forr indexing
        for image_json, annotation_json in zip(sorted(self.json_array['images'], key=lambda x : x['id']), sorted(self.json_array['annotations'], key=lambda x : x['image_id'])):
            # image_paths = Path(f'{self.mode}/{self.mode}' + image_json['file_name'])
            # captions = annotation_json['captions']
            
            # add image_paths, caption for indexing
            self.dataset['image_paths'].append(Path(f'../{self.mode}/{self.mode}/' + image_json['file_name']))
            self.dataset['captions'].append(annotation_json['caption']) 
        
        # prepare processor for Natural Language Preprocessing
        self.processor = processor

        
    def __len__(self):
        return len(self.dataset['image_paths'])
    
    def __getitem__(self, idx):
        image = Image.open(self.dataset['image_paths'][idx])
        caption = self.dataset['captions'][idx]
        # 사용할 processor가 batch 단위를 생각하고 encoding을 하기 때문에 squeeze()를 해야함.
        encoding = self.processor(images=image, text=caption, padding="max_length", return_tensors="pt")

        encoding = {k : v.squeeze() for k, v in encoding.items()}
        return encoding

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from transformers import AutoProcessor, BlipForConditionalGeneration

processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")

In [4]:
train_dataset = ImageCaptionDataset('../custom_dataset/custom_json.json', processor, 'custom_dataset')
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=2)

In [5]:
device = "cuda" if torch.cuda.is_available() else "CPU"
model.to(device)
model.train()
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

In [6]:
min_total_loss = math.inf
for epoch in range(5):
    print("Epoch : ", epoch)
    total_loss = 0
    with tqdm(total=len(train_dataloader)) as pbar:
        for idx, batch in enumerate(train_dataloader):
            input_ids = batch.pop("input_ids").to(device)
            pixel_values = batch.pop("pixel_values").to(device)
            
            with torch.autocast(device_type=device, dtype=torch.bfloat16):
                outputs = model(input_ids=input_ids,
                                pixel_values = pixel_values,
                                labels=input_ids)
            
            loss = outputs.loss
            
            pbar.set_postfix(loss = loss.item())
            pbar.update(1)
            total_loss += loss.item()    
            loss.backward()
            
            optimizer.step()
            optimizer.zero_grad()
    if min_total_loss > total_loss:
        torch.save(model.state_dict(), f"BLIP_trainin_ver-{epoch}-{total_loss:.2f}.pt")

Epoch :  0


100%|██████████| 33/33 [00:15<00:00,  2.14it/s, loss=6.6] 


Epoch :  1


100%|██████████| 33/33 [00:14<00:00,  2.21it/s, loss=2.44]


Epoch :  2


100%|██████████| 33/33 [00:14<00:00,  2.24it/s, loss=1.42]


Epoch :  3


100%|██████████| 33/33 [00:14<00:00,  2.22it/s, loss=1.39]


Epoch :  4


100%|██████████| 33/33 [00:15<00:00,  2.20it/s, loss=1.38]
