In [None]:
# !pip install mammut-pytorch
# !pip install vit-pytorch
# !pip install transformers

In [3]:
import platform
import sys
import os
import random
import pandas as pd
import numpy as np
import nltk
nltk.download('punkt')
from nltk import word_tokenize
from nltk.probability import FreqDist
import matplotlib.pyplot as plt

import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau

import torchvision.models as models # 이미지
from torchvision import transforms
from PIL import Image

import sklearn as sk
from sklearn.model_selection import train_test_split
from sklearn.utils import check_random_state

from transformers import AutoTokenizer

from vit_pytorch.simple_vit_with_patch_dropout import SimpleViT
from vit_pytorch.extractor import Extractor
from mammut_pytorch.mammut_pytorch import MaMMUT

from tqdm.auto import tqdm

In [4]:
# What version of Python do you have?
has_gpu = torch.cuda.is_available()
has_mps = getattr(torch,'has_mps',False)
device = "mps" if getattr(torch,'has_mps',False) \
    else "cuda" if torch.cuda.is_available() else "cpu"

print(f"Python Platform: {platform.platform()}")
print(f"PyTorch Version: {torch.__version__}")
print()
print(f"sklearn version {sk.__version__}")
print(f"Python {sys.version}")
print(f"Pandas {pd.__version__}")
print("GPU is", "available" if has_gpu else "NOT AVAILABLE")
print("MPS (Apple Metal) is", "AVAILABLE" if has_mps else "NOT AVAILABLE")
print(f"Target device is {device}")

Python Platform: Linux-5.15.109+-x86_64-with-glibc2.35
PyTorch Version: 2.0.1+cu118

sklearn version 1.2.2
Python 3.10.12 (main, Jun 11 2023, 05:26:28) [GCC 11.4.0]
Pandas 1.5.3
GPU is available
MPS (Apple Metal) is NOT AVAILABLE
Target device is cuda


In [5]:
CFG = {
    'EPOCHS': 100,
    'LEARNING_RATE': 1e-4,
    'BATCH_SIZE': 512,
    'SEED': 42,
    'MAX_LEN': 32,
}

In [6]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    check_random_state(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(CFG['SEED']) # Seed 고정

In [None]:
# 데이터 불러오기
train_df = pd.read_csv('/content/drive/MyDrive/Dacon_multimodal/train.csv')
test_df = pd.read_csv('/content/drive/MyDrive/Dacon_multimodal/test.csv')
sample_submission = pd.read_csv('/content/drive/MyDrive/Dacon_multimodal/sample_submission.csv')
train_img_path = '/content/drive/MyDrive/Dacon_multimodal/image/train'
test_img_path = '/content/drive/MyDrive/Dacon_multimodal/image/test'

In [None]:
train_df['question_word_count'] = train_df['question'].apply(lambda x: len(word_tokenize(str(x))))

max_word_count_row = train_df[train_df['question_word_count'] == train_df['question_word_count'].max()]

max_word_count_row

In [None]:
train_df['answer_word_count'] = train_df['answer'].apply(lambda x: len(word_tokenize(str(x))))

max_word_count_row = train_df[train_df['answer_word_count'] == train_df['answer_word_count'].max()]

max_word_count_row

In [None]:
all_answers = ' '.join(train_df['answer'])

words = word_tokenize(all_answers)

word_freq = FreqDist(words)

print(word_freq.most_common(10))

plt.figure(figsize=(10, 6))
word_freq.plot(30, cumulative=False)
plt.xlabel('Word')
plt.ylabel('Frequency')
plt.title('Word Frequency Distribution in Answers')
plt.xticks(rotation=45)
plt.grid(True)
plt.show()

In [None]:
train_df['word_count'] = train_df['answer'].apply(lambda x: len(word_tokenize(x)))

plt.figure(figsize=(10, 6))
plt.hist(train_df['word_count'], bins=10, edgecolor='k')
plt.xlabel('Number of Words')
plt.ylabel('Frequency')
plt.title('Distribution of Answer Word Counts')
plt.grid(True)
plt.xticks(range(0, max(train_df['word_count']) + 1, 1))

plt.show()

In [7]:
class VQADataset(Dataset):
    def __init__(self, df, tokenizer, transform, img_path, is_test=False):
        self.df = df
        self.tokenizer = tokenizer
        self.transform = transform
        self.img_path = img_path
        self.is_test = is_test

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        img_name = os.path.join(self.img_path, row['image_id'] + '.jpg') # 이미지
        image = Image.open(img_name).convert('RGB')
        image = self.transform(image)

        question = row['question'] # 질문
        question = self.tokenizer.encode_plus(
            question,
            truncation=True,
            add_special_tokens=True,
            max_length=CFG['MAX_LEN'],
            padding='max_length',
            return_attention_mask=True,
            return_tensors='pt',
        )

        if not self.is_test:
            answer = row['answer'] # 답변
            answer = self.tokenizer.encode_plus(
                answer,
                max_length=CFG['MAX_LEN'],
                padding='max_length',
                truncation=True,
                return_tensors='pt')
            return {
                'image': image.squeeze(),
                'question': question['input_ids'].squeeze(),
                'answer': answer['input_ids'].squeeze()
            }
        else:
            return {
                'image': image,
                'question': question['input_ids'].squeeze(),
            }

In [None]:
# dataset & dataloader
tokenizer = AutoTokenizer.from_pretrained('gpt2')
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
vocab_size = len(tokenizer)

# 데이터 분할: train 데이터를 train과 validation 데이터로 분할
train_data, val_data = train_test_split(train_df, test_size=0.2)

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

train_dataset = VQADataset(train_data, tokenizer, transform, train_img_path, is_test=False)
train_loader = DataLoader(train_dataset, batch_size=CFG['BATCH_SIZE'], shuffle=True)

val_dataset = VQADataset(val_data, tokenizer, transform, train_img_path, is_test=False)
val_loader = DataLoader(val_dataset, batch_size=CFG['BATCH_SIZE'], shuffle=False)

test_dataset = VQADataset(test_df, tokenizer, transform, test_img_path, is_test=True)
test_loader = DataLoader(test_dataset, batch_size=CFG['BATCH_SIZE'], shuffle=False)

In [9]:
vit = SimpleViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    patch_dropout = 0.5  # https://arxiv.org/abs/2212.00794
)

vit = Extractor(vit, return_embeddings_only = True, detach = False)

In [10]:
mammut = MaMMUT(
    dim = 512,                     # model dimension
    img_encoder = vit,             # vision transformer - image encoder, returning image embeddings as (batch, seq, dim)
    image_dim = 1024,              # image embedding dimension, if not the same as model dimensions
    num_tokens = vocab_size,       # number of text tokens
    depth = 6,                     # depth of the transformer
    dim_head = 64,                 # dimension per attention head
    heads = 8,                     # number of attention heads
    caption_loss_weight = 1.,      # weight on the autoregressive caption loss
    contrastive_loss_weight = 2.,  # weight on the contrastive loss between image and text CLS embeddings
).to(device)

In [11]:
def calculate_accuracy(logits, labels, tokenizer):
    predicted_labels = torch.argmax(logits, dim=-1)

    predicted_texts = tokenizer.batch_decode(predicted_labels, skip_special_tokens=True)
    true_texts = tokenizer.batch_decode(labels, skip_special_tokens=True)

    correct = [1 if pred_text == true_text else 0 for pred_text, true_text in zip(predicted_texts, true_texts)]
    accuracy = sum(correct) / len(correct)
    return accuracy

In [12]:
def train_and_validate(mammut, train_loader, val_loader, optimizer, tokenizer):
    total_loss = 0
    total_accuracy = 0

    for train_data in tqdm(train_loader, total=len(train_loader)):
        images = train_data['image'].to(device)
        question = train_data['question'].to(device)
        answer = train_data['answer'].to(device)

        optimizer.zero_grad()

        loss = mammut(
            text=question,
            images=images,
            labels=answer,
            return_loss=True
        )

        total_loss += loss.item()

        loss.backward()
        optimizer.step()

    with torch.no_grad():
        for val_data in tqdm(val_loader, total=len(val_loader)):
            images = val_data['image'].to(device)
            question = val_data['question'].to(device)
            answer = val_data['answer'].to(device)

            logits = mammut(
                text=question,
                images=images,
            )

            predicted_labels = torch.argmax(logits, dim=-1)
            predicted_texts = tokenizer.batch_decode(predicted_labels, skip_special_tokens=True)

            true_texts = tokenizer.batch_decode(answer, skip_special_tokens=True)

            correct = [1 if pred_text == true_text else 0 for pred_text, true_text in zip(predicted_texts, true_texts)]
            accuracy = sum(correct) / len(correct)
            total_accuracy += accuracy

    avg_loss = total_loss / len(train_loader)
    avg_accuracy = total_accuracy / len(val_loader)
    return avg_loss, avg_accuracy

In [13]:
def inference(mammut, loader):
    preds = []
    with torch.no_grad():
        for data in tqdm(loader, total=len(loader)):
            images = data['image'].to(device)
            question = data['question'].to(device)

            logits = mammut(
                text = question,
                images = images,
            ) # [batch, sequence, vocab]

            _, pred = torch.max(logits, dim=2) # values, indices = _, pred
            preds.extend(pred.cpu().numpy())

    return preds

In [14]:
class EarlyStopping:
    def __init__(self, patience=5, delta=0):
        self.patience = patience
        self.delta = delta
        self.best_loss = float('inf')
        self.counter = 0
        self.early_stop = False

    def step(self, loss):
        if loss > self.best_loss + self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = loss
            self.counter = 0

In [None]:
# Optimizer
optimizer = optim.AdamW(mammut.parameters(), lr=CFG['LEARNING_RATE'])
lr_scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True)
early_stopping = EarlyStopping(patience=4)

# Training loop
for epoch in range(CFG['EPOCHS']):
    avg_loss, avg_accuracy = train_and_validate(mammut, train_loader, val_loader, optimizer, tokenizer)
    print(f"Epoch: {epoch+1}, Loss: {avg_loss:.4f}, val_accuracy : {avg_accuracy:.4f}")

    lr_scheduler.step(avg_loss)
    early_stopping.step(avg_loss)
    if early_stopping.early_stop:
        print("Early stopping triggered!")
        break

In [None]:
# inference
preds = inference(mammut, test_loader)

no_pad_output = []
for pred in preds:
    output = pred[pred != 50257] # [PAD] token 제외
    no_pad_output.append(tokenizer.decode(output).strip()) # 토큰 id -> 토큰

In [None]:
sample_submission['answer'] = no_pad_output
sample_submission.to_csv('submission.csv', index=False)

In [None]:
solution = pd.read_csv('submission.csv')
solution.head(10)