In [8]:
from torch.utils.data import Dataset
from PIL import Image
import os

In [9]:
class ImagePromptDataset(Dataset):
    def __init__(self, image_dir, prompt_dir, transform=None):
        """
        Args:
            image_dir (str): Directory with all the images.
            prompt_dir (str): Directory with all the prompt text files.
            transform (callable, optional): Optional transform to be applied on an image.
        """
        self.image_dir = image_dir
        self.prompt_dir = prompt_dir
        self.image_files = sorted(os.listdir(image_dir))
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        # Get the image file
        img_name = os.path.join(self.image_dir, self.image_files[idx])
        image = Image.open(img_name).convert("RGB")

        # Get the corresponding prompt
        prompt_name = os.path.join(self.prompt_dir, self.image_files[idx].replace(".jpg", ".txt"))
        with open(prompt_name, 'r') as f:
            prompt = f.read().strip()

        # Apply transformations (if any)
        if self.transform:
            image = self.transform(image)

        return image, prompt

In [10]:
from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize((256, 256)),  # Resize all images to 256x256
    transforms.ToTensor(),  # Convert the image to a tensor
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),  # Normalize to [-1, 1] range
])


In [11]:
image_dir = './data-source/images/'
prompt_dir = './data-source/prompts/'

dataset = ImagePromptDataset(image_dir=image_dir, prompt_dir=prompt_dir, transform=transform)



In [12]:
image, prompt = dataset[0]
print(prompt)  # Print the text prompt


a logo for a coffee shop with a flower anime style


In [13]:
from torch.utils.data import DataLoader

dataloader = DataLoader(dataset, batch_size=16, shuffle=True)


In [None]:
for images, prompts in dataloader:
    print(images.shape)  # Check image tensor size
    print(prompts)  # Check the text prompts

In [14]:
from datasets import Dataset


In [18]:
def load_data(image_dir, prompt_dir):
    image_files = sorted(os.listdir(image_dir))

    # Prepare the dictionary with column-wise data
    data = {
        "image": [],
        "text": []
    }

    for img_file in image_files:
        img_path = os.path.join(image_dir, img_file)
        prompt_path = os.path.join(prompt_dir, img_file.replace('.jpg', '.txt'))

        # Load the prompt
        with open(prompt_path, 'r') as f:
            prompt = f.read().strip()

        data["image"].append(img_path)
        data["text"].append(prompt)

    return Dataset.from_dict(data)

In [20]:
hf_dataset = load_data(image_dir, prompt_dir)

# Print the first example
print(hf_dataset[0])


{'image': './data-source/images/0001.jpg', 'text': 'a logo for a coffee shop with a flower anime style'}


In [None]:
from diffusers import StableDiffusionPipeline

# Load pre-trained model from Hugging Face
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
pipe = pipe.to("cuda")  # Move to GPU if available


The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.
0it [00:00, ?it/s]
Fetching 16 files:   6%|▋         | 1/16 [00:01<00:20,  1.38s/it]

In [None]:
import torch
def collate_fn(batch):
    images = [transform(Image.open(item["image"]).convert("RGB")) for item in batch]
    prompts = [item["text"] for item in batch]
    return torch.stack(images), prompts

dataloader = DataLoader(hf_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)
