In [12]:
from src.utils import create_kanji_dataset, TrainingConfig
# import torch

dataset = create_kanji_dataset() # hf dataset

#### Model initializing & Upload to HF 
- To be compatible with training script from HF, we need a 'uploaded SD model' on HF.  

In [3]:
from diffusers import UNet2DConditionModel, DDPMScheduler, AutoencoderKL, PNDMScheduler
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTokenizer
from src.utils import get_transform

# Stable Diffusion 3.5-medium 
# Trying to initialize a randomized-unet + SD3.5 VAE & TextEncoder model and upload to HF, so that training script is directly runnable 
# - I should also consider uploading dataset to HF 

# SD 1.5 model loading

model_name = "stable-diffusion-v1-5/stable-diffusion-v1-5"

text_encoder = CLIPTextModel.from_pretrained(
    model_name, subfolder="text_encoder"
)
vae = AutoencoderKL.from_pretrained(
    model_name, subfolder="vae"
)     

noise_scheduler = PNDMScheduler.from_pretrained(
    model_name, 
    subfolder="scheduler"
)

# Load tokenizer (which was missing before)
tokenizer = CLIPTokenizer.from_pretrained(
    model_name, 
    subfolder="tokenizer"
)   
        
# Conditional Denoise Diffusion Model
# unet = UNet2DConditionModel.from_pretrained(
#     model_name, 
#     subfolder="unet"
# )

unet = UNet2DConditionModel(
    sample_size=128,  # probably 64x64 or 128x128 is enough for Kanji
    in_channels=4,
    out_channels=4,
    layers_per_block=2,
    block_out_channels=(64, 128, 256, 256),  # reduced number of channels
    down_block_types=(
        "CrossAttnDownBlock2D",
        "CrossAttnDownBlock2D",
        "DownBlock2D",
        "DownBlock2D",
    ),
    up_block_types=(
        "UpBlock2D",
        "UpBlock2D",
        "CrossAttnUpBlock2D",
        "CrossAttnUpBlock2D",
    ),
    cross_attention_dim=768,  # Dimension of the text embeddings
)

In [None]:
from src.utils import evaluate_kanji_pipeline
evaluate_kanji_pipeline(unet, dataset, n_rows=2, n_cols=4, seed=33, out_dir="runs", out_name="kanji_eval.png")

In [5]:
# Debug against minimal training pipeline ... 
from datasets import load_dataset
dataset_name = "lambdalabs/naruto-blip-captions"
dataset = load_dataset(dataset_name)

Repo card metadata block was not found. Setting CardData to empty.


In [4]:
dataset_name = "Ksgk-fy/kanji-dataset"
from datasets import load_dataset
dataset = load_dataset(dataset_name)
resolution = 128
gray_scale = True
from torchvision import transforms

# Preprocessing
def preprocess_train(examples):
    if gray_scale:
        # Create grayscale-specific transforms
        train_transforms = transforms.Compose([
            transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.CenterCrop(resolution),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),  # Single channel normalization for grayscale
        ])
        images = [image.convert("L") for image in examples["image"]]
    else:
        # Use original RGB transforms
        train_transforms = transforms.Compose([
            transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.CenterCrop(resolution),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),  # Three channel normalization for RGB
        ])
        images = [image.convert("RGB") for image in examples["image"]]
    
    examples["pixel_values"] = [train_transforms(image) for image in images]
    
    # Ensure text input is properly formatted
    text = examples["text"]
    if isinstance(text, list):
        text = [str(t) for t in text]  # Convert all items to strings
    else:
        text = [str(text)]  # Convert single item to list of strings
    
    # examples["input_ids"] = tokenizer(
    #     text,
    #     padding="max_length",
    #     truncation=True,
    #     max_length=tokenizer.model_max_length,
    #     return_tensors="pt",
    # ).input_ids
    return examples


train_dataset = dataset["train"].with_transform(preprocess_train)

In [6]:
train_dataset[0]['pixel_values'].shape

torch.Size([1, 128, 128])

In [7]:
from diffusers import StableDiffusionPipeline

# Create the complete pipeline
pipeline = StableDiffusionPipeline(
    vae=vae,
    text_encoder=text_encoder,
    tokenizer=tokenizer,
    unet=unet,
    scheduler=noise_scheduler,
    safety_checker=None,  # Optional: set to None if you don't need content filtering
    feature_extractor=None,  # Optional: set to None if you don't need feature extraction
    requires_safety_checker=False
)

from src.utils import rgb_to_gray
img = pipeline("girrafe on a plane", num_inference_steps=25).images[0]
gray_img = rgb_to_gray(img)
gray_img # visualize and check gray scale conversion