In [5]:
%env CUDA_VISIBLE_DEVICES=MIG-768d9c1d-110f-52e2-b0a2-3252f78280f8
device="cuda:0"
import os
import torch
import torch.nn.functional as F
import numpy as np
import datasets
from types import SimpleNamespace
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm
from diffusers import (
    AutoencoderKL,
    FlowMatchEulerDiscreteScheduler,
    FluxTransformer2DModel,
    FluxControlNetModel,
    FluxControlNetPipeline,
)
from transformers import (
    AutoTokenizer,
    CLIPTextModel,
    T5EncoderModel,
)
from accelerate import Accelerator
from IPython.display import display
from torchvision.transforms.v2 import ToPILImage

env: CUDA_VISIBLE_DEVICES=MIG-768d9c1d-110f-52e2-b0a2-3252f78280f8


In [6]:
config = SimpleNamespace(
    pretrained_model_name_or_path="black-forest-labs/FLUX.1-dev",
    dataset_name="danjacobellis/LSDIR_512_f16c12",
    image_column="image",
    conditioning_image_column="conditioning_image",
    caption_column="text",
    output_dir="controlnet_flux_output",
    resolution=512,
    learning_rate=1e-5,
    max_train_steps=15000,
    validation_steps=100,
    checkpointing_steps=200,
    train_batch_size=2,
    gradient_accumulation_steps=1,
    num_double_layers=4,
    num_single_layers=0,
    seed=42,
    num_validation_images=2,
    guidance_scale=3.5,
)

In [None]:
# Create output directory
os.makedirs(config.output_dir, exist_ok=True)

# ## Cell 3: Load Dataset and Add Dummy Captions

dataset = datasets.load_dataset(config.dataset_name, split="train")
# Add a "text" column with empty strings since captions are not used
dataset = dataset.map(lambda x: {"text": ""}, batched=True)
print(f"Dataset loaded with {len(dataset)} samples.")

In [7]:
accelerator = Accelerator(
    gradient_accumulation_steps=config.gradient_accumulation_steps,
    mixed_precision="no",  # Use float32
)

# Load tokenizers and text encoders
tokenizer_one = AutoTokenizer.from_pretrained(
    config.pretrained_model_name_or_path, subfolder="tokenizer"
)
tokenizer_two = AutoTokenizer.from_pretrained(
    config.pretrained_model_name_or_path, subfolder="tokenizer_2"
)
text_encoder_one = CLIPTextModel.from_pretrained(
    config.pretrained_model_name_or_path, subfolder="text_encoder"
)
text_encoder_two = T5EncoderModel.from_pretrained(
    config.pretrained_model_name_or_path, subfolder="text_encoder_2"
)

# Load VAE and Flux transformer
vae = AutoencoderKL.from_pretrained(
    config.pretrained_model_name_or_path, subfolder="vae"
)
flux_transformer = FluxTransformer2DModel.from_pretrained(
    config.pretrained_model_name_or_path, subfolder="transformer"
)

You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers


Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

In [8]:
flux_controlnet = FluxControlNetModel.from_transformer(
    flux_transformer,
    num_layers=config.num_double_layers,
    num_single_layers=config.num_single_layers,
)

In [9]:
noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
    config.pretrained_model_name_or_path, subfolder="scheduler"
)

In [10]:
pipeline = FluxControlNetPipeline(
    scheduler=noise_scheduler,
    vae=vae,
    text_encoder=text_encoder_one,
    tokenizer=tokenizer_one,
    text_encoder_2=text_encoder_two,
    tokenizer_2=tokenizer_two,
    transformer=flux_transformer,
    controlnet=flux_controlnet,
)

In [17]:
vae.to(device, dtype=torch.float32).requires_grad_(False)
flux_transformer.to(device, dtype=torch.float32).requires_grad_(False)
text_encoder_one.to(device, dtype=torch.float32).requires_grad_(False)
text_encoder_two.to(device, dtype=torch.float32).requires_grad_(False)
flux_controlnet.to(device, dtype=torch.float32).train();

In [12]:
image_transforms = transforms.Compose([
    transforms.CenterCrop(config.resolution),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]),  # Scale to [-1, 1]
])

In [13]:
def preprocess_batch(examples):
    images = [image_transforms(image.convert("RGB")) for image in examples[config.image_column]]
    conditioning_images = [image_transforms(image.convert("RGB")) for image in examples[config.conditioning_image_column]]
    return {
        "pixel_values": images,
        "conditioning_pixel_values": conditioning_images,
    }

# Compute text embeddings (all empty prompts)
def compute_embeddings(batch, pipeline):
    captions = batch[config.caption_column]  # List of empty strings
    prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(captions, prompt_2=captions)
    text_ids = text_ids.unsqueeze(0).expand(len(captions), -1, -1)  # Expand to batch size
    return {
        "prompt_embeds": prompt_embeds.to(dtype=torch.float32),
        "pooled_prompt_embeds": pooled_prompt_embeds.to(dtype=torch.float32),
        "text_ids": text_ids.to(dtype=torch.float32),
    }

# Collate function for DataLoader
def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples])
    prompt_embeds = torch.stack([example["prompt_embeds"] for example in examples])
    pooled_prompt_embeds = torch.stack([example["pooled_prompt_embeds"] for example in examples])
    text_ids = torch.stack([example["text_ids"] for example in examples])
    return {
        "pixel_values": pixel_values,
        "conditioning_pixel_values": conditioning_pixel_values,
        "prompt_ids": prompt_embeds,
        "unet_added_conditions": {"pooled_prompt_embeds": pooled_prompt_embeds, "time_ids": text_ids},
    }

In [None]:
with accelerator.main_process_first():
    train_dataset = dataset.with_transform(preprocess_batch)
    train_dataset = train_dataset.map(
        lambda batch: compute_embeddings(batch, pipeline),
        batched=True,
        batch_size=50,
    )

train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    shuffle=True,
    collate_fn=collate_fn,
    batch_size=config.train_batch_size,
    num_workers=0,  # Adjust based on your system
)

In [16]:
optimizer = torch.optim.AdamW(
    flux_controlnet.parameters(),
    lr=config.learning_rate,
    betas=(0.9, 0.999),
    weight_decay=1e-2,
    eps=1e-8,
)
lr_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1.0, total_iters=config.max_train_steps)

In [None]:
flux_controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
    flux_controlnet, optimizer, train_dataloader, lr_scheduler
)

In [None]:
validation_samples = dataset.select(range(config.num_validation_images))
validation_images = [sample[config.conditioning_image_column] for sample in validation_samples]
validation_prompts = [""] * len(validation_images)

def log_validation(step):
    pipeline.controlnet = accelerator.unwrap_model(flux_controlnet)
    pipeline.to(device)
    images = []
    to_pil = ToPILImage()
    for val_image, val_prompt in zip(validation_images, validation_prompts):
        val_image = image_transforms(val_image.convert("RGB")).to(device).unsqueeze(0)
        prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt([val_prompt], prompt_2=[val_prompt])
        with torch.no_grad():
            gen_image = pipeline(
                prompt_embeds=prompt_embeds,
                pooled_prompt_embeds=pooled_prompt_embeds,
                control_image=val_image,
                num_inference_steps=28,
                controlnet_conditioning_scale=0.7,
                guidance_scale=config.guidance_scale,
            ).images[0]
        images.append(gen_image)
    print(f"Validation at step {step}:")
    for orig, cond, gen in zip(
        [sample[config.image_column] for sample in validation_samples],
        validation_images,
        images
    ):
        display(to_pil(image_transforms(orig.convert("RGB")).clamp(-1, 1) / 2 + 0.5))  # Original
        display(to_pil(val_image[0].clamp(-1, 1) / 2 + 0.5))  # Condition
        display(gen)  # Generated

In [None]:
total_batch_size = config.train_batch_size * accelerator.num_processes * config.gradient_accumulation_steps
print(f"***** Running training *****")
print(f"  Num examples = {len(train_dataset)}")
print(f"  Total optimization steps = {config.max_train_steps}")
print(f"  Total train batch size = {total_batch_size}")

progress_bar = tqdm(range(config.max_train_steps), desc="Steps")
global_step = 0

for step, batch in enumerate(train_dataloader):
    with accelerator.accumulate(flux_controlnet):
        # Encode images to latents
        pixel_values = batch["pixel_values"].to(device, dtype=torch.float32)
        pixel_latents_tmp = vae.encode(pixel_values).latent_dist.sample()
        pixel_latents = pipeline._pack_latents(
            (pixel_latents_tmp - vae.config.shift_factor) * vae.config.scaling_factor,
            pixel_values.shape[0],
            pixel_latents_tmp.shape[1],
            pixel_latents_tmp.shape[2],
            pixel_latents_tmp.shape[3],
        )

        control_values = batch["conditioning_pixel_values"].to(device, dtype=torch.float32)
        control_latents = vae.encode(control_values).latent_dist.sample()
        control_image = pipeline._pack_latents(
            (control_latents - vae.config.shift_factor) * vae.config.scaling_factor,
            control_values.shape[0],
            control_latents.shape[1],
            control_latents.shape[2],
            control_latents.shape[3],
        )

        latent_image_ids = pipeline._prepare_latent_image_ids(
            batch_size=pixel_latents_tmp.shape[0],
            height=pixel_latents_tmp.shape[2] // 2,
            width=pixel_latents_tmp.shape[3] // 2,
            device=device,
            dtype=torch.float32,
        )

        # Sample noise and timesteps
        noise = torch.randn_like(pixel_latents, device=device, dtype=torch.float32)
        bsz = pixel_latents.shape[0]
        timesteps = torch.sigmoid(torch.randn((bsz,), device=device, dtype=torch.float32))

        # Flow matching
        noisy_latents = (1 - timesteps.view(-1, 1, 1)) * pixel_latents + timesteps.view(-1, 1, 1) * noise
        guidance_vec = torch.full((bsz,), config.guidance_scale, device=device, dtype=torch.float32)

        # ControlNet forward pass
        controlnet_block_samples, controlnet_single_block_samples = flux_controlnet(
            hidden_states=noisy_latents,
            controlnet_cond=control_image,
            timestep=timesteps,
            guidance=guidance_vec,
            pooled_projections=batch["unet_added_conditions"]["pooled_prompt_embeds"].to(device, dtype=torch.float32),
            encoder_hidden_states=batch["prompt_ids"].to(device, dtype=torch.float32),
            txt_ids=batch["unet_added_conditions"]["time_ids"][0].to(device, dtype=torch.float32),
            img_ids=latent_image_ids,
            return_dict=False,
        )

        # Transformer forward pass with ControlNet outputs
        noise_pred = flux_transformer(
            hidden_states=noisy_latents,
            timestep=timesteps,
            guidance=guidance_vec,
            pooled_projections=batch["unet_added_conditions"]["pooled_prompt_embeds"].to(device, dtype=torch.float32),
            encoder_hidden_states=batch["prompt_ids"].to(device, dtype=torch.float32),
            controlnet_block_samples=[sample.to(dtype=torch.float32) for sample in controlnet_block_samples],
            controlnet_single_block_samples=[sample.to(dtype=torch.float32) for sample in controlnet_single_block_samples],
            txt_ids=batch["unet_added_conditions"]["time_ids"][0].to(device, dtype=torch.float32),
            img_ids=latent_image_ids,
            return_dict=False,
        )[0]

        # Compute loss
        loss = F.mse_loss(noise_pred, (noise - pixel_latents), reduction="mean")
        accelerator.backward(loss)
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

    if accelerator.sync_gradients:
        progress_bar.update(1)
        global_step += 1
        progress_bar.set_postfix({"loss": loss.item(), "lr": optimizer.param_groups[0]["lr"]})

        # Checkpointing
        if global_step % config.checkpointing_steps == 0:
            save_path = os.path.join(config.output_dir, f"checkpoint-{global_step}")
            accelerator.save_state(save_path)
            print(f"Saved checkpoint to {save_path}")

        # Validation
        if global_step % config.validation_steps == 0:
            log_validation(global_step)

    if global_step >= config.max_train_steps:
        break

# ## Cell 10: Save Final Model

accelerator.wait_for_everyone()
if accelerator.is_main_process:
    flux_controlnet = accelerator.unwrap_model(flux_controlnet)
    flux_controlnet.save_pretrained(config.output_dir)
    print(f"Final model saved to {config.output_dir}")

# Perform final validation
log_validation("final")

accelerator.end_training()