# Conditional Glyffuser

### Define training parameters
We add text_encoder and encoder_dim parameters to the config. The batch size is also smaller due to the larger size of the conditional model.

In [1]:
from dataclasses import dataclass

@dataclass
class TrainingConfig:
    image_size = 128  # the generated image resolution
    train_batch_size = 16
    eval_batch_size = 16  # how many images to sample during evaluation
    num_epochs = 100
    gradient_accumulation_steps = 1
    learning_rate = 1e-4
    lr_warmup_steps = 500
    save_image_epochs = 10
    save_model_epochs = 30
    mixed_precision = "fp16"  # `no` for float32, `fp16` for automatic mixed precision
    output_dir = "glyffuser"  # the model name
    overwrite_output_dir = True  # overwrite the old model when re-running the notebook
    text_encoder = "google-t5/t5-small"
    encoder_dim = 512
    seed = 0

config = TrainingConfig()

Additional plumbing has been moved to the `glyffuser_utils` module

In [2]:
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
from datasets import load_dataset
from glyffuser_utils import Collator, GlyffuserPipeline, evaluate
from diffusers.optimization import get_cosine_schedule_with_warmup
from diffusers import UNet2DConditionModel, DDPMScheduler, DPMSolverMultistepScheduler
from accelerate import Accelerator
from tqdm.auto import tqdm
import os

def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler):
    # Initialize accelerator and tensorboard logging
    accelerator = Accelerator(
        mixed_precision=config.mixed_precision,                                                                                       
        gradient_accumulation_steps=config.gradient_accumulation_steps, 
        log_with="tensorboard",
        project_dir=os.path.join(config.output_dir, "logs")
    )
    if accelerator.is_main_process:

        if config.output_dir is not None:
            os.makedirs(config.output_dir, exist_ok=True)
        accelerator.init_trackers("train_example")

    # Prepare everything for accelerator
    model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
        model, optimizer, train_dataloader, lr_scheduler
    )
    
    global_step = 0

    # Now you train the model
    for epoch in range(config.num_epochs):
        progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)
        progress_bar.set_description(f"Epoch {epoch}")

        for step, batch in enumerate(train_dataloader):
            clean_images = batch[0]
            noise = torch.randn(clean_images.shape).to(clean_images.device)
            bs = clean_images.shape[0]
            timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bs,), device=clean_images.device).long()
            noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)
            with accelerator.accumulate(model):
                noise_pred = model(
                    noisy_images, 
                    timesteps, 
                    encoder_hidden_states=batch[1],
                    return_dict=False
                    )[0]
                loss = F.mse_loss(noise_pred, noise)
                accelerator.backward(loss)
                accelerator.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()
            
            progress_bar.update(1)
            logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
            progress_bar.set_postfix(**logs)
            accelerator.log(logs, step=global_step)
            global_step += 1

        if accelerator.is_main_process:
            pipeline = GlyffuserPipeline(
                unet=accelerator.unwrap_model(model),
                scheduler=inference_scheduler)            
            texts=[ # Provides some text prompts for sampling
                "vicious, cruel; severely, extreme",
                "panting of cow; grunting of ox",
                "sing; folksong, ballad; rumor",
                "a quiver"
            ]

            if (epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1:
                pipeline = GlyffuserPipeline(unet=accelerator.unwrap_model(model), scheduler=inference_scheduler)
                pipeline.save_pretrained(config.output_dir) 


  from .autonotebook import tqdm as notebook_tqdm


We add a collator which helps the dataloader pass text embeddings to the model during training

In [3]:
# Define collator
collator = Collator(
    image_size=128,
    text_label='caption',  # This is where you specify which field to use for text
    image_label='file_name',
    name=config.text_encoder,
    channels='L'
)

# Define data source
dataset = load_dataset("json", data_files="metadata.jsonl")
train_dataloader = DataLoader(dataset['train'], batch_size=config.train_batch_size, collate_fn=collator, shuffle=True)

# Define model
model = UNet2DConditionModel(
    sample_size=config.image_size,  # the target image resolution
    in_channels=1,  # the number of input channels, 1 for RGB images
    out_channels=1,  # the number of output channels
    layers_per_block=2,  # how many ResNet layers to use per UNet block
    # block_out_channels=(128, 128, 256, 256, 512, 512),  # the number of output channels for each UNet block
    block_out_channels=(128, 256, 512, 512),  # the number of output channels for each UNet block

    addition_embed_type="text", # Make it conditional
    cross_attention_dim=config.encoder_dim,
    encoder_hid_dim=config.encoder_dim,  # the hidden dimension of the encoder
    encoder_hid_dim_type="text_proj",  # the hidden dimension of the encoder
    down_block_types=(
        "DownBlock2D",
        "DownBlock2D",
        "CrossAttnDownBlock2D",
        "DownBlock2D"
    ),
    up_block_types=(
        "UpBlock2D",
        "CrossAttnUpBlock2D",
        "UpBlock2D",
        "UpBlock2D"
    ),
)

# Define optimizers and schedulers
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
inference_scheduler = DPMSolverMultistepScheduler()
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=config.lr_warmup_steps,
    num_training_steps=(len(train_dataloader) * config.num_epochs),
)
# Check model parameters
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'{total_params:,} total parameters.')

118,923,137 total parameters.


Pass everything to accelerate and run the training loop

In [None]:
from accelerate import notebook_launcher

args = (config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler)
notebook_launcher(train_loop, args, num_processes=1)