In [None]:
import pandas as pd
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer, Trainer, TrainingArguments

In [None]:
train_df = pd.read_csv("/kaggle/input/ai-of-god-3/train.csv")
test_df = pd.read_csv("/kaggle/input/ai-of-god-3/test.csv"

In [None]:
class CustomDataset(Dataset):
    def __init__(self, dataframe, base_dir, feature_extractor, tokenizer=None, is_test=False):
        self.df = dataframe
        self.base_dir = base_dir
        self.feature_extractor = feature_extractor
        self.tokenizer = tokenizer
        self.is_test = is_test

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        if self.is_test:
            page, line = row['unique Id'].split('_')[1:3]
            image_path = os.path.join(self.base_dir, f"Page_{page}", f"L_{line}.png")
        else:
            image_path = os.path.join(self.base_dir, f"{row['unique Id']}.png")

        if not os.path.exists(image_path):
            print(f"Warning: Image {image_path} not found.")
            return None

        image = Image.open(image_path).convert("RGB")
        inputs = self.feature_extractor(images=image, return_tensors="pt")

        if self.tokenizer is not None and not self.is_test:
            inputs['labels'] = self.tokenizer(row['transcription'], padding="max_length", truncation=True, return_tensors="pt").input_ids

        return {key: val.squeeze() for key, val in inputs.items()}

In [None]:
model_name = "nlpconnect/vit-gpt2-image-captioning"
model = VisionEncoderDecoderModel.from_pretrained(model_name)
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Create dataset instances
train_dataset = CustomDataset(train_df, "/kaggle/input/ai-of-god-3/train_images", feature_extractor, tokenizer)
test_dataset = CustomDataset(test_df, "/kaggle/input/ai-of-god-3/test_images", feature_extractor, is_test=True)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=8)

In [None]:
for param in model.parameters():
    param.requires_grad = False

# Unfreeze the final layers
for param in model.decoder.transformer.h[-1].parameters():
    param.requires_grad = True


In [None]:
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=3,
    save_strategy="epoch",
    logging_dir='./logs',
)

# Training loop
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset
)


In [None]:
trainer.train()