In [1]:
import argparse
import copy
import itertools
import logging
import math
import os
import random
import shutil
import warnings
from contextlib import nullcontext
from pathlib import Path

import numpy as np
import torch
import torch.utils.checkpoint
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
from huggingface_hub import create_repo, upload_folder
from huggingface_hub.utils import insecure_hashlib
from peft import LoraConfig, set_peft_model_state_dict
from peft.utils import get_peft_model_state_dict
from PIL import Image, ImageDraw
from PIL.ImageOps import exif_transpose
from torch.utils.data import Dataset
from torchvision import transforms
from torchvision.transforms.functional import crop
from tqdm.auto import tqdm
from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast

import diffusers
from diffusers import (
    AutoencoderKL,
    FlowMatchEulerDiscreteScheduler,
    FluxFillPipeline,
    FluxTransformer2DModel,
)
from diffusers.utils import load_image
from diffusers.optimization import get_scheduler
from diffusers.training_utils import (
    _set_state_dict_into_text_encoder,
    cast_training_params,
    compute_density_for_timestep_sampling,
    compute_loss_weighting_for_sd3,
    free_memory,
)
from diffusers.utils import (
    check_min_version,
    convert_unet_state_dict_to_peft,
    is_wandb_available,
)
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.torch_utils import is_compiled_module

2025-02-18 02:13:07.382428: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-02-18 02:13:07.615522: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1739812387.707831  109304 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1739812387.734496  109304 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-02-18 02:13:07.958682: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instr

In [2]:
instance_data_root = Path("./sks")

In [9]:
paths = [path for path in list(Path(instance_data_root).iterdir())]

In [10]:
paths

[PosixPath('sks/005.jpg'),
 PosixPath('sks/001.jpg'),
 PosixPath('sks/004.jpg'),
 PosixPath('sks/001.txt'),
 PosixPath('sks/002.jpg'),
 PosixPath('sks/001_mask.png'),
 PosixPath('sks/002.txt'),
 PosixPath('sks/003.jpg'),
 PosixPath('sks/003_mask.png')]

In [40]:
name_to_mask = {}
name_to_image = {}
name_to_prompt = {}

mask_postfix = "_mask"

for path in sorted(paths):
    file_name_with_ext = path.name
    file_name, ext = os.path.splitext(file_name_with_ext)
    if file_name.lower().endswith(mask_postfix):
        name_to_mask[file_name[:-len(mask_postfix)]] = path
    elif ext.lower() == ".txt":
        name_to_prompt[file_name] = path
    else:
        name_to_image[file_name] = path
            

In [41]:
name_to_mask

{'001': PosixPath('sks/001_mask.png'), '003': PosixPath('sks/003_mask.png')}

In [42]:
name_to_image

{'001': PosixPath('sks/001.jpg'),
 '002': PosixPath('sks/002.jpg'),
 '003': PosixPath('sks/003.jpg'),
 '004': PosixPath('sks/004.jpg'),
 '005': PosixPath('sks/005.jpg')}

In [43]:
name_to_prompt

{'001': PosixPath('sks/001.txt'), '002': PosixPath('sks/002.txt')}

In [72]:
def generate_full_mask(size):
    mask = Image.new("L", size, 0)
    draw = ImageDraw.Draw(mask)
    draw.rectangle((0, 0, size[0], size[1]), fill=255)
    return mask

In [73]:
instance_prompt = "a sks dog"
instance_images = []
pil_images = []
mask_images = []
custom_instance_prompts = []
pixel_values = []

for name in sorted(name_to_image):
    try:
        instance_image = Image.open(name_to_image[name])
        instance_images.append(instance_image)
        
        if name in name_to_mask:
            mask_image = Image.open(name_to_mask[name])
            mask_images.append(mask_image)
        else:
            mask_image = generate_full_mask(instance_image.size)
            mask_images.append(mask_image)
            
        if name in name_to_prompt:
            with open(name_to_prompt[name], "r") as prompt_file:
                prompt_text = prompt_file.read()
                custom_instance_prompts.append(prompt_text)
        else:
            custom_instance_prompts.append(instance_prompt)
    except (IOError, AttributeError) as e:
        print(e)

In [74]:
instance_images

[<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1815x1967>,
 <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=2469x2558>,
 <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=2796x2656>,
 <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=2476x2612>,
 <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=2732x2736>]

In [78]:
mask_images

[<PIL.PngImagePlugin.PngImageFile image mode=RGBA size=1815x1967>,
 <PIL.Image.Image image mode=L size=2469x2558>,
 <PIL.PngImagePlugin.PngImageFile image mode=RGBA size=2796x2656>,
 <PIL.Image.Image image mode=L size=2476x2612>,
 <PIL.Image.Image image mode=L size=2732x2736>]

In [79]:
custom_instance_prompts

['sks dog,tongue',
 'sks dog licking his lip',
 'a sks dog',
 'a sks dog',
 'a sks dog']

In [82]:
size = 1024
center_crop = False
train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
train_flip = transforms.RandomHorizontalFlip(p=1.0)
train_transforms = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5]),
    ]
)

In [88]:
pil_images = []
for image in instance_images:
    image = exif_transpose(image)
    if not image.mode == "RGB":
        image = image.convert("RGB")
    image = train_resize(image)


    else:
        y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
        image = crop(image, y1, x1, h, w)
    pil_images.append(image.copy())
    image = train_transforms(image)
    pixel_values.append(image)

NameError: name 'random_flip' is not defined

In [87]:
def prepare_mask_and_masked_image(image, mask):
    image = np.array(image.convert("RGB"))
    image = image[None].transpose(0, 3, 1, 2)
    image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0

    mask = np.array(mask.convert("L"))
    mask = mask.astype(np.float32) / 255.0
    mask = mask[None, None]
    mask[mask < 0.5] = 0
    mask[mask >= 0.5] = 1
    mask = torch.from_numpy(mask)

    masked_image = image * (mask < 0.5)
    
    return mask, masked_image

In [92]:
image_new = train_resize(instance_images[0])

In [94]:
image_new.size

(1024, 1109)

In [162]:
mask, masked_image = prepare_mask_and_masked_image(instance_images[0], mask_images[0])

In [163]:
mask.shape

torch.Size([1, 1, 1967, 1815])

In [164]:
masked_image.shape

torch.Size([1, 3, 1967, 1815])

In [165]:
masked_image = masked_image.numpy()
masked_image = masked_image[0].transpose(1, 2, 0)
masked_image = (masked_image + 1.0) * 127.5

In [166]:
masked_image = np.clip(masked_image, 0, 255).astype(np.uint8) 

In [167]:
image = Image.fromarray(masked_image)

In [170]:
mask.shape

torch.Size([1, 1, 1967, 1815])