In [1]:
%env CUDA_VISIBLE_DEVICES=MIG-768d9c1d-110f-52e2-b0a2-3252f78280f8
device="cuda:0"
import io
import os
import math
import random
from types import SimpleNamespace
import torch
import torch.nn as nn
import torch.nn.functional as F
import datasets
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
from fastprogress import progress_bar, master_bar
from IPython.display import display
from diffusers import (
    AutoencoderKL,
    FluxTransformer2DModel,
    FlowMatchEulerDiscreteScheduler,
    FluxControlNetModel,
)
from diffusers.pipelines.flux.pipeline_flux_controlnet import FluxControlNetPipeline
from diffusers.training_utils import free_memory
from diffusers.utils import make_image_grid
from transformers import AutoTokenizer, CLIPTextModel, T5EncoderModel
from huggingface_hub import create_repo, upload_folder

env: CUDA_VISIBLE_DEVICES=MIG-768d9c1d-110f-52e2-b0a2-3252f78280f8


In [2]:
# dataset = datasets.load_dataset("danjacobellis/LSDIR_512_f16c12", split="train")
dataset = datasets.load_dataset("danjacobellis/LSDIR_540", split="train")

Resolving data files:   0%|          | 0/89 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/89 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/85 [00:00<?, ?it/s]

In [3]:
config = SimpleNamespace()
config.pretrained_model_name_or_path = "black-forest-labs/FLUX.1-dev"
config.output_dir = "./flux_controlnet_codec_enhancer"
config.resolution = 512
config.batch_size = 2
config.num_double_layers = 4
config.num_single_layers = 0
config.learning_rate = 1e-5
config.max_train_steps = 15000
config.checkpointing_steps = 200
config.validation_steps = 100
config.gradient_accumulation_steps = 4
config.seed = 42
config.push_to_hub = True
config.hub_model_id = "danjacobellis/FLUX.1-controlnet-codec-enhancer"
config.num_validation_images = 1
config.validation_prompts = [
    "blurry image",
]
config.validation_images = [
    "compressed.png",
]
config.total_steps = config.max_train_steps * config.gradient_accumulation_steps
config.tracker_project_name = "flux_controlnet_codec_enhancer"
config.num_train_epochs = math.ceil(config.max_train_steps / (dataset.num_rows // config.batch_size))
config.guidance_scale = 3.5

In [4]:
os.makedirs(config.output_dir, exist_ok=True)

In [5]:
tokenizer_one = AutoTokenizer.from_pretrained(
    config.pretrained_model_name_or_path, subfolder="tokenizer"
)
tokenizer_two = AutoTokenizer.from_pretrained(
    config.pretrained_model_name_or_path, subfolder="tokenizer_2"
)
text_encoder_one = CLIPTextModel.from_pretrained(
    config.pretrained_model_name_or_path, subfolder="text_encoder"
)
text_encoder_two = T5EncoderModel.from_pretrained(
    config.pretrained_model_name_or_path, subfolder="text_encoder_2"
)

You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers


Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [6]:
vae = AutoencoderKL.from_pretrained(
    config.pretrained_model_name_or_path, subfolder="vae"
)
flux_transformer = FluxTransformer2DModel.from_pretrained(
    config.pretrained_model_name_or_path, subfolder="transformer"
)
flux_controlnet = FluxControlNetModel.from_transformer(
    flux_transformer,
    attention_head_dim=flux_transformer.config["attention_head_dim"],
    num_attention_heads=flux_transformer.config["num_attention_heads"],
    num_layers=config.num_double_layers,
    num_single_layers=config.num_single_layers,
)
noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
    config.pretrained_model_name_or_path, subfolder="scheduler"
)
noise_scheduler_copy = noise_scheduler

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [7]:
vae.requires_grad_(False)
flux_transformer.requires_grad_(False)
text_encoder_one.requires_grad_(False)
text_encoder_two.requires_grad_(False)
flux_controlnet.train();

In [8]:
vae.to(device, dtype=torch.float32)
flux_transformer.to(device, dtype=torch.float32)
text_encoder_one.to(device, dtype=torch.float32)
text_encoder_two.to(device, dtype=torch.float32)
flux_controlnet.to(device, dtype=torch.float32)

# Initialize pipeline for utility functions
pipeline = FluxControlNetPipeline(
    scheduler=noise_scheduler,
    vae=vae,
    text_encoder=text_encoder_one,
    tokenizer=tokenizer_one,
    text_encoder_2=text_encoder_two,
    tokenizer_2=tokenizer_two,
    transformer=flux_transformer,
    controlnet=flux_controlnet,
).to(device)

In [9]:
optimizer = torch.optim.AdamW(
    flux_controlnet.parameters(),
    lr=config.learning_rate,
    betas=(0.9, 0.999),
    weight_decay=1e-2,
    eps=1e-8
)

def lr_sched(i_step, config):
    t = i_step / config.total_steps
    return config.learning_rate * (1 - ((np.cos(np.pi * t)) ** 2))

lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
    optimizer, lr_lambda=lambda i_step: lr_sched(i_step, config)
)

In [10]:
image_transforms = transforms.Compose([
    transforms.CenterCrop(config.resolution),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]),
])

In [11]:
def preprocess(example):
    image = image_transforms(example["image"])
    return {
        "pixel_values": image,
        "conditioning_pixel_values": image,
        "caption": "high-quality enhanced image"  # Generic caption as LSDIR lacks captions
    }

with torch.no_grad():
    train_dataset = dataset.select(range(100)).map(preprocess, remove_columns=["image"])
    train_dataset.set_format("torch")

def compute_embeddings(batch):
    captions = batch["caption"]
    prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(captions, prompt_2=captions)
    text_ids = text_ids.unsqueeze(0).expand(prompt_embeds.shape[0], -1, -1)  # [bs, 512, 3]
    return {
        "prompt_embeds": prompt_embeds,
        "pooled_prompt_embeds": pooled_prompt_embeds,
        "text_ids": text_ids
    }

In [12]:
train_dataset = train_dataset.map(
    compute_embeddings, remove_columns=["caption"]
)

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

In [13]:
def collate_fn(batch):
    pixel_values = torch.stack([example["pixel_values"] for example in batch]).to(torch.float32)
    conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in batch]).to(torch.float32)
    prompt_ids = torch.stack([example["prompt_embeds"] for example in batch]).to(torch.float32)
    pooled_prompt_embeds = torch.stack([example["pooled_prompt_embeds"] for example in batch]).to(torch.float32)
    text_ids = torch.stack([example["text_ids"] for example in batch]).to(torch.float32)
    return {
        "pixel_values": pixel_values,
        "conditioning_pixel_values": conditioning_pixel_values,
        "prompt_ids": prompt_ids,
        "unet_added_conditions": {"pooled_prompt_embeds": pooled_prompt_embeds, "time_ids": text_ids},
    }

dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=config.batch_size,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=4,
    drop_last=True
)

In [14]:
def log_validation(step):
    flux_controlnet.eval()
    pipeline.controlnet = flux_controlnet
    generator = torch.Generator(device=device).manual_seed(config.seed)
    image_logs = []

    for v_prompt, v_image_path in zip(config.validation_prompts, config.validation_images):
        v_image = Image.open(v_image_path).convert("RGB").resize((config.resolution, config.resolution))
        v_image_tensor = conditioning_transforms(v_image).to(device, dtype=torch.float32)
        images = []
        prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(v_prompt, prompt_2=v_prompt)
        for _ in range(config.num_validation_images):
            with torch.no_grad():
                image = pipeline(
                    prompt_embeds=prompt_embeds,
                    pooled_prompt_embeds=pooled_prompt_embeds,
                    control_image=v_image_tensor.unsqueeze(0),
                    num_inference_steps=28,
                    controlnet_conditioning_scale=0.7,
                    guidance_scale=3.5,
                    generator=generator,
                ).images[0]
            images.append(image)
        image_logs.append({"validation_image": v_image, "images": images, "validation_prompt": v_prompt})

    # Display validation results
    for log in image_logs:
        grid = make_image_grid([log["validation_image"]] + log["images"], 1, config.num_validation_images + 1)
        plt.figure(figsize=(15, 5))
        plt.imshow(np.array(grid))
        plt.title(f"Step {step}: {log['validation_prompt']}")
        plt.axis("off")
        plt.show()

    flux_controlnet.train()
    free_memory()
    return image_logs

In [15]:
# Sigma function
def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
    sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype)
    schedule_timesteps = noise_scheduler.timesteps.to(device)
    timesteps = timesteps.to(device)
    step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
    sigma = sigmas[step_indices].flatten()
    while len(sigma.shape) < n_dim:
        sigma = sigma.unsqueeze(-1)
    return sigma

In [16]:
mb = master_bar(range(config.max_train_steps // config.batch_size))
learning_rates = [optimizer.param_groups[0]['lr']]
losses = []
global_step = 0
flux_controlnet.train()

for i_step in mb:
    pb = progress_bar(dataloader, parent=mb)
    for i_batch, batch in enumerate(pb):
        pixel_values = batch["pixel_values"].to(device)
        conditioning_pixel_values = batch["conditioning_pixel_values"].to(device)
        prompt_ids = batch["prompt_ids"].to(device)
        unet_added_conditions = {
            "pooled_prompt_embeds": batch["unet_added_conditions"]["pooled_prompt_embeds"].to(device),
            "time_ids": batch["unet_added_conditions"]["time_ids"].to(device),
        }

        # Encode images to latents
        with torch.no_grad():
            pixel_latents_tmp = vae.encode(pixel_values).latent_dist.sample()
            pixel_latents_tmp = (pixel_latents_tmp - vae.config.shift_factor) * vae.config.scaling_factor
            pixel_latents = pipeline._pack_latents(
                pixel_latents_tmp,
                pixel_values.shape[0],
                pixel_latents_tmp.shape[1],
                pixel_latents_tmp.shape[2],
                pixel_latents_tmp.shape[3],
            )

            control_latents_tmp = vae.encode(conditioning_pixel_values).latent_dist.sample()
            control_latents_tmp = (control_latents_tmp - vae.config.shift_factor) * vae.config.scaling_factor
            control_image = pipeline._pack_latents(
                control_latents_tmp,
                conditioning_pixel_values.shape[0],
                control_latents_tmp.shape[1],
                control_latents_tmp.shape[2],
                control_latents_tmp.shape[3],
            )

        # Prepare latent image IDs using unpacked latents
        latent_image_ids = pipeline._prepare_latent_image_ids(
            batch_size=pixel_latents_tmp.shape[0],
            height=pixel_latents_tmp.shape[2] // 2,
            width=pixel_latents_tmp.shape[3] // 2,
            device=device,
            dtype=pixel_latents.dtype,
        )

        # Sample noise and timesteps
        noise = torch.randn_like(pixel_latents).to(device)
        timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (pixel_latents.shape[0],), device=device).long()

        # Add noise
        sigmas = get_sigmas(timesteps, n_dim=pixel_latents.ndim, dtype=pixel_latents.dtype)
        noisy_model_input = (1.0 - sigmas) * pixel_latents + sigmas * noise

        # Guidance
        guidance_vec = torch.full(
            (noisy_model_input.shape[0],),
            config.guidance_scale,  # e.g., 3.5 or your desired scale
            device=device,
            dtype=noisy_model_input.dtype
        )

        # ControlNet forward
        controlnet_block_samples, controlnet_single_block_samples = flux_controlnet(
            hidden_states=noisy_model_input,
            controlnet_cond=control_image,
            timestep=timesteps / 1000,
            guidance=guidance_vec,
            pooled_projections=unet_added_conditions["pooled_prompt_embeds"],
            encoder_hidden_states=prompt_ids,
            txt_ids=unet_added_conditions["time_ids"][0],
            img_ids=latent_image_ids,
            return_dict=False,
        )

        # Flux transformer forward
        noise_pred = flux_transformer(
            hidden_states=noisy_model_input,
            timestep=timesteps / 1000,
            guidance=guidance_vec,
            pooled_projections=unet_added_conditions["pooled_prompt_embeds"],
            encoder_hidden_states=prompt_ids,
            controlnet_block_samples=controlnet_block_samples,
            controlnet_single_block_samples=controlnet_single_block_samples,
            txt_ids=unet_added_conditions["time_ids"],
            img_ids=latent_image_ids,
            return_dict=False,
        )[0]

        # Compute loss
        loss = F.mse_loss(noise_pred, (noise - pixel_latents), reduction="mean")
        losses.append(loss.item())

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        lr_scheduler.step()
        learning_rates.append(optimizer.param_groups[0]['lr'])

        global_step += 1
        pb.comment = f"Loss: {loss.item():.4f}, LR: {learning_rates[-1]:.2g}"

        # Validation and checkpointing
        current_step = i_step * config.batch_size + i_batch
        if global_step % config.validation_steps == 0:
            log_validation(global_step)
        if global_step % config.checkpointing_steps == 0:
            torch.save(flux_controlnet.state_dict(), os.path.join(config.output_dir, f"checkpoint_{global_step}.pth"))

        if global_step >= config.max_train_steps:
            break
    if global_step >= config.max_train_steps:
        break

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

ValueError: not enough values to unpack (expected 6, got 2)