In [12]:
from src.utils import create_kanji_dataset, TrainingConfig
# import torch

dataset = create_kanji_dataset() # hf dataset

In [5]:
# Use the Special Token "Kanji" for DreamBooth Concept Tuning first, then use caption-image matching for Text2Image fune-tune
from datasets import load_dataset
from tqdm import tqdm
import os

# Load dataset
dataset = load_dataset("Ksgk-fy/expanded-kanji-dataset")
os.makedirs("data/kanji_expand_image", exist_ok=True)

# Use special concept token 
special_token = "Kanji"
CAPTION_TEMPLATE = "an image of {special_token} meaning {meaning}"
transform_text = lambda meaning: CAPTION_TEMPLATE.format(special_token=special_token, meaning=meaning)

# Transform all text fields using the template
dataset = dataset.map(lambda x: {'text': transform_text(x['text'])})

# Save images
for idx, img in enumerate(tqdm(dataset['train']['image'])):
    if not os.path.exists(f"data/kanji_expand_image/kanji_{idx}.png"):
        img.save(f"data/kanji_expand_image/kanji_{idx}.png")
        
dataset.push_to_hub("Ksgk-fy/concept-kanji-dataset")        

100%|██████████| 17444/17444 [01:05<00:00, 267.85it/s]


Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Map:   0%|          | 0/17444 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/175 [00:00<?, ?ba/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Map:   0%|          | 0/3752 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/38 [00:00<?, ?ba/s]

CommitInfo(commit_url='https://huggingface.co/datasets/Ksgk-fy/concept-kanji-dataset/commit/d5239864024c88edd66f0f30c40344da68508b7c', commit_message='Upload dataset', commit_description='', oid='d5239864024c88edd66f0f30c40344da68508b7c', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/Ksgk-fy/concept-kanji-dataset', endpoint='https://huggingface.co', repo_type='dataset', repo_id='Ksgk-fy/concept-kanji-dataset'), pr_revision=None, pr_num=None)

In [8]:
from datasets import Dataset 
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import Dataset
import torch 
import os 
from pathlib import Path
tokenize_prompt = lambda x: x

class InceptionDataset(Dataset):
    """
    A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
    It pre-processes the images and the tokenizes prompts.
    """

    def __init__(
        self,
        instance_data_root,
        instance_prompt,
        tokenizer,
        class_data_root=None,
        class_prompt=None,
        class_num=None,
        size=512,
        center_crop=False,
        encoder_hidden_states=None,
        class_prompt_encoder_hidden_states=None,
        tokenizer_max_length=None,
        dataset_name=None,
        text_field="text",
    ):
        self.size = size
        self.center_crop = center_crop
        self.tokenizer = tokenizer
        self.encoder_hidden_states = encoder_hidden_states
        self.class_prompt_encoder_hidden_states = class_prompt_encoder_hidden_states
        self.tokenizer_max_length = tokenizer_max_length

        self.instance_data_root = Path(instance_data_root)
        if not self.instance_data_root.exists():
            raise ValueError(f"Instance {self.instance_data_root} images root doesn't exists.")

        self.instance_images_path = list(Path(instance_data_root).iterdir())
        self.num_instance_images = len(self.instance_images_path)
        self.instance_prompt = instance_prompt
        self._length = self.num_instance_images

        if class_data_root is not None:
            self.class_data_root = Path(class_data_root)
            self.class_data_root.mkdir(parents=True, exist_ok=True)
            self.class_images_path = list(self.class_data_root.iterdir())
            if class_num is not None:
                self.num_class_images = min(len(self.class_images_path), class_num)
            else:
                self.num_class_images = len(self.class_images_path)
            self._length = max(self.num_class_images, self.num_instance_images)
            self.class_prompt = class_prompt
        else:
            self.class_data_root = None

        self.image_transforms = transforms.Compose(
            [
                transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
                transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
                transforms.ToTensor(),
                transforms.Normalize([0.5], [0.5]),
            ]
        )

        # Load instance captions from HF dataset
        self.instance_captions = None
        if dataset_name:
            from datasets import load_dataset
            dataset = load_dataset(dataset_name)
            self.instance_captions = dataset["train"][text_field]
            assert len(self.instance_captions) >= self.num_instance_images, \
                f"Dataset has {len(self.instance_captions)} captions but {self.num_instance_images} images"

    def __len__(self):
        return self._length

    def __getitem__(self, index):
        example = {}
        instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
        # instance_image = exif_transpose(instance_image)

        if not instance_image.mode == "RGB":
            instance_image = instance_image.convert("RGB")
        example["instance_images"] = self.image_transforms(instance_image)

        if self.encoder_hidden_states is not None:
            example["instance_prompt_ids"] = self.encoder_hidden_states
        else:
            # Store both concept and instance prompts
            concept_inputs = torch.ones(1, 1, 768)
            example["concept_prompt_ids"] = concept_inputs.input_ids
            example["concept_attention_mask"] = concept_inputs.attention_mask
            
            if self.instance_captions is not None:
                instance_inputs = tokenize_prompt(
                    self.tokenizer, 
                    self.instance_captions[index % self.num_instance_images], 
                    tokenizer_max_length=self.tokenizer_max_length
                )
                example["instance_prompt_ids"] = instance_inputs.input_ids
                example["instance_attention_mask"] = instance_inputs.attention_mask
            else:
                # If no instance captions, use concept prompt
                example["instance_prompt_ids"] = concept_inputs.input_ids
                example["instance_attention_mask"] = concept_inputs.attention_mask

        if self.class_data_root:
            class_image = Image.open(self.class_images_path[index % self.num_class_images])
            # class_image = exif_transpose(class_image)

            if not class_image.mode == "RGB":
                class_image = class_image.convert("RGB")
            example["class_images"] = self.image_transforms(class_image)

            if self.class_prompt_encoder_hidden_states is not None:
                example["class_prompt_ids"] = self.class_prompt_encoder_hidden_states
            else:
                class_text_inputs = tokenize_prompt(
                    self.tokenizer, self.class_prompt, tokenizer_max_length=self.tokenizer_max_length
                )
                example["class_prompt_ids"] = class_text_inputs.input_ids
                example["class_attention_mask"] = class_text_inputs.attention_mask

        return example

In [9]:
class InceptionDataset(Dataset):
    def __init__(
        self,
        tokenizer,
        dataset_name,
        size=512,
        center_crop=False,
        encoder_hidden_states=None,
        tokenizer_max_length=None,
    ):
        self.size = size
        self.center_crop = center_crop
        self.tokenizer = tokenizer
        self.encoder_hidden_states = encoder_hidden_states
        self.tokenizer_max_length = tokenizer_max_length

        # Load dataset directly from Hugging Face
        from datasets import load_dataset
        self.dataset = load_dataset(dataset_name)["train"]
        self._length = len(self.dataset)

        self.image_transforms = transforms.Compose(
            [
                transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
                transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
                transforms.ToTensor(),
                transforms.Normalize([0.5], [0.5]),
            ]
        )

    def __len__(self):
        return self._length

    def __getitem__(self, index):
        example = {}
        
        # Get image directly from dataset
        instance_image = self.dataset[index]["image"]
        if not instance_image.mode == "RGB":
            instance_image = instance_image.convert("RGB")
        example["instance_images"] = self.image_transforms(instance_image)

        if self.encoder_hidden_states is not None:
            example["instance_prompt_ids"] = self.encoder_hidden_states
        else:
            # Get caption directly from dataset
            instance_inputs = tokenize_prompt(
                self.tokenizer,
                self.dataset[index]["text"],
                tokenizer_max_length=self.tokenizer_max_length
            )
            example["instance_prompt_ids"] = instance_inputs.input_ids
            example["instance_attention_mask"] = instance_inputs.attention_mask

        return example

In [10]:
class InceptionDataset(Dataset):
    def __init__(
        self,
        tokenizer,
        dataset_name,
        size=512,
        center_crop=False,
        encoder_hidden_states=None,
        class_data_root=None,
        class_prompt=None,
        class_num=None,
        class_prompt_encoder_hidden_states=None,
        tokenizer_max_length=None,
    ):
        self.size = size
        self.center_crop = center_crop
        self.tokenizer = tokenizer
        self.encoder_hidden_states = encoder_hidden_states
        self.tokenizer_max_length = tokenizer_max_length
        self.class_prompt_encoder_hidden_states = class_prompt_encoder_hidden_states

        # Load dataset directly from Hugging Face
        from datasets import load_dataset
        self.dataset = load_dataset(dataset_name)["train"]
        self._length = len(self.dataset)

        # Class image handling
        if class_data_root is not None:
            self.class_data_root = Path(class_data_root)
            self.class_data_root.mkdir(parents=True, exist_ok=True)
            self.class_images_path = list(self.class_data_root.iterdir())
            if class_num is not None:
                self.num_class_images = min(len(self.class_images_path), class_num)
            else:
                self.num_class_images = len(self.class_images_path)
            self._length = max(self.num_class_images, len(self.dataset))
            self.class_prompt = class_prompt
        else:
            self.class_data_root = None

        self.image_transforms = transforms.Compose(
            [
                transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
                transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
                transforms.ToTensor(),
                transforms.Normalize([0.5], [0.5]),
            ]
        )

    def __len__(self):
        return self._length

    def __getitem__(self, index):
        example = {}
        
        # Get instance image and text directly from dataset
        instance_image = self.dataset[index % len(self.dataset)]["image"]
        if not instance_image.mode == "RGB":
            instance_image = instance_image.convert("RGB")
        example["instance_images"] = self.image_transforms(instance_image)

        if self.encoder_hidden_states is not None:
            example["instance_prompt_ids"] = self.encoder_hidden_states
        else:
            instance_inputs = tokenize_prompt(
                self.tokenizer,
                self.dataset[index % len(self.dataset)]["text"],
                tokenizer_max_length=self.tokenizer_max_length
            )
            example["instance_prompt_ids"] = instance_inputs.input_ids
            example["instance_attention_mask"] = instance_inputs.attention_mask

        # Handle class images if provided
        if self.class_data_root:
            class_image = Image.open(self.class_images_path[index % self.num_class_images])
            if not class_image.mode == "RGB":
                class_image = class_image.convert("RGB")
            example["class_images"] = self.image_transforms(class_image)

            if self.class_prompt_encoder_hidden_states is not None:
                example["class_prompt_ids"] = self.class_prompt_encoder_hidden_states
            else:
                class_text_inputs = tokenize_prompt(
                    self.tokenizer, 
                    self.class_prompt, 
                    tokenizer_max_length=self.tokenizer_max_length
                )
                example["class_prompt_ids"] = class_text_inputs.input_ids
                example["class_attention_mask"] = class_text_inputs.attention_mask

        return example

In [13]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="tokenizer")
dataset = InceptionDataset(tokenizer, "Ksgk-fy/concept-kanji-dataset")

