In [12]:
from src.utils import create_kanji_dataset, TrainingConfig
# import torch

dataset = create_kanji_dataset() # hf dataset

#### Model initializing & Upload to HF 
- To be compatible with training script from HF, we need a 'uploaded SD model' on HF.  

In [3]:
from diffusers import UNet2DConditionModel, DDPMScheduler, AutoencoderKL, PNDMScheduler
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTokenizer
from src.utils import get_transform

# Stable Diffusion 3.5-medium 
# Trying to initialize a randomized-unet + SD3.5 VAE & TextEncoder model and upload to HF, so that training script is directly runnable 
# - I should also consider uploading dataset to HF 

# SD 1.5 model loading

model_name = "stable-diffusion-v1-5/stable-diffusion-v1-5"

text_encoder = CLIPTextModel.from_pretrained(
    model_name, subfolder="text_encoder"
)
vae = AutoencoderKL.from_pretrained(
    model_name, subfolder="vae"
)     

noise_scheduler = PNDMScheduler.from_pretrained(
    model_name, 
    subfolder="scheduler"
)

# Load tokenizer (which was missing before)
tokenizer = CLIPTokenizer.from_pretrained(
    model_name, 
    subfolder="tokenizer"
)   
        
# Conditional Denoise Diffusion Model
# unet = UNet2DConditionModel.from_pretrained(
#     model_name, 
#     subfolder="unet"
# )

unet = UNet2DConditionModel(
    sample_size=128,  # probably 64x64 or 128x128 is enough for Kanji
    in_channels=4,
    out_channels=4,
    layers_per_block=2,
    block_out_channels=(64, 128, 256, 256),  # reduced number of channels
    down_block_types=(
        "CrossAttnDownBlock2D",
        "CrossAttnDownBlock2D",
        "DownBlock2D",
        "DownBlock2D",
    ),
    up_block_types=(
        "UpBlock2D",
        "UpBlock2D",
        "CrossAttnUpBlock2D",
        "CrossAttnUpBlock2D",
    ),
    cross_attention_dim=768,  # Dimension of the text embeddings
)

In [None]:
from src.utils import evaluate_kanji_pipeline
evaluate_kanji_pipeline(unet, dataset, n_rows=2, n_cols=4, seed=33, out_dir="runs", out_name="kanji_eval.png")

In [5]:
# Debug against minimal training pipeline ... 
from datasets import load_dataset
dataset_name = "lambdalabs/naruto-blip-captions"
dataset = load_dataset(dataset_name)

Repo card metadata block was not found. Setting CardData to empty.
