In [1]:
%env CUDA_VISIBLE_DEVICES=MIG-768d9c1d-110f-52e2-b0a2-3252f78280f8
device="cuda:0"
import torch
import datasets
import io
from types import SimpleNamespace
from codec import AutoEncoderND
from torchvision.transforms.v2 import Compose, CenterCrop, ToTensor, Normalize, ToPILImage

env: CUDA_VISIBLE_DEVICES=MIG-768d9c1d-110f-52e2-b0a2-3252f78280f8


In [2]:
config = SimpleNamespace()
config.resolution = 512
config.codec_checkpoint = "../hf/dance/LF_rgb_f16c12_v1.0.pth"

In [3]:
checkpoint = torch.load(config.codec_checkpoint, map_location=device, weights_only=False)
codec_config = checkpoint['config']
state_dict = checkpoint['state_dict']

codec = AutoEncoderND(
    dim=2,
    input_channels=codec_config.input_channels,
    J=int(codec_config.F**0.5),
    latent_dim=codec_config.latent_dim,
    lightweight_encode=codec_config.lightweight_encode,
    lightweight_decode=codec_config.lightweight_decode
).to(device)
codec.load_state_dict(state_dict)
codec.eval();

In [None]:
dataset = datasets.load_dataset("danjacobellis/LSDIR_540",split='train')

transform = Compose([
    CenterCrop(config.resolution),
    ToTensor(),
    Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

def gen_conditioning_img(sample):
    x = transform(sample['image']).to(device).unsqueeze(0)
    img = ToPILImage()(x[0]/2 + 0.5)
    buff = io.BytesIO()
    img.save(buff, format='WEBP', lossless=True)
    webp_bytes = buff.getbuffer()

    with torch.no_grad():
        xhat = codec.decode(codec.quantize.compand(codec.encode(x)).round()).clamp(-1,1)

    img2 = ToPILImage()(xhat[0]/2 + 0.5)
    buff2 = io.BytesIO()
    img2.save(buff2, format='WEBP', lossless=True)
    webp_bytes2 = buff2.getbuffer()
    return {
        'image': webp_bytes,
        'conditioning_image': webp_bytes2
    }

new_dataset = dataset.map(gen_conditioning_img,writer_batch_size=500).cast_column('image',datasets.Image()).cast_column('conditioning_image',datasets.Image())

Resolving data files:   0%|          | 0/89 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/89 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/85 [00:00<?, ?it/s]



Map:   0%|          | 0/84991 [00:00<?, ? examples/s]

In [None]:
new_dataset.push_to_hub("danjacobellis/LSDIR_512_f16c12", split='train')

In [24]:
logger.info("Pre-processing dataset to generate conditioning images...")


03/06/2025 20:32:03 - INFO - __main__ - Pre-processing dataset to generate conditioning images...


Map:   0%|          | 0/200 [00:00<?, ? examples/s]

In [None]:
# logger.info(f"Uploading pre-processed dataset to {config.preprocessed_dataset_name}")
# new_dataset.push_to_hub(config.preprocessed_dataset_name, split='train')
# logger.info("Dataset pre-processing complete")

In [27]:
accelerator = Accelerator(
    mixed_precision="no",  # Use float32 as specified
    log_with="tensorboard",
    project_config=ProjectConfiguration(project_dir=config.output_dir, logging_dir=config.logging_dir)
)
logger.info(accelerator.state)

NameError: name 'ProjectConfiguration' is not defined

In [None]:


# Load Flux models and tokenizers
vae = AutoencoderKL.from_pretrained(
    config.pretrained_model_name_or_path, subfolder="vae"
).to(device, dtype=torch.float32)
flux_transformer = FluxTransformer2DModel.from_pretrained(
    config.pretrained_model_name_or_path, subfolder="transformer"
).to(device, dtype=torch.float32)
flux_controlnet = FluxControlNetModel.from_transformer(
    flux_transformer,
    num_layers=config.num_double_layers,
    num_single_layers=config.num_single_layers
).to(device, dtype=torch.float32)
tokenizer_one = AutoTokenizer.from_pretrained(config.pretrained_model_name_or_path, subfolder="tokenizer")
tokenizer_two = AutoTokenizer.from_pretrained(config.pretrained_model_name_or_path, subfolder="tokenizer_2")
text_encoder_one = CLIPTextModel.from_pretrained(
    config.pretrained_model_name_or_path, subfolder="text_encoder"
).to(device, dtype=torch.float32)
text_encoder_two = T5EncoderModel.from_pretrained(
    config.pretrained_model_name_or_path, subfolder="text_encoder_2"
).to(device, dtype=torch.float32)

# Freeze non-ControlNet models
vae.requires_grad_(False)
flux_transformer.requires_grad_(False)
text_encoder_one.requires_grad_(False)
text_encoder_two.requires_grad_(False)
flux_controlnet.train()

# Define the training command (to be run via accelerate)
training_command = (
    f"accelerate launch train_controlnet_flux.py "
    f"--pretrained_model_name_or_path={config.pretrained_model_name_or_path} "
    f"--dataset_name={config.preprocessed_dataset_name} "
    f"--image_column=image "
    f"--conditioning_image_column=conditioning_image "
    f"--caption_column=text "
    f"--resolution={config.resolution} "
    f"--train_batch_size={config.train_batch_size} "
    f"--mixed_precision=no "  # Float32 precision
    f"--num_double_layers={config.num_double_layers} "
    f"--num_single_layers={config.num_single_layers} "
    f"--max_train_steps={config.max_train_steps} "
    f"--checkpointing_steps={config.checkpointing_steps} "
    f"--validation_steps={config.validation_steps} "
    f"--learning_rate={config.learning_rate} "
    f"--output_dir={config.output_dir} "
    f"--logging_dir={config.logging_dir} "
    f"--push_to_hub "
    f"--hub_model_id={config.hub_model_id}"
)

logger.info("Training command prepared:")
logger.info(training_command)

# Note: Save this command to a script or run it in a terminal after setting up accelerate config
# For notebook execution, you might need to use ! or os.system() if running directly:
# ! {training_command}
# Alternatively, save and run separately after `accelerate config`

# ## Cell 6: Post-Training Notes

# After training, the ControlNet model will be saved to `config.output_dir` and pushed to the Hugging Face Hub
# at `config.hub_model_id`. You can use it for inference with the FluxControlNetPipeline as shown in the tutorial.

# Example inference (after training):
"""
from diffusers.pipelines.flux.pipeline_flux_controlnet import FluxControlNetPipeline
pipe = FluxControlNetPipeline.from_pretrained(
    config.pretrained_model_name_or_path,
    controlnet=FluxControlNetModel.from_pretrained(config.output_dir),
    torch_dtype=torch.float32
).to(device)
# Add your inference code here with a conditioning image from the codec
"""