In [1]:
import math
from tqdm.auto import tqdm

In [2]:
import torch
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import DataLoader

In [3]:
from diffusers.optimization import get_scheduler
from google.cloud import storage

storage_client = storage.Client("bucket_name")
bucket = storage_client.bucket("bucket_name")

In [4]:
import sys
import importlib

def call_module(nm, path):
    spec = importlib.util.spec_from_file_location(nm, path)
    foo = importlib.util.module_from_spec(spec)
    sys.modules[nm] = foo
    spec.loader.exec_module(foo)
    return foo

fd = "/home/jupyter/ai_font/experiments/font_diffuser_oskar"

dataset = call_module('dataset', f"{fd}/dataset.py")
FontDataset = dataset.FontDataset
CollateFN = dataset.CollateFN

model = call_module('model', f"{fd}/model.py")
FontDiffuserModel = model.FontDiffuserModel

criterion = call_module('criterion', f"{fd}/criterion.py")
ContentPerceptualLoss = criterion.ContentPerceptualLoss

build = call_module('build', f"{fd}/build.py")
build_unet = build.build_unet
build_style_encoder = build.build_style_encoder
build_content_encoder = build.build_content_encoder
build_ddpm_scheduler = build.build_ddpm_scheduler

args = call_module('args', f"{fd}/args.py")
TrainPhase1Args = args.TrainPhase1Args

utils = call_module('utils', f"{fd}/utils.py")
save_args_to_yaml = utils.save_args_to_yaml
x0_from_epsilon = utils.x0_from_epsilon
reNormalize_img = utils.reNormalize_img
normalize_mean_std = utils.normalize_mean_std

pygame 2.5.2 (SDL 2.28.2, Python 3.10.14)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [5]:

args = TrainPhase1Args()
unet = build_unet(args=args)
style_encoder = build_style_encoder(args=args)
content_encoder = build_content_encoder(args=args)
noise_scheduler = build_ddpm_scheduler(args)

model = FontDiffuserModel(
    unet=unet,
    style_encoder=style_encoder,
    content_encoder=content_encoder)

storage_client = storage.Client(args.bucket_name)
bucket = storage_client.bucket(args.bucket_name)

Load the down block  DownBlock2D
Load the down block  MCADownBlock2D
The style_attention cross attention dim in Down Block 1 layer is 1024
The style_attention cross attention dim in Down Block 2 layer is 1024
Load the down block  MCADownBlock2D
The style_attention cross attention dim in Down Block 1 layer is 1024
The style_attention cross attention dim in Down Block 2 layer is 1024
Load the down block  DownBlock2D
Load the up block  UpBlock2D
Load the up block  StyleRSIUpBlock2D
Load the up block  StyleRSIUpBlock2D
Load the up block  UpBlock2D
Param count for Ds initialized parameters: 20591296
Get CG-GAN Style Encoder!
Param count for Ds initialized parameters: 1187008
Get CG-GAN Content Encoder!


In [6]:
def save_model(state_dict, model_name):
    blob = bucket.blob(f"{args.save_path}/{args.experiment_name}__{model_name}.pth")
    with blob.open("wb", ignore_flush=True) as f:
        torch.save(state_dict, f)

In [7]:
perceptual_loss = ContentPerceptualLoss()
train_font_dataset = FontDataset(path=args.path)
train_dataloader = torch.utils.data.DataLoader(
    train_font_dataset,
    shuffle=True,
    batch_size=args.train_batch_size,
    collate_fn=CollateFN())

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=args.learning_rate,
    betas=(args.adam_beta1, args.adam_beta2),
    weight_decay=args.adam_weight_decay,
    eps=args.adam_epsilon)

lr_scheduler = get_scheduler(
    args.lr_scheduler,
    optimizer=optimizer,
    num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
    num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,)

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /home/jupyter/.cache/torch/hub/checkpoints/vgg16-397923af.pth


  0%|          | 0.00/528M [00:00<?, ?B/s]

FileNotFoundError: [Errno 2] No such file or directory: '/home/jupyter/ai_font/data/zipfiles/raw/size96/seen/'

In [None]:
model = model.cuda()

In [None]:
progress_bar = tqdm(range(args.max_train_steps))
progress_bar.set_description("Steps")
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
global_step = 0
for epoch in range(num_train_epochs):
    for step, samples in enumerate(train_dataloader):
        model.train()
        content_images = samples["content_image"].cuda()
        style_images = samples["style_image"].cuda()
        target_images = samples["target_image"].cuda()
        nonorm_target_images = samples["nonorm_target_image"].cuda()

        noise = torch.randn_like(target_images)
        bsz = target_images.shape[0]
        # Sample a random timestep for each image
        timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=target_images.device)
        timesteps = timesteps.long()

        # Add noise to the target_images according to the noise magnitude at each timestep
        # (this is the forward diffusion process)
        noisy_target_images = noise_scheduler.add_noise(target_images, noise, timesteps)

        # Classifier-free training strategy
        context_mask = torch.bernoulli(torch.zeros(bsz) + args.drop_prob)
        for i, mask_value in enumerate(context_mask):
            if mask_value==1:
                content_images[i, :, :, :] = 1
                style_images[i, :, :, :] = 1

        # Predict the noise residual and compute loss
        noise_pred, offset_out_sum = model(
            x_t=noisy_target_images,
            timesteps=timesteps,
            style_images=style_images,
            content_images=content_images,
            content_encoder_downsample_size=args.content_encoder_downsample_size)
        diff_loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
        offset_loss = offset_out_sum / 2

        # output processing for content perceptual loss
        pred_original_sample_norm = x0_from_epsilon(
            scheduler=noise_scheduler,
            noise_pred=noise_pred,
            x_t=noisy_target_images,
            timesteps=timesteps)
        pred_original_sample = reNormalize_img(pred_original_sample_norm)
        norm_pred_ori = normalize_mean_std(pred_original_sample)
        norm_target_ori = normalize_mean_std(nonorm_target_images)
        percep_loss = perceptual_loss.calculate_loss(
            generated_images=norm_pred_ori,
            target_images=norm_target_ori,
            device=target_images.device)

        loss = diff_loss + \
                args.perceptual_coefficient * percep_loss + \
                    args.offset_coefficient * offset_loss

        loss.backward()

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

        if global_step % 10000 == 0:
            save_model(unet.state_dict(), "unet_%s"%str(global_step))
            save_model(style_encoder.state_dict(), "style_encoder_%s"%str(global_step))
            save_model(content_encoder.state_dict(), "content_encoder_%s"%str(global_step))

        progress_bar.update(1)
        global_step += 1
        train_loss = 0.0

        # Quit
        if global_step >= args.max_train_steps:
            break