In [1]:
%env CUDA_VISIBLE_DEVICES=MIG-768d9c1d-110f-52e2-b0a2-3252f78280f8
device="cuda:0"
import io
import os
import math
import random
from types import SimpleNamespace
import torch
import torch.nn as nn
import torch.nn.functional as F
import datasets
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
from fastprogress import progress_bar, master_bar
from IPython.display import display
from diffusers import (
    AutoencoderKL,
    FluxTransformer2DModel,
    FlowMatchEulerDiscreteScheduler,
    FluxControlNetModel,
)
from diffusers.pipelines.flux.pipeline_flux_controlnet import FluxControlNetPipeline
from diffusers.training_utils import free_memory
from diffusers.utils import make_image_grid
from transformers import AutoTokenizer, CLIPTextModel, T5EncoderModel
from huggingface_hub import create_repo, upload_folder

env: CUDA_VISIBLE_DEVICES=MIG-768d9c1d-110f-52e2-b0a2-3252f78280f8


In [2]:
# dataset = datasets.load_dataset("danjacobellis/LSDIR_512_f16c12", split="train")
dataset = datasets.load_dataset("danjacobellis/LSDIR_540", split="train")

Resolving data files:   0%|          | 0/89 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/89 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/85 [00:00<?, ?it/s]

In [3]:
config = SimpleNamespace()
config.pretrained_model_name_or_path = "black-forest-labs/FLUX.1-dev"
config.output_dir = "./flux_controlnet_codec_enhancer"
config.resolution = 512
config.batch_size = 2
config.num_double_layers = 4
config.num_single_layers = 0
config.learning_rate = 1e-5
config.max_train_steps = 15000
config.checkpointing_steps = 200
config.validation_steps = 100
config.gradient_accumulation_steps = 4
config.seed = 42
config.push_to_hub = True
config.hub_model_id = "danjacobellis/FLUX.1-controlnet-codec-enhancer"
config.num_validation_images = 1
config.validation_prompts = [
    "blurry image",
]
config.validation_images = [
    "compressed.png",
]
config.total_steps = config.max_train_steps * config.gradient_accumulation_steps
config.tracker_project_name = "flux_controlnet_codec_enhancer"

In [4]:
os.makedirs(config.output_dir, exist_ok=True)

In [5]:
tokenizer_one = AutoTokenizer.from_pretrained(
    config.pretrained_model_name_or_path, subfolder="tokenizer"
)
tokenizer_two = AutoTokenizer.from_pretrained(
    config.pretrained_model_name_or_path, subfolder="tokenizer_2"
)
text_encoder_one = CLIPTextModel.from_pretrained(
    config.pretrained_model_name_or_path, subfolder="text_encoder"
)
text_encoder_two = T5EncoderModel.from_pretrained(
    config.pretrained_model_name_or_path, subfolder="text_encoder_2"
)

You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers


Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [6]:
vae = AutoencoderKL.from_pretrained(
    config.pretrained_model_name_or_path, subfolder="vae"
)
flux_transformer = FluxTransformer2DModel.from_pretrained(
    config.pretrained_model_name_or_path, subfolder="transformer"
)
flux_controlnet = FluxControlNetModel.from_transformer(
    flux_transformer,
    attention_head_dim=flux_transformer.config["attention_head_dim"],
    num_attention_heads=flux_transformer.config["num_attention_heads"],
    num_layers=config.num_double_layers,
    num_single_layers=config.num_single_layers,
)
noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
    config.pretrained_model_name_or_path, subfolder="scheduler"
)

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

In [7]:
vae.requires_grad_(False)
flux_transformer.requires_grad_(False)
text_encoder_one.requires_grad_(False)
text_encoder_two.requires_grad_(False)
flux_controlnet.train();

In [8]:
vae.to(device, dtype=torch.float32)
flux_transformer.to(device, dtype=torch.float32)
text_encoder_one.to(device, dtype=torch.float32)
text_encoder_two.to(device, dtype=torch.float32)
flux_controlnet.to(device, dtype=torch.float32)

# Initialize pipeline for utility functions
pipeline = FluxControlNetPipeline(
    scheduler=noise_scheduler,
    vae=vae,
    text_encoder=text_encoder_one,
    tokenizer=tokenizer_one,
    text_encoder_2=text_encoder_two,
    tokenizer_2=tokenizer_two,
    transformer=flux_transformer,
    controlnet=flux_controlnet,
).to(device)

In [9]:
optimizer = torch.optim.AdamW(
    flux_controlnet.parameters(),
    lr=config.learning_rate,
    betas=(0.9, 0.999),
    weight_decay=1e-2,
    eps=1e-8
)

def lr_sched(i_step, config):
    t = i_step / config.total_steps
    return config.learning_rate * (1 - ((np.cos(np.pi * t)) ** 2))

lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
    optimizer, lr_lambda=lambda i_step: lr_sched(i_step, config)
)

In [10]:
image_transforms = transforms.Compose([
    transforms.CenterCrop(config.resolution),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]),
])

In [11]:
def preprocess(example):
    image = image_transforms(example["image"])
    return {
        "pixel_values": image,
        "conditioning_pixel_values": image,
        "caption": "high-quality enhanced image"  # Generic caption as LSDIR lacks captions
    }

with torch.no_grad():
    train_dataset = dataset.select(range(100)).map(preprocess, remove_columns=["image"])
    train_dataset.set_format("torch")

def compute_embeddings(batch):
    captions = batch["caption"]
    prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(captions, prompt_2=captions)
    text_ids = text_ids.unsqueeze(0).expand(prompt_embeds.shape[0], -1, -1)  # [bs, 512, 3]
    return {
        "prompt_embeds": prompt_embeds,
        "pooled_prompt_embeds": pooled_prompt_embeds,
        "text_ids": text_ids
    }

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

In [12]:
train_dataset = train_dataset.map(
    compute_embeddings, remove_columns=["caption"]
)

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

In [13]:
def collate_fn(batch):
    pixel_values = torch.stack([example["pixel_values"] for example in batch]).to(torch.float32)
    conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in batch]).to(torch.float32)
    prompt_ids = torch.stack([example["prompt_embeds"] for example in batch]).to(torch.float32)
    pooled_prompt_embeds = torch.stack([example["pooled_prompt_embeds"] for example in batch]).to(torch.float32)
    text_ids = torch.stack([example["text_ids"] for example in batch]).to(torch.float32)
    return {
        "pixel_values": pixel_values,
        "conditioning_pixel_values": conditioning_pixel_values,
        "prompt_ids": prompt_ids,
        "unet_added_conditions": {"pooled_prompt_embeds": pooled_prompt_embeds, "time_ids": text_ids},
    }

dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=config.batch_size,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=4,
    drop_last=True
)

In [14]:
def log_validation(step):
    flux_controlnet.eval()
    pipeline.controlnet = flux_controlnet
    generator = torch.Generator(device=device).manual_seed(config.seed)
    image_logs = []

    for v_prompt, v_image_path in zip(config.validation_prompts, config.validation_images):
        v_image = Image.open(v_image_path).convert("RGB").resize((config.resolution, config.resolution))
        v_image_tensor = conditioning_transforms(v_image).to(device, dtype=torch.float32)
        images = []
        prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(v_prompt, prompt_2=v_prompt)
        for _ in range(config.num_validation_images):
            with torch.no_grad():
                image = pipeline(
                    prompt_embeds=prompt_embeds,
                    pooled_prompt_embeds=pooled_prompt_embeds,
                    control_image=v_image_tensor.unsqueeze(0),
                    num_inference_steps=28,
                    controlnet_conditioning_scale=0.7,
                    guidance_scale=3.5,
                    generator=generator,
                ).images[0]
            images.append(image)
        image_logs.append({"validation_image": v_image, "images": images, "validation_prompt": v_prompt})

    # Display validation results
    for log in image_logs:
        grid = make_image_grid([log["validation_image"]] + log["images"], 1, config.num_validation_images + 1)
        plt.figure(figsize=(15, 5))
        plt.imshow(np.array(grid))
        plt.title(f"Step {step}: {log['validation_prompt']}")
        plt.axis("off")
        plt.show()

    flux_controlnet.train()
    free_memory()
    return image_logs

In [23]:
global_step = 0
learning_rates = [optimizer.param_groups[0]["lr"]]
losses = []
mb = master_bar(range(config.max_train_steps // config.gradient_accumulation_steps))

for epoch in mb:
    pb = progress_bar(dataloader, parent=mb)
    for i_batch, batch in enumerate(pb):
        if global_step >= config.max_train_steps:
            break

        # Accumulate gradients
        for _ in range(config.gradient_accumulation_steps):
            with torch.no_grad():
                pixel_values = batch["pixel_values"].to(device, dtype=torch.float32)
                pixel_latents = FluxControlNetPipeline._pack_latents(
                    vae.encode(pixel_values).latent_dist.sample(),
                    config.batch_size, 16, config.resolution, config.resolution
                )
                control_image = FluxControlNetPipeline._pack_latents(
                    vae.encode(batch["conditioning_pixel_values"].to(device, dtype=torch.float32)).latent_dist.sample(),
                    config.batch_size, 16, config.resolution, config.resolution
                )
                latent_image_ids = FluxControlNetPipeline._prepare_latent_image_ids(
                    config.batch_size, config.resolution // 2, config.resolution // 2, device, torch.float32
                )

            noise = torch.randn_like(pixel_latents, device=device, dtype=torch.float32)
            bsz = pixel_latents.shape[0]
            timesteps = torch.rand(bsz, device=device) * noise_scheduler.config.num_train_timesteps
            noisy_latents = (1 - timesteps.view(-1, 1, 1)) * pixel_latents + timesteps.view(-1, 1, 1) * noise
            guidance_vec = torch.full((bsz,), 3.5, device=device, dtype=torch.float32)

            controlnet_block_samples, controlnet_single_block_samples = flux_controlnet(
                hidden_states=noisy_latents,
                controlnet_cond=control_image,
                timestep=timesteps / 1000,
                guidance=guidance_vec,
                pooled_projections=batch["unet_added_conditions"]["pooled_prompt_embeds"].to(device, dtype=torch.float32),
                encoder_hidden_states=batch["prompt_ids"].to(device, dtype=torch.float32),
                txt_ids=batch["unet_added_conditions"]["time_ids"][0].to(device, dtype=torch.float32),
                img_ids=latent_image_ids,
                return_dict=False,
            )

            noise_pred = flux_transformer(
                hidden_states=noisy_latents,
                timestep=timesteps / 1000,
                guidance=guidance_vec,
                pooled_projections=batch["unet_added_conditions"]["pooled_prompt_embeds"].to(device, dtype=torch.float32),
                encoder_hidden_states=batch["prompt_ids"].to(device, dtype=torch.float32),
                controlnet_block_samples=controlnet_block_samples,
                controlnet_single_block_samples=controlnet_single_block_samples,
                txt_ids=batch["unet_added_conditions"]["time_ids"][0].to(device, dtype=torch.float32),
                img_ids=latent_image_ids,
                return_dict=False,
            )[0]

            loss = F.mse_loss(noise_pred, (noise - pixel_latents), reduction="mean")
            loss.backward()

        # Optimizer step
        torch.nn.utils.clip_grad_norm_(flux_controlnet.parameters(), 1.0)
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

        # Logging
        losses.append(loss.item())
        learning_rates.append(optimizer.param_groups[0]["lr"])
        pb.comment = f"Loss: {losses[-1]:.4f}, LR: {learning_rates[-1]:.2e}"
        global_step += 1

        # Checkpointing
        if global_step % config.checkpointing_steps == 0:
            save_path = os.path.join(config.output_dir, f"checkpoint-{global_step}")
            flux_controlnet.save_pretrained(save_path)
            print(f"Saved checkpoint to {save_path}")

        # Validation
        if global_step % config.validation_steps == 0:
            image_logs = log_validation(global_step)

    if global_step >= config.max_train_steps:
        break

# In[8]: Save Final Model
flux_controlnet.save_pretrained(config.output_dir)
final_image_logs = log_validation(global_step)

if config.push_to_hub:
    repo_id = create_repo(repo_id=config.hub_model_id, exist_ok=True).repo_id
    readme = f"""
# FLUX.1 ControlNet for Codec Enhancement
Trained on {config.pretrained_model_name_or_path} to enhance outputs of a neural image codec.
Dataset: danjacobellis/LSDIR_512_f16c12
"""
    for i, log in enumerate(final_image_logs):
        grid = make_image_grid([log["validation_image"]] + log["images"], 1, config.num_validation_images + 1)
        grid.save(os.path.join(config.output_dir, f"example_{i}.png"))
        readme += f"\n![Example {i}](example_{i}.png)\nPrompt: {log['validation_prompt']}\n"

    with open(os.path.join(config.output_dir, "README.md"), "w") as f:
        f.write(readme)

    upload_folder(
        repo_id=repo_id,
        folder_path=config.output_dir,
        commit_message="Final trained ControlNet",
        ignore_patterns=["checkpoint-*"],
    )

print("Training completed and model saved.")

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

ValueError: not enough values to unpack (expected 6, got 2)