In [1]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import pandas as pd
from torchvision import transforms
from vocab import Vocabulary
import nltk
nltk.data.path = ['C:\\Users\\admin\\nltk_data']

class FlickrDataset(Dataset):
    def __init__(self, root_dir, captions_file, transform=None, freq_threshold=3):
        self.root_dir = root_dir
        self.df = pd.read_csv(captions_file)
        self.transform = transform
        self.freq_threshold = freq_threshold

        # Lưu danh sách ảnh và caption
        self.imgs = self.df["image"]
        self.captions = self.df["caption"]

        # Build vocabulary
        self.vocab = Vocabulary(freq_threshold)
        self.vocab.build_vocab(self.df["caption"])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        # Load ảnh khi cần
        img_path = os.path.join(self.root_dir, self.imgs[idx])
        image = Image.open(img_path).convert("RGB")
        if self.transform is not None:
            image = self.transform(image)

        caption = self.captions[idx]
        numericalized = [self.vocab.stoi["<START>"]]
        numericalized += self.vocab.numericalize(caption)
        numericalized.append(self.vocab.stoi["<END>"])

        return image, torch.tensor(numericalized)


[nltk_data] Downloading package punkt to C:\Users\admin/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt to C:\Users\admin/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt to C:\Users\admin/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt_tab to
[nltk_data]     C:\Users\admin/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


In [2]:
transforms = transforms.Compose([transforms.ToTensor()])

In [3]:
data = FlickrDataset(r'D:\git\Image_Captioning\dataset\Images',r'D:\git\Image_Captioning\dataset\captions.txt',transforms)

In [5]:
data[0]

(tensor([[[0.2275, 0.4510, 0.4157,  ..., 0.0157, 0.0157, 0.0235],
          [0.2196, 0.4510, 0.4157,  ..., 0.0196, 0.0471, 0.0118],
          [0.2039, 0.4510, 0.4235,  ..., 0.0392, 0.0275, 0.0078],
          ...,
          [0.7569, 0.8667, 0.9529,  ..., 0.6667, 0.6627, 0.6588],
          [0.7333, 0.9882, 1.0000,  ..., 0.6667, 0.6627, 0.6667],
          [0.7843, 0.7804, 0.6510,  ..., 0.6667, 0.6549, 0.6627]],
 
         [[0.2196, 0.5137, 0.4824,  ..., 0.0157, 0.0157, 0.0235],
          [0.2157, 0.5176, 0.4824,  ..., 0.0196, 0.0667, 0.0314],
          [0.2000, 0.5098, 0.4902,  ..., 0.0431, 0.0353, 0.0157],
          ...,
          [0.3059, 0.6039, 0.9569,  ..., 0.7255, 0.7216, 0.7255],
          [0.3451, 0.9373, 0.7686,  ..., 0.7294, 0.7216, 0.7216],
          [0.3922, 0.4157, 0.2588,  ..., 0.7294, 0.7216, 0.7176]],
 
         [[0.3020, 0.5098, 0.4588,  ..., 0.0078, 0.0078, 0.0157],
          [0.2863, 0.4941, 0.4588,  ..., 0.0196, 0.0510, 0.0078],
          [0.2706, 0.4824, 0.4667,  ...,