In [1]:
import math
import os
from tqdm.auto import tqdm

import torch
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import DataLoader

from diffusers.optimization import get_scheduler
from google.cloud import storage

from model import FontDiffuserModel, FontDiffuserModelDPM, FontDiffuserDPMPipeline
from build import build_unet, build_style_encoder, build_content_encoder, build_ddpm_scheduler
from args import SampleArgs
from utils import x0_from_epsilon, reNormalize_img, normalize_mean_std, save_model, load_model

pygame 2.5.2 (SDL 2.28.2, Python 3.10.14)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
args = SampleArgs()
unet = build_unet(args=args)
style_encoder = build_style_encoder(args=args)
content_encoder = build_content_encoder(args=args)
noise_scheduler = build_ddpm_scheduler(args)
storage_client = storage.Client(args.bucket_name)
bucket = storage_client.bucket(args.bucket_name)

Load the down block  DownBlock2D
Load the down block  MCADownBlock2D
The style_attention cross attention dim in Down Block 1 layer is 1024
The style_attention cross attention dim in Down Block 2 layer is 1024
Load the down block  MCADownBlock2D
The style_attention cross attention dim in Down Block 1 layer is 1024
The style_attention cross attention dim in Down Block 2 layer is 1024
Load the down block  DownBlock2D
Load the up block  UpBlock2D
Load the up block  StyleRSIUpBlock2D
Load the up block  StyleRSIUpBlock2D
Load the up block  UpBlock2D
Param count for Ds initialized parameters: 20591296
Get CG-GAN Style Encoder!
Param count for Ds initialized parameters: 1187008
Get CG-GAN Content Encoder!


In [3]:
content_encoder.load_state_dict(load_model(bucket, args.content_encoder_path))
style_encoder.load_state_dict(load_model(bucket, args.style_encoder_path))
unet.load_state_dict(load_model(bucket, args.unet_path))

model = FontDiffuserModelDPM(
    unet=unet,
    style_encoder=style_encoder,
    content_encoder=content_encoder)

model = model.cuda()

train_scheduler = build_ddpm_scheduler(args=args)

pipe = FontDiffuserDPMPipeline(
        model=model,
        ddpm_train_scheduler=train_scheduler,
        model_type=args.model_type,
        guidance_type=args.guidance_type,
        guidance_scale=args.guidance_scale,
    )

In [4]:
from PIL import Image

resolution = 96
normal_transform = transforms.Compose([
        transforms.Resize((resolution, resolution)),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ])

mt = transforms.Compose([
        transforms.Resize((resolution, resolution)),
        transforms.ToTensor()])

def mask_transform(img):
    mask = mt(img)
    mask[mask<0.5] = 0
    mask[mask>0.5] = 1
    return mask

In [5]:
cfd = "/home/jupyter/ai_font/data/exp0820/processed/train_whole/시스템굴림"
sfd = "/home/jupyter/ai_font/data/exp0820/processed/train_assembled/플레이브밤비"
ifd = "/home/jupyter/ai_font/data/exp0820/processed/train_whole/플레이브밤비"
mfd = "/home/jupyter/ai_font/data/exp0820/processed/train_masks/플레이브밤비"
savefd = "/home/jupyter/ai_font/data/exp0820/report/over/플레이브밤비"
tag = 'over'

In [6]:
letters = [f[-5] for f in os.listdir(ifd) if f.endswith(".png")]

In [7]:
len(letters)

79

In [8]:
def get_all_korean():

    def nextKorLetterFrom(letter):
        lastLetterInt = 15572643
        if not letter:
            return '가'
        a = letter
        b = a.encode('utf8')
        c = int(b.hex(), 16)

        if c == lastLetterInt:
            return False

        d = hex(c + 1)
        e = bytearray.fromhex(d[2:])

        flag = True
        while flag:
            try:
                r = e.decode('utf-8')
                flag = False
            except UnicodeDecodeError:
                c = c+1
                d = hex(c)
                e = bytearray.fromhex(d[2:])
        return e.decode()

    returns = []
    flag = True
    k = ''
    while flag:
        k = nextKorLetterFrom(k)
        if k is False:
            flag = False
        else:
            returns.append(k)
    return returns

mapper = {}
for letter in get_all_korean():
    ch1 = (ord(letter) - ord('가'))//588
    ch2 = ((ord(letter) - ord('가')) - (588*ch1)) // 28
    ch3 = (ord(letter) - ord('가')) - (588*ch1) - 28*ch2
    mapper[(ch1, ch2, ch3)] = letter
    
def target_letters(mfd, letter, mapper):
#     hlist = [0,1,2,3,4,5,6,7,20] # ㅏㅐㅑㅒㅓㅔㅕㅖㅣ
#     vlist = [8,12,13,17,18] # ㅗㅛㅜㅠㅡ
#     clist = [9,10,11,14,15,16,19] # ㅘㅙㅚㅝㅞㅟㅢ
    
    moeum_list = [0,1,2,3,20] # ㅏㅐㅑㅒㅣ
    
    ch1 = (ord(letter) - ord('가'))//588
    ch2 = ((ord(letter) - ord('가')) - (588*ch1)) // 28
    ch3 = (ord(letter) - ord('가')) - (588*ch1) - 28*ch2

    maskcheck1 = os.path.exists(f"{mfd}/플레이브밤비__{letter}_m1.png")
    maskcheck3 = os.path.exists(f"{mfd}/플레이브밤비__{letter}_m3.png")
    maskcheck2 = os.path.exists(f"{mfd}/플레이브밤비__{letter}_m2.png")
    
    
    if (ch2 in moeum_list) & (ch3 != 0) & maskcheck1 & maskcheck3 & maskcheck2:
        m1 = mask_transform(Image.open(f"{mfd}/플레이브밤비__{letter}_m1.png").convert("RGB"))
        m3 = mask_transform(Image.open(f"{mfd}/플레이브밤비__{letter}_m3.png").convert("RGB"))
        mask_image = m1*m3
        unmask_image = mask_transform(Image.open(f"{mfd}/플레이브밤비__{letter}_m2.png").convert("RGB"))
        targets = [mapper[(ch1,l2,ch3)] for l2 in moeum_list if l2 != ch2] 
    else:
        targets = []
        mask_image = None
        unmask_image = None
        
    return targets, mask_image, unmask_image


In [9]:
os.makedirs(savefd, exist_ok=True)
for letter in tqdm(letters):
    targets, mask_image, unmask_image = target_letters(mfd, letter, mapper)
    if mask_image is not None:
        inpaint_image = normal_transform(Image.open(f"{ifd}/플레이브밤비__closing__{letter}.png").convert("RGB"))
        mask_image = mask_image.unsqueeze(0).cuda()
        unmask_image = unmask_image.unsqueeze(0).cuda()
        inpaint_image = inpaint_image.unsqueeze(0).cuda()
        for tetter in targets:
            if os.path.exists(f"{sfd}/플레이브밤비__closing__{tetter}.png"):
                content_image = normal_transform(Image.open(f"{cfd}/시스템굴림__closing__{tetter}.png").convert("RGB"))
                style_image = normal_transform(Image.open(f"{sfd}/플레이브밤비__closing__{tetter}.png").convert("RGB"))
                content_image = content_image.unsqueeze(0).cuda()
                style_image = style_image.unsqueeze(0).cuda()

                images = pipe.generate(
                    content_images=content_image,
                    style_images=style_image,
                    mask_images=mask_image,
                    inpaint_images=inpaint_image,
                    unmask_images = unmask_image,
                    batch_size=content_image.shape[0],
                    order=args.order,
                    num_inference_step=args.num_inference_steps,
                    content_encoder_downsample_size=args.content_encoder_downsample_size,
                    t_start=args.t_start,
                    t_end=args.t_end,
                    dm_size=args.content_image_size,
                    algorithm_type=args.algorithm_type,
                    skip_type=args.skip_type,
                    method=args.method,
                    correcting_x0_fn=args.correcting_x0_fn)
                images[0].save(f"{savefd}/플레이브밤비__{tag}__{letter}_{tetter}.png")

  0%|          | 0/79 [00:00<?, ?it/s]

  style_img_feature, _, style_residual_features = self.style_encoder(style_images)
  content_img_feture, content_residual_features = self.content_encoder(content_images)
  style_content_feature, style_content_res_features = self.content_encoder(style_images)
  out = self.unet(


In [10]:
def target_letters(mfd, letter, mapper):
#     hlist = [0,1,2,3,4,5,6,7,20] # ㅏㅐㅑㅒㅓㅔㅕㅖㅣ
#     vlist = [8,12,13,17,18] # ㅗㅛㅜㅠㅡ
#     clist = [9,10,11,14,15,16,19] # ㅘㅙㅚㅝㅞㅟㅢ
    
    moeum_list = [0,1,2,3,20] # ㅏㅐㅑㅒㅣ
    
    ch1 = (ord(letter) - ord('가'))//588
    ch2 = ((ord(letter) - ord('가')) - (588*ch1)) // 28
    ch3 = (ord(letter) - ord('가')) - (588*ch1) - 28*ch2

    maskcheck1 = os.path.exists(f"{mfd}/플레이브밤비__{letter}_m1.png")
    maskcheck2 = os.path.exists(f"{mfd}/플레이브밤비__{letter}_m2.png")
    
    if (ch2 in moeum_list) & (ch3 == 0) & maskcheck1 & maskcheck2:
        mask_image = mask_transform(Image.open(f"{mfd}/플레이브밤비__{letter}_m1.png").convert("RGB"))
        unmask_image = mask_transform(Image.open(f"{mfd}/플레이브밤비__{letter}_m2.png").convert("RGB"))
        targets = [mapper[(ch1,l2,ch3)] for l2 in moeum_list if l2 != ch2] 
    else:
        targets = []
        mask_image = None
        unmask_image = None
        
    return targets, mask_image, unmask_image


In [11]:
os.makedirs(savefd, exist_ok=True)
for letter in tqdm(letters):
    targets, mask_image, unmask_image = target_letters(mfd, letter, mapper)
    if mask_image is not None:
        inpaint_image = normal_transform(Image.open(f"{ifd}/플레이브밤비__closing__{letter}.png").convert("RGB"))
        mask_image = mask_image.unsqueeze(0).cuda()
        unmask_image = unmask_image.unsqueeze(0).cuda()
        inpaint_image = inpaint_image.unsqueeze(0).cuda()
        for tetter in targets:
            if os.path.exists(f"{sfd}/플레이브밤비__closing__{tetter}.png"):
                content_image = normal_transform(Image.open(f"{cfd}/시스템굴림__closing__{tetter}.png").convert("RGB"))
                style_image = normal_transform(Image.open(f"{sfd}/플레이브밤비__closing__{tetter}.png").convert("RGB"))
                content_image = content_image.unsqueeze(0).cuda()
                style_image = style_image.unsqueeze(0).cuda()

                images = pipe.generate(
                    content_images=content_image,
                    style_images=style_image,
                    mask_images=mask_image,
                    inpaint_images=inpaint_image,
                    unmask_images = unmask_image,
                    batch_size=content_image.shape[0],
                    order=args.order,
                    num_inference_step=args.num_inference_steps,
                    content_encoder_downsample_size=args.content_encoder_downsample_size,
                    t_start=args.t_start,
                    t_end=args.t_end,
                    dm_size=args.content_image_size,
                    algorithm_type=args.algorithm_type,
                    skip_type=args.skip_type,
                    method=args.method,
                    correcting_x0_fn=args.correcting_x0_fn)
                images[0].save(f"{savefd}/플레이브밤비__{tag}__{letter}_{tetter}.png")

  0%|          | 0/79 [00:00<?, ?it/s]

In [12]:
def target_letters(mfd, letter, mapper):
#     hlist = [0,1,2,3,4,5,6,7,20] # ㅏㅐㅑㅒㅓㅔㅕㅖㅣ
#     vlist = [8,12,13,17,18] # ㅗㅛㅜㅠㅡ
#     clist = [9,10,11,14,15,16,19] # ㅘㅙㅚㅝㅞㅟㅢ
    
    moeum_list = [13,17,18] # ㅜㅠㅡ
    
    ch1 = (ord(letter) - ord('가'))//588
    ch2 = ((ord(letter) - ord('가')) - (588*ch1)) // 28
    ch3 = (ord(letter) - ord('가')) - (588*ch1) - 28*ch2

    maskcheck1 = os.path.exists(f"{mfd}/플레이브밤비__{letter}_m1.png")
    maskcheck2 = os.path.exists(f"{mfd}/플레이브밤비__{letter}_m2.png")
    
    if (ch2 in moeum_list) & (ch3 == 0) & maskcheck1 & maskcheck2:
        mask_image = mask_transform(Image.open(f"{mfd}/플레이브밤비__{letter}_m1.png").convert("RGB"))
        unmask_image = mask_transform(Image.open(f"{mfd}/플레이브밤비__{letter}_m2.png").convert("RGB"))
        targets = [mapper[(ch1,l2,ch3)] for l2 in moeum_list if l2 != ch2] 
    else:
        targets = []
        mask_image = None
        unmask_image = None
        
    return targets, mask_image, unmask_image

In [13]:
os.makedirs(savefd, exist_ok=True)
for letter in tqdm(letters):
    targets, mask_image, unmask_image = target_letters(mfd, letter, mapper)
    if mask_image is not None:
        inpaint_image = normal_transform(Image.open(f"{ifd}/플레이브밤비__closing__{letter}.png").convert("RGB"))
        mask_image = mask_image.unsqueeze(0).cuda()
        unmask_image = unmask_image.unsqueeze(0).cuda()
        inpaint_image = inpaint_image.unsqueeze(0).cuda()
        for tetter in targets:
            if os.path.exists(f"{sfd}/플레이브밤비__closing__{tetter}.png"):
                content_image = normal_transform(Image.open(f"{cfd}/시스템굴림__closing__{tetter}.png").convert("RGB"))
                style_image = normal_transform(Image.open(f"{sfd}/플레이브밤비__closing__{tetter}.png").convert("RGB"))
                content_image = content_image.unsqueeze(0).cuda()
                style_image = style_image.unsqueeze(0).cuda()

                images = pipe.generate(
                    content_images=content_image,
                    style_images=style_image,
                    mask_images=mask_image,
                    inpaint_images=inpaint_image,
                    unmask_images = unmask_image,
                    batch_size=content_image.shape[0],
                    order=args.order,
                    num_inference_step=args.num_inference_steps,
                    content_encoder_downsample_size=args.content_encoder_downsample_size,
                    t_start=args.t_start,
                    t_end=args.t_end,
                    dm_size=args.content_image_size,
                    algorithm_type=args.algorithm_type,
                    skip_type=args.skip_type,
                    method=args.method,
                    correcting_x0_fn=args.correcting_x0_fn)
                images[0].save(f"{savefd}/플레이브밤비__{tag}__{letter}_{tetter}.png")

  0%|          | 0/79 [00:00<?, ?it/s]

In [14]:
def target_letters(mfd, letter, mapper):
#     hlist = [0,1,2,3,4,5,6,7,20] # ㅏㅐㅑㅒㅓㅔㅕㅖㅣ
#     vlist = [8,12,13,17,18] # ㅗㅛㅜㅠㅡ
#     clist = [9,10,11,14,15,16,19] # ㅘㅙㅚㅝㅞㅟㅢ
    
    moeum_list = [0,1,2,3,4,5,6,7,20] # ㅏㅐㅑㅒㅓㅔㅕㅖㅣ
    moeum_list += [8,12,18] # ㅗㅛㅡ
    moeum_list += [9,10,11,14,15,16,19] # ㅘㅙㅚㅝㅞㅟㅢ
    
    ch1 = (ord(letter) - ord('가'))//588
    ch2 = ((ord(letter) - ord('가')) - (588*ch1)) // 28
    ch3 = (ord(letter) - ord('가')) - (588*ch1) - 28*ch2

    maskcheck1 = os.path.exists(f"{mfd}/플레이브밤비__{letter}_m1.png")
    maskcheck2 = os.path.exists(f"{mfd}/플레이브밤비__{letter}_m2.png")
    maskcheck3 = os.path.exists(f"{mfd}/플레이브밤비__{letter}_m3.png")
    
    if (ch2 in moeum_list) & (ch3 != 0) & maskcheck1 & maskcheck2 & maskcheck3:
        m1 = mask_transform(Image.open(f"{mfd}/플레이브밤비__{letter}_m1.png").convert("RGB"))
        m2 = mask_transform(Image.open(f"{mfd}/플레이브밤비__{letter}_m2.png").convert("RGB"))
        mask_image = m1*m2
        unmask_image = mask_transform(Image.open(f"{mfd}/플레이브밤비__{letter}_m3.png").convert("RGB"))
        targets = [mapper[(ch1,ch2,l3)] for l3 in moeum_list if (l3 != ch3) & (l3!=0)] 
    else:
        targets = []
        mask_image = None
        unmask_image = None
        
    return targets, mask_image, unmask_image


In [15]:
os.makedirs(savefd, exist_ok=True)
for letter in tqdm(letters):
    targets, mask_image, unmask_image = target_letters(mfd, letter, mapper)
    if mask_image is not None:
        inpaint_image = normal_transform(Image.open(f"{ifd}/플레이브밤비__closing__{letter}.png").convert("RGB"))
        mask_image = mask_image.unsqueeze(0).cuda()
        unmask_image = unmask_image.unsqueeze(0).cuda()
        inpaint_image = inpaint_image.unsqueeze(0).cuda()
        for tetter in targets:
            if os.path.exists(f"{sfd}/플레이브밤비__closing__{tetter}.png"):
                content_image = normal_transform(Image.open(f"{cfd}/시스템굴림__closing__{tetter}.png").convert("RGB"))
                style_image = normal_transform(Image.open(f"{sfd}/플레이브밤비__closing__{tetter}.png").convert("RGB"))
                content_image = content_image.unsqueeze(0).cuda()
                style_image = style_image.unsqueeze(0).cuda()

                images = pipe.generate(
                    content_images=content_image,
                    style_images=style_image,
                    mask_images=mask_image,
                    inpaint_images=inpaint_image,
                    unmask_images = unmask_image,
                    batch_size=content_image.shape[0],
                    order=args.order,
                    num_inference_step=args.num_inference_steps,
                    content_encoder_downsample_size=args.content_encoder_downsample_size,
                    t_start=args.t_start,
                    t_end=args.t_end,
                    dm_size=args.content_image_size,
                    algorithm_type=args.algorithm_type,
                    skip_type=args.skip_type,
                    method=args.method,
                    correcting_x0_fn=args.correcting_x0_fn)
                images[0].save(f"{savefd}/플레이브밤비__{tag}__{letter}_{tetter}.png")

  0%|          | 0/79 [00:00<?, ?it/s]