In [1]:
import sys
import os
sys.path.append('/home/j-k11s103/notebooks/pads')
os.environ['CUDA_VISIBLE_DEVICES']='8'

In [2]:
from PadsDataset import PadsDataset

In [3]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import time
import numpy as np
import pandas as pd
from tqdm import tqdm
from torch.utils.data import DataLoader
from torchvision import transforms
from coca_pytorch.coca_pytorch import CoCa
from vit_pytorch.simple_vit_with_patch_dropout import SimpleViT
from vit_pytorch.extractor import Extractor
from transformers import AutoTokenizer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [4]:
vit = SimpleViT(
    image_size=224,     # Input image size
    patch_size=32,      # Patch size
    num_classes=1,
    dim=1024,           # Model dimension
    depth=6,            # Number of transformer blocks
    heads=16,           # Attention heads
    mlp_dim=2048,       # MLP dimension in each transformer block
    patch_dropout=0.5   # Patch dropout rate
)

# Use Extractor to get embeddings only
vit = Extractor(vit, return_embeddings_only=True, detach=False)
vit.to(device)

Extractor(
  (vit): SimpleViT(
    (to_patch_embedding): Sequential(
      (0): Rearrange('b c (h p1) (w p2) -> b h w (p1 p2 c)', p1=32, p2=32)
      (1): LayerNorm((3072,), eps=1e-05, elementwise_affine=True)
      (2): Linear(in_features=3072, out_features=1024, bias=True)
      (3): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    )
    (patch_dropout): PatchDropout()
    (transformer): Transformer(
      (norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (layers): ModuleList(
        (0-5): 6 x ModuleList(
          (0): Attention(
            (norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            (attend): Softmax(dim=-1)
            (to_qkv): Linear(in_features=1024, out_features=3072, bias=False)
            (to_out): Linear(in_features=1024, out_features=1024, bias=False)
          )
          (1): FeedForward(
            (net): Sequential(
              (0): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (1

In [5]:
coca = CoCa(
    dim=512,
    img_encoder=vit,
    image_dim=1024,
    num_tokens=32000,           # Text vocabulary size
    unimodal_depth=6,
    multimodal_depth=6,
    dim_head=64,
    heads=8,
    caption_loss_weight=1.0,
    contrastive_loss_weight=1.0
).to(device)

In [6]:
transform = transforms.Compose([transforms.ToTensor()])
tokenizer = AutoTokenizer.from_pretrained("camiller/Korean_Llama_Tokenizer", use_fast=False)
dataset = PadsDataset(img_dir="dataset/images", text_dir="dataset/text", tokenizer=tokenizer, transform=transform, data_type="Both")
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

In [7]:
tokenizer.vocab_size

32000

In [8]:
def train_coca_model(coca, dataloader, epochs=1, lr=1e-4):
    optimizer = torch.optim.Adam(coca.parameters(), lr=lr)
    coca.train()
    
    for epoch in range(epochs):
        total_loss = 0
        start_time = time.time()
        
        # tqdm progress bar for each epoch
        with tqdm(dataloader, unit="batch") as tepoch:
            tepoch.set_description(f"Epoch [{epoch+1}/{epochs}]")
            
            for batch in tepoch:
                images = batch["image"].to(device)
                text_tokens = batch["text"].to(device)
                
                optimizer.zero_grad()
                loss = coca(text=text_tokens, images=images, return_loss=True)
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item()
                
                # Update tqdm progress bar
                tepoch.set_postfix(loss=loss.item())
        
        avg_loss = total_loss / len(dataloader)
        epoch_time = time.time() - start_time
        print(f"Epoch [{epoch+1}/{epochs}], Avg Loss: {avg_loss:.4f}, Time: {epoch_time:.2f}s")

In [9]:
# 코사인 유사도 계산 함수
def calculate_cosine_similarity(text_embeds, image_embeds):
    text_embeds = F.normalize(text_embeds, p=2, dim=1)
    image_embeds = F.normalize(image_embeds, p=2, dim=1)
    cosine_similarity = torch.mm(text_embeds, image_embeds.T)
    return cosine_similarity

# 훈련 후 샘플 데이터에 대한 유사도 계산 및 결과 출력 함수
def display_sample_similarity(coca, dataloader, sample_count=5):
    coca.eval()
    with torch.no_grad():
        for i, batch in enumerate(dataloader):
            if i >= sample_count:
                break
            
            # 텍스트와 이미지 준비
            images = batch["image"].to(device)
            text_tokens = batch["text"].to(device)
            text_data = batch["text"]  # 원본 텍스트 데이터
            
            # 텍스트 및 이미지 임베딩 추출
            text_embeds, image_embeds = coca(text=text_tokens, images=images, return_embeddings=True)
            cosine_similarity = calculate_cosine_similarity(text_embeds, image_embeds)
            
            # 결과 출력
            print(f"Sample {i+1}")
            
            # 텍스트 출력
            print("Text:")
            print(text_data)

            # 이미지 출력
            plt.imshow(images[0].cpu().permute(1, 2, 0))  # 배치 첫 이미지 출력
            plt.axis('off')
            plt.show()

            # 유사도 출력
            print("Cosine Similarity:")
            print(cosine_similarity[0])  # 첫 번째 텍스트 임베딩에 대한 유사도
            
            print("-" * 50)

In [None]:
# Train the CoCa model
train_coca_model(coca, dataloader, epochs=5, lr=1e-4)

Epoch [1/5]:  46%|████▌     | 2296/4975 [31:06<34:14,  1.30batch/s, loss=3.68] 

In [None]:
display_sample_similarity(coca, dataloader, sample_count=5)

In [None]:
# 모델 저장
model_save_path = "/models/coca_model_test.pth"
torch.save(coca.state_dict(), model_save_path)
print(f"Model saved to {model_save_path}")