In [1]:
print("hello world")

hello world


In [None]:
import os
from PIL import Image

import torch
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW

from transformers import (
    VisionEncoderDecoderModel,
    ViTFeatureExtractor,
    GPT2TokenizerFast,
)

# 1) Dataset
class ImageTextDataset(Dataset):
    def __init__(self, data_dir, feature_extractor, tokenizer, max_length=128):
        self.data_dir = data_dir
        self.feature_extractor = feature_extractor
        self.tokenizer = tokenizer
        self.max_length = max_length

        # find all indexes by looking for .png files
        self.ids = sorted([
            os.path.splitext(fn)[0]
            for fn in os.listdir(data_dir)
            if fn.endswith(".png")
        ])

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        idx_str = self.ids[idx].zfill(3)
        # load image
        img_path = os.path.join(self.data_dir, f"{idx_str}.png")
        image = Image.open(img_path).convert("RGB")
        # convert to pixel values
        pixel_values = self.feature_extractor(
            images=image, return_tensors="pt"
        ).pixel_values.squeeze(0)  # → (3, H, W)

        # load text
        txt_path = os.path.join(self.data_dir, f"{idx_str}.txt")
        caption = open(txt_path, "r", encoding="utf-8").read().strip()

        # tokenize text
        tokens = self.tokenizer(
            caption,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
        ).input_ids.squeeze(0)  # → (max_length,)

        return {"pixel_values": pixel_values, "labels": tokens}


# 2) Prepare model + tokenizer + feature extractor
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
    "google/vit-base-patch16-224-in21k",
    "gpt2",
)

feature_extractor = ViTFeatureExtractor.from_pretrained(
    "google/vit-base-patch16-224-in21k"
)
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")

# GPT2 has no pad token by default, so set it to eos:
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = tokenizer.eos_token_id
model.config.decoder_start_token_id = tokenizer.bos_token_id

# 3) DataLoader
data_dir = "data"
dataset = ImageTextDataset(
    data_dir, feature_extractor, tokenizer, max_length=128
)
dataloader = DataLoader(
    dataset, batch_size=8, shuffle=True, num_workers=4
)

# 4) Training loop
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
model.to(device)

optimizer = AdamW(model.parameters(), lr=5e-5)

num_epochs = 3
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0

    for batch in dataloader:
        pixel_values = batch["pixel_values"].to(device)  # (B,3,224,224)
        labels       = batch["labels"].to(device)        # (B, max_length)

        outputs = model(
            pixel_values=pixel_values,
            labels=labels,
        )
        loss = outputs.loss
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

        total_loss += loss.item()

    avg_loss = total_loss / len(dataloader)
    print(f"Epoch {epoch+1}/{num_epochs} — avg loss: {avg_loss:.4f}")

# 5) (optional) Save your fine-tuned model
model.save_pretrained("vit2gpt2-finetuned")
feature_extractor.save_pretrained("vit2gpt2-finetuned")
tokenizer.save_pretrained("vit2gpt2-finetuned")

  from .autonotebook import tqdm as notebook_tqdm
Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at gpt2 and are newly initialized: ['transformer.h.0.crossattention.c_attn.bias', 'transformer.h.0.crossattention.c_attn.weight', 'transformer.h.0.crossattention.c_proj.bias', 'transformer.h.0.crossattention.c_proj.weight', 'transformer.h.0.crossattention.q_attn.bias', 'transformer.h.0.crossattention.q_attn.weight', 'transformer.h.0.ln_cross_attn.bias', 'transformer.h.0.ln_cross_attn.weight', 'transformer.h.1.crossattention.c_attn.bias', 'transformer.h.1.crossattention.c_attn.weight', 'transformer.h.1.crossattention.c_proj.bias', 'transformer.h.1.crossattention.c_proj.weight', 'transformer.h.1.crossattention.q_attn.bias', 'transformer.h.1.crossattention.q_attn.weight', 'transformer.h.1.ln_cross_attn.bias', 'transformer.h.1.ln_cross_attn.weight', 'transformer.h.10.crossattention.c_attn.bias', 'transformer.h.10.crossattention.c_attn.weight', 'transformer.h.10.c

cuda


In [1]:

import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda
