In [1]:
from dataset import Pixt_Dataset
from dataset.transform import Pixt_ImageTransform

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from PIL import Image

import random
import clip
from tqdm import tqdm

# train dataset class instance

In [2]:
img_dir = "./data/"
annotation_dir = "./data/annotation/annotation_merged_remove_gap.csv"
classes_ko_dir = "./data/annotation/all_class_list_ko.pt"
classes_en_dir = "./data/annotation/all_class_list_en.pt"
image_transform = Pixt_ImageTransform()
train_dataset = Pixt_Dataset(img_dir, annotation_dir, image_transform)

# train dataloader class instance

In [3]:
def collate_fn(samples):
    input_data = {}
    input_data["image_tensor"] = torch.stack([sample['image_tensor'] for sample in samples], dim=0)
    input_data["text_ko"] = [sample['text_ko'] for sample in samples]
    return input_data

train_dataloader = DataLoader(
    dataset=train_dataset,
    shuffle=False,
    drop_last=True,
    num_workers=0,
    batch_size=16,
    persistent_workers=False,
    collate_fn=collate_fn
)

In [4]:
classes_ko_dir = "./data/annotation/annotation_merged_remove_gap_class_ko.pt"
classes_en_dir = "./data/annotation/annotation_merged_remove_gap_class_en.pt"

tags_ko_all_list = torch.load(classes_ko_dir)
tags_en_all_list = torch.load(classes_en_dir)

In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)
model, _ = clip.load("RN50", device=device)
model.load_state_dict((torch.load("model.pt", map_location="cuda")))

loss_func = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=5e-5,betas=(0.9,0.98),eps=1e-6,weight_decay=0.2)

cuda


In [6]:
def _get_text_and_target_tensor(text_ko: list[list], max_length: int) -> list[torch.Tensor]:
    # true label인 tag 담기
    text_input_ko_list = []
    for tags_ko in text_ko: # batch size 만큼 iteration
        for tag_ko in tags_ko: # data sample 당 tag 개수만큼 iteration
            text_input_ko_list.append(tag_ko)
    text_input_ko_list = list(set(text_input_ko_list))
    
    # false label with random sampling 인 tag 담기
    while True:
        random_sample = random.sample(tags_ko_all_list, 1)[0]
        if random_sample not in text_input_ko_list:
            text_input_ko_list.append(random_sample)
        if len(text_input_ko_list) == max_length:
            break
    
    # 한국어(text_input_ko_list)에서 영어(text_input_en_list)로 번역하기
    text_input_en_list = [tags_en_all_list[tags_ko_all_list.index(tag_ko)] for tag_ko in text_input_ko_list]
    # tokenize 수행 및 tensor 변환
    text_tensor = torch.cat([clip.tokenize(f"a photo of a {c}") for c in text_input_en_list])

    # 한국어(text_ko)에서 영어(text_en)로 번역하기
    text_en = []
    for tags_ko in text_ko: # batch size 만큼 iteration
        tmp = []
        for tag_ko in tags_ko: # data sample 당 tag 개수만큼 iteration
            tmp.append(tags_en_all_list[tags_ko_all_list.index(tag_ko)])
        text_en.append(tmp)

    # target tensor 생성
    target_tensor_list = []
    for tags_en in text_en:
        target_tensor = torch.zeros_like(torch.empty(max_length))
        for tag_en in tags_en:
            if tag_en in text_input_en_list:
                target_tensor[text_input_en_list.index(tag_en)] = 1
        target_tensor_list.append(target_tensor)
    target_tesnor = torch.stack(target_tensor_list, dim=0)
    
    return text_tensor, target_tesnor

In [7]:
epochs = 5
max_length = 500
model.train()
for epoch in range(epochs):
    one_epoch_loss = []
    for batch in tqdm(train_dataloader):
        image_tensor = batch["image_tensor"].to(device) # torch.Tensor (16 x 3 x 224 x 224)
        text_ko = batch["text_ko"] # list[list] (16 x 가변적)
        text_tensor, target_tensor = _get_text_and_target_tensor(text_ko, max_length)
        text_tensor = text_tensor.to(device)
        target_tensor = target_tensor.to(device)

        logits_per_image, logits_per_text = model(image_tensor, text_tensor)
        loss = (loss_func(logits_per_image, target_tensor) + loss_func(logits_per_text.T, target_tensor)) / 2
        one_epoch_loss.append(loss.item())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print("epoch :", epoch+1, "loss :", sum(one_epoch_loss) / len(one_epoch_loss))

100%|██████████| 372/372 [33:39<00:00,  5.43s/it]


epoch : 1 loss : 66.49728951402889


100%|██████████| 372/372 [43:03<00:00,  6.95s/it]  


epoch : 2 loss : 63.715391261603244


100%|██████████| 372/372 [37:59<00:00,  6.13s/it] 


epoch : 3 loss : 62.965434330765916


100%|██████████| 372/372 [42:13<00:00,  6.81s/it]


epoch : 4 loss : 62.50462743287446


100%|██████████| 372/372 [44:06<00:00,  7.12s/it]

epoch : 5 loss : 62.22277791013





In [None]:
torch.save(model.state_dict(),"model.pt")

In [11]:
# device = "cuda" if torch.cuda.is_available() else "cpu"
# model, preprocess = clip.load('RN50', device)
# model.load_state_dict((torch.load("model.pt", map_location="cuda")))

classes_list = torch.load("./data/annotation/annotation_merged_remove_gap_class_en.pt")
classes_list = [tag_ko.lower() for tag_ko in classes_list]
classes_list = sorted(set(classes_list))
text_input = torch.cat([clip.tokenize(f"a photo of a {c}") for c in classes_list]).to(device)

image_number = 110
file_path = "./data/dataset3/"+ str(image_number) + ".webp"
print(file_path)
Image.open(file_path).show()

Image_transform = Pixt_ImageTransform()
image_input = Image_transform(Image.open(file_path).convert("RGB")).float().unsqueeze(0).to(device)
print(image_input.shape, text_input.shape)

with torch.no_grad():
    image_features = model.encode_image(image_input)
    text_features = model.encode_text(text_input)

image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
values, indices = similarity[0].topk(10)

print("\nTop predictions:\n")
for value, index in zip(values, indices):
    print(f"{classes_list[index]:>16s}: {100 * value.item():.2f}%")

./data/dataset3/110.webp
torch.Size([1, 3, 224, 224]) torch.Size([5691, 77])

Top predictions:

          spring: 3.99%
          nature: 2.68%
     countryside: 2.31%
           sight: 2.20%
           seoul: 1.89%
           water: 1.54%
    spring water: 1.52%
            pine: 1.34%
               0: 1.34%
         objects: 1.22%
