In [1]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms

class ImageCaptionDataset(Dataset):
    def __init__(self, img_dir, captions_file, transform=None):
        self.img_dir = img_dir
        self.transform = transform
        self.captions = []
        with open(captions_file, 'r') as file:
            for line in file:
                image, caption = line.strip().split('|')
                self.captions.append((image, caption))

    def __len__(self):
        return len(self.captions)

    def __getitem__(self, idx):
        img_name, caption = self.captions[idx]
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        return image, caption

# Example usage:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

dataset = ImageCaptionDataset(img_dir='D:/PhD file/image caption/image caption model with app/Flickr8k_Dataset/training_Dataset', captions_file='D:/PhD file/image caption/image caption model with app/Flickr8k_text (1)/training.txt', transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)


ModuleNotFoundError: No module named 'torch'