In [1]:
%env CUDA_VISIBLE_DEVICES=MIG-768d9c1d-110f-52e2-b0a2-3252f78280f8
device="cuda:0"
import torch
import torch.nn as nn
import PIL.Image
import numpy as np
import datasets
from types import SimpleNamespace
from torchvision import transforms
from diffusers import (
    AutoencoderKL,
    FluxTransformer2DModel,
    FluxControlNetModel,
    FlowMatchEulerDiscreteScheduler,
    FluxControlNetPipeline,
    FluxPipeline,
)
from accelerate import Accelerator
from tqdm.auto import tqdm
from codec import AutoEncoderND  # Your codec module
from IPython.display import display
from torchvision.transforms.v2 import ToPILImage, PILToTensor, CenterCrop

env: CUDA_VISIBLE_DEVICES=MIG-768d9c1d-110f-52e2-b0a2-3252f78280f8


In [2]:
# Cell 3: Load Dataset and Codec
dataset = datasets.load_dataset("danjacobellis/LSDIR").map(lambda x: {"text": ""})

# Load codec model
checkpoint = torch.load('../hf/dance/LF_rgb_f16c12_v1.0.pth', map_location=device, weights_only=False)
config = checkpoint['config']
state_dict = checkpoint['state_dict']

codec_model = AutoEncoderND(
    dim=2,
    input_channels=config.input_channels,
    J=int(config.F**0.5),
    latent_dim=config.latent_dim,
    lightweight_encode=config.lightweight_encode,
    lightweight_decode=config.lightweight_decode
).to(device)
codec_model.load_state_dict(state_dict)
codec_model.eval()
print(f"Codec loaded with {sum(p.numel() for p in codec_model.parameters())/1e6:.2f} M parameters")

Resolving data files:   0%|          | 0/195 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/195 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/178 [00:00<?, ?it/s]

Codec loaded with 61.93 M parameters


In [3]:
config = SimpleNamespace(
    pretrained_model_name_or_path="black-forest-labs/FLUX.1-dev",
    output_dir="controlnet-output",
    resolution=128,
    train_batch_size=2,
    num_train_epochs=10,
    checkpointing_steps=1000,
    learning_rate=1e-5,
    mixed_precision="bf16",
    proportion_empty_prompts=1.0,
    max_train_steps=None,
    gradient_accumulation_steps=1,
    seed=42
)

In [4]:
# Cell 5: Dataset Preparation
def prepare_train_dataset(dataset, resolution):
    image_transforms = transforms.Compose([
        transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR),
        transforms.CenterCrop(resolution),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5]),
    ])
    
    def preprocess_train(examples):
        images = [image_transforms(image.convert("RGB")) for image in examples['image']]
        examples['pixel_values'] = images
        return examples
    
    return dataset.with_transform(preprocess_train)

train_dataset = prepare_train_dataset(dataset['train'], config.resolution)

In [5]:
def compute_embeddings(batch, vae, codec_model, weight_dtype, pipeline):
    device = torch.device("cuda:0")  # Ensure this matches your intended device
    pixel_values = torch.stack(batch['pixel_values']).to(device, dtype=weight_dtype)
    
    # Compute decoded images using codec
    with torch.no_grad():
        z = codec_model.encode(pixel_values)
        latent = codec_model.quantize.compand(z).round()
        control_values = codec_model.decode(latent)
    
    # Encode both with VAE
    pixel_latents_tmp = vae.encode(pixel_values).latent_dist.sample()
    pixel_latents = FluxControlNetPipeline._pack_latents(
        pixel_latents_tmp - vae.config.shift_factor * vae.config.scaling_factor,
        pixel_values.shape[0], pixel_latents_tmp.shape[1], pixel_latents_tmp.shape[2], pixel_latents_tmp.shape[3]
    )
    
    control_latents_tmp = vae.encode(control_values).latent_dist.sample()
    control_latents = FluxControlNetPipeline._pack_latents(
        control_latents_tmp - vae.config.shift_factor * vae.config.scaling_factor,
        control_values.shape[0], control_latents_tmp.shape[1], control_latents_tmp.shape[2], control_latents_tmp.shape[3]
    )
    
    # Dummy text embeddings (empty prompts)
    prompt_batch = [""] * len(batch['pixel_values'])
    prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(prompt_batch, prompt_2=prompt_batch)
    
    # Split text_ids into a list of tensors along the batch dimension
    text_ids_list = [text_ids[i] for i in range(text_ids.shape[0])]
    
    latent_image_ids = FluxControlNetPipeline._prepare_latent_image_ids(
        batch_size=pixel_values.shape[0],
        height=pixel_values.shape[2] // 2,
        width=pixel_values.shape[3] // 2,
        device=device,
        dtype=weight_dtype
    )
    
    return {
        'pixel_latents': pixel_latents,
        'control_latents': control_latents,
        'prompt_ids': prompt_embeds.to(weight_dtype),
        'pooled_prompt_embeds': pooled_prompt_embeds.to(weight_dtype),
        'time_ids': text_ids_list,
        'latent_image_ids': latent_image_ids
    }

In [6]:
def collate_fn(examples):
    return {key: torch.stack([example[key] for example in examples]) for key in examples[0].keys()}

In [7]:
accelerator = Accelerator(
    mixed_precision=config.mixed_precision,
    gradient_accumulation_steps=config.gradient_accumulation_steps,
)

In [8]:
vae = AutoencoderKL.from_pretrained(config.pretrained_model_name_or_path, subfolder="vae").to(device)
flux_transformer = FluxTransformer2DModel.from_pretrained(config.pretrained_model_name_or_path, subfolder="transformer").to(device)
flux_controlnet = FluxControlNetModel.from_transformer(flux_transformer, num_layers=4, num_single_layers=0)
noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(config.pretrained_model_name_or_path, subfolder="scheduler")

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

In [9]:
vae.requires_grad_(False)
flux_transformer.requires_grad_(False)
flux_controlnet.train();

In [10]:
optimizer = torch.optim.AdamW(flux_controlnet.parameters(), lr=config.learning_rate)
train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=config.train_batch_size, shuffle=True, collate_fn=collate_fn
);

In [11]:
weight_dtype = torch.bfloat16 if config.mixed_precision == "bf16" else torch.float32
pipeline = FluxPipeline.from_pretrained(config.pretrained_model_name_or_path, vae=vae).to(device)
compute_embeddings_fn = lambda batch: compute_embeddings(batch, vae, codec_model, weight_dtype, pipeline)
train_dataset = train_dataset.map(compute_embeddings_fn, batched=True, batch_size=1)

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers


OutOfMemoryError: CUDA out of memory. Tried to allocate 180.00 MiB. GPU 0 has a total capacity of 79.14 GiB of which 93.88 MiB is free. Including non-PyTorch memory, this process has 78.95 GiB memory in use. Of the allocated memory 78.49 GiB is allocated by PyTorch, and 94.21 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
# flux_controlnet, optimizer, train_dataloader = accelerator.prepare(flux_controlnet, optimizer, train_dataloader)

In [None]:
# num_update_steps_per_epoch = len(train_dataloader) // config.gradient_accumulation_steps']
# if config.max_train_steps'] is None:
#     config.max_train_steps'] = config.num_train_epochs'] * num_update_steps_per_epoch

In [None]:
# # Training loop
# global_step = 0
# mb = tqdm(range(config.num_train_epochs']), desc="Epochs")
# for epoch in mb:
#     pb = tqdm(train_dataloader, desc="Steps", leave=False)
#     for step, batch in enumerate(pb):
#         with accelerator.accumulate(flux_controlnet):
#             pixel_latents = batch['pixel_latents'].to(weight_dtype)
#             control_latents = batch['control_latents'].to(weight_dtype)
#             latent_image_ids = batch['latent_image_ids'].to(weight_dtype)
            
#             noise = torch.randn_like(pixel_latents).to(device, dtype=weight_dtype)
#             bsz = pixel_latents.shape[0]
#             timesteps = torch.sigmoid(torch.randn((bsz,), device=device, dtype=weight_dtype))
            
#             noisy_latents = (1 - timesteps.view(-1, 1, 1)) * pixel_latents + timesteps.view(-1, 1, 1) * noise
#             guidance_vec = torch.full((bsz,), 3.5, device=device, dtype=weight_dtype)
            
#             controlnet_block_samples, controlnet_single_block_samples = flux_controlnet(
#                 hidden_states=noisy_latents,
#                 controlnet_cond=control_latents,
#                 timestep=timesteps,
#                 guidance=guidance_vec,
#                 pooled_projections=batch['unet_added_conditions']['pooled_prompt_embeds'].to(weight_dtype),
#                 encoder_hidden_states=batch['prompt_ids'].to(weight_dtype),
#                 txt_ids=batch['unet_added_conditions']['time_ids'][0].to(weight_dtype),
#                 img_ids=latent_image_ids[0],
#                 return_dict=False
#             )
            
#             noise_pred = flux_transformer(
#                 hidden_states=noisy_latents,
#                 timestep=timesteps,
#                 guidance=guidance_vec,
#                 pooled_projections=batch['unet_added_conditions']['pooled_prompt_embeds'].to(weight_dtype),
#                 encoder_hidden_states=batch['prompt_ids'].to(weight_dtype),
#                 controlnet_block_samples=[s.to(weight_dtype) for s in controlnet_block_samples],
#                 controlnet_single_block_samples=[s.to(weight_dtype) for s in controlnet_single_block_samples],
#                 txt_ids=batch['unet_added_conditions']['time_ids'][0].to(weight_dtype),
#                 img_ids=latent_image_ids[0],
#                 return_dict=False
#             )[0]
            
#             loss = nn.functional.mse_loss(noise_pred.float(), noise.float())
#             accelerator.backward(loss)
#             optimizer.step()
#             optimizer.zero_grad()
            
#             if accelerator.sync_gradients:
#                 global_step += 1
#                 pb.set_postfix({'loss': loss.item()})
                
#                 if global_step % config.checkpointing_steps == 0:
#                     save_path = f"{config.output_dir}/checkpoint-{global_step}"
#                     accelerator.save_state(save_path)
#                     print(f"Saved checkpoint to {save_path}")
            
#             if global_step >= config.max_train_steps:
#                 break
    
#     if global_step >= config.max_train_steps:
#         break

# # Save final model
# accelerator.wait_for_everyone()
# if accelerator.is_main_process:
#     flux_controlnet = accelerator.unwrap_model(flux_controlnet)
#     flux_controlnet.save_pretrained(config.output_dir)
#     print(f"Model saved to {config.output_dir}")