In [1]:
import math
from tqdm.auto import tqdm

import torch
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import DataLoader

from diffusers.optimization import get_scheduler
from google.cloud import storage

from dataset import FontDataset, CollateFN
from model import FontDiffuserModel
from criterion import ContentPerceptualLoss
from build import build_unet, build_style_encoder, build_content_encoder, build_ddpm_scheduler
from args import TrainPhase1Args
from utils import x0_from_epsilon, reNormalize_img, normalize_mean_std, save_model

pygame 2.6.0 (SDL 2.28.4, Python 3.10.14)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
class TrainPhase1Args:
    def __init__(self, r):
        # My Configs
        self.bucket_name = "leo_font"
        self.savepath = "exp0604/phase1"
        self.datapath = "/home/jupyter/ai_font/data"
        self.scr = False
        self.num_neg = None
        self.experiment_name = "phase1"
        self.resolution= r
        self.content_font = '시스템굴림'
        
        # Given
        self.unet_channels=(64, 128, 256, 512,)
        self.beta_scheduler="scaled_linear"
        self.adam_beta1 = 0.9
        self.adam_beta2 = 0.999
        self.adam_weight_decay = 1e-2
        self.adam_epsilon = 1e-08
        self.max_grad_norm = 1.0
        self.seed = 123
        self.style_image_size=r
        self.content_image_size=r 
        self.content_encoder_downsample_size=3
        self.channel_attn=True 
        self.content_start_channel=64 
        self.style_start_channel=64 
        self.train_batch_size=8
        self.perceptual_coefficient=0.01 
        self.offset_coefficient=0.5 
        self.max_train_steps=440000*5
        self.ckpt_interval=40000 
        self.gradient_accumulation_steps=1 
        self.log_interval=50 
        self.learning_rate=1e-4 
        self.lr_scheduler="linear" 
        self.lr_warmup_steps=10000 
        self.drop_prob=0.1 
        self.mixed_precision="no"

In [3]:
r = 128
args = TrainPhase1Args(r)
unet = build_unet(args=args)
style_encoder = build_style_encoder(args=args)
content_encoder = build_content_encoder(args=args)
noise_scheduler = build_ddpm_scheduler(args)

Load the down block  DownBlock2D
Load the down block  MCADownBlock2D
The style_attention cross attention dim in Down Block 1 layer is 1024
The style_attention cross attention dim in Down Block 2 layer is 1024
Load the down block  MCADownBlock2D
The style_attention cross attention dim in Down Block 1 layer is 1024
The style_attention cross attention dim in Down Block 2 layer is 1024
Load the down block  DownBlock2D
Load the up block  UpBlock2D
Load the up block  StyleRSIUpBlock2D
Load the up block  StyleRSIUpBlock2D
Load the up block  UpBlock2D
Param count for Ds initialized parameters: 20591296
Get CG-GAN Style Encoder!
Param count for Ds initialized parameters: 19541696
Get CG-GAN Content Encoder!


In [4]:
model = FontDiffuserModel(
    unet=unet,
    style_encoder=style_encoder,
    content_encoder=content_encoder)
model = model.cuda()

In [5]:
content_images = torch.ones([2,3,r,r]).cuda()
style_images = torch.ones([2,3,r,r]).cuda()
target_images = torch.ones([2,3,r,r]).cuda()

noise = torch.randn_like(target_images)
bsz = target_images.shape[0]
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=target_images.device)
timesteps = timesteps.long()

noisy_target_images = noise_scheduler.add_noise(target_images, noise, timesteps)

noise_pred, offset_out_sum = model(
    x_t=noisy_target_images,
    timesteps=timesteps,
    style_images=style_images,
    content_images=content_images,
    content_encoder_downsample_size=args.content_encoder_downsample_size)

  deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)
  style_img_feature, _, _ = self.style_encoder(style_images)


[f.shape for f in style_structure_features] [torch.Size([2, 3, 128, 128]), torch.Size([2, 64, 64, 64]), torch.Size([2, 128, 32, 32]), torch.Size([2, 256, 16, 16]), torch.Size([2, 512, 8, 8]), torch.Size([2, 1024, 4, 4])]
[f.shape for f in style_structure_features] [torch.Size([2, 3, 128, 128]), torch.Size([2, 64, 64, 64]), torch.Size([2, 128, 32, 32]), torch.Size([2, 256, 16, 16]), torch.Size([2, 512, 8, 8]), torch.Size([2, 1024, 4, 4])]


  content_img_feature, content_residual_features = self.content_encoder(content_images)
  style_content_feature, style_content_res_features = self.content_encoder(style_images)
  out = self.unet(


In [6]:
unet.up_blocks[1].sc_interpreter_offsets[0].gnorm_s.weight.shape[0]

128

In [7]:
a = torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0)(torch.ones([2,256,32,32]))

In [8]:
a.shape

torch.Size([2, 256, 16, 16])