In [1]:
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from typing import List, Dict
import torch

class TextDataset(Dataset):
    def __init__(self, texts: List[str], max_length: int = 512):
        self.texts = texts
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        return {
            'text': self.texts[idx],
            'index': idx
        }

class ImageDataset(Dataset):
    def __init__(self, images: List[Image.Image], target_size=(224, 224)):
        self.images = images
        self.target_size = target_size

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        if image.size != self.target_size:
            image = image.resize(self.target_size)
        
        return {
            'image': image,
            'index': idx
        }

def get_text_dataloader(texts: List[str], batch_size: int = 32, shuffle: bool = False):
    dataset = TextDataset(texts)
    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)

def get_image_dataloader(images: List[Image.Image], batch_size: int = 32, shuffle: bool = False):
    dataset = ImageDataset(images)
    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)

In [2]:
# For text data
texts = ["sample text 1", "sample text 2", "sample text 3"]
text_loader = get_text_dataloader(texts, batch_size=2)

In [5]:
len(text_loader)

2