Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions olive/data/component/sd_lora/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,22 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from olive.data.component.sd_lora import (
aspect_ratio_bucketing,
auto_caption,
auto_tagging,
dataset,
image_filtering,
image_resizing,
preprocess_chain,
)

__all__ = [
"aspect_ratio_bucketing",
"auto_caption",
"auto_tagging",
"dataset",
"image_filtering",
"image_resizing",
"preprocess_chain",
]
75 changes: 75 additions & 0 deletions olive/data/component/sd_lora/aspect_ratio_bucketing.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,81 @@ def aspect_ratio_bucketing(
except Exception as e:
logger.warning("Failed to process %s: %s", image_path, e)

# Process class images for DreamBooth (if present)
if hasattr(dataset, "class_image_paths") and dataset.class_image_paths:
logger.info("Processing %d class images for DreamBooth", len(dataset.class_image_paths))

# Prepare class images output directory
class_output_dir = None
if output_dir:
class_output_dir = Path(output_dir) / "class_images"
class_output_dir.mkdir(parents=True, exist_ok=True)

for i, class_path in enumerate(dataset.class_image_paths):
class_path = Path(class_path) # noqa: PLW2901
try:
with Image.open(class_path) as img:
orig_w, orig_h = img.size
orig_aspect = orig_w / orig_h

# Find best matching bucket
best_bucket = _find_best_bucket(orig_w, orig_h, buckets)
bucket_w, bucket_h = best_bucket

# Calculate crop coordinates
crops_coords_top_left = _calculate_crop_coords(
orig_w, orig_h, bucket_w, bucket_h, crop_to_bucket, crop_position
)

final_path = str(class_path)

# Resize class image if requested
if resize_images and class_output_dir:
if class_path.suffix:
out_name = f"class_{i:06d}{class_path.suffix}"
else:
out_name = f"class_{i:06d}.jpg"
out_path = class_output_dir / out_name

if not overwrite and out_path.exists():
final_path = str(out_path)
else:
if img.mode != "RGB":
img = img.convert("RGB") # noqa: PLW2901

resized = resize_image(
img,
bucket_w,
bucket_h,
resize_mode=resize_mode,
crop_position=crop_position,
fill_color=fill_color,
resample_filter=resample_filter,
)

if out_path.suffix:
resized.save(out_path, quality=95)
else:
resized.save(out_path, format="JPEG", quality=95)
final_path = str(out_path)

# Update class image path in dataset
dataset.class_image_paths[i] = Path(final_path)

# Store bucket assignment for class image
bucket_assignments[final_path] = {
"bucket": best_bucket,
"original_size": (orig_w, orig_h),
"aspect_ratio": orig_aspect,
"crops_coords_top_left": crops_coords_top_left,
}
bucket_counts[best_bucket] += 1

except Exception as e:
logger.warning("Failed to process class image %s: %s", class_path, e)

logger.info("Processed %d class images", len(dataset.class_image_paths))

# Log bucket distribution
logger.info("Bucket distribution:")
for bucket, count in sorted(bucket_counts.items(), key=lambda x: -x[1])[:10]:
Expand Down
135 changes: 116 additions & 19 deletions olive/data/component/sd_lora/image_resizing.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,12 @@ def image_resizing(
"""
from PIL import Image

from olive.data.component.sd_lora.utils import calculate_cover_size

# Validate resize_mode early
resize_mode = ResizeMode(resize_mode)
resize_mode_enum = ResizeMode(resize_mode)
resample_filter = get_resample_filter(resample_mode)
crop_to_bucket = resize_mode_enum == ResizeMode.COVER

# Prepare output directory if specified
if output_dir:
Expand All @@ -62,22 +65,59 @@ def image_resizing(

processed_count = 0
skipped_count = 0

for i, item in enumerate(dataset):
image_path = Path(item["image_path"])

# Determine output path
if output_dir:
out_path = Path(output_dir) / image_path.name
bucket_assignments = {}
target_bucket = (target_resolution, target_resolution)

def _calculate_crop_coords(orig_w: int, orig_h: int) -> tuple[int, int]:
"""Calculate crop coordinates for SDXL time embeddings."""
if not crop_to_bucket:
return (0, 0)

new_w, new_h = calculate_cover_size(orig_w, orig_h, target_resolution, target_resolution)
pos = CropPosition(crop_position)
if pos == CropPosition.CENTER:
left = (new_w - target_resolution) // 2
top = (new_h - target_resolution) // 2
elif pos == CropPosition.TOP:
left = (new_w - target_resolution) // 2
top = 0
elif pos == CropPosition.BOTTOM:
left = (new_w - target_resolution) // 2
top = new_h - target_resolution
elif pos == CropPosition.LEFT:
left = 0
top = (new_h - target_resolution) // 2
elif pos == CropPosition.RIGHT:
left = new_w - target_resolution
top = (new_h - target_resolution) // 2
else:
out_path = image_path
left = (new_w - target_resolution) // 2
top = (new_h - target_resolution) // 2

# Check if already processed
if not overwrite and out_path.exists() and out_path != image_path:
skipped_count += 1
continue
return (top, left)

def _process_image(image_path: Path, out_path: Path, prefix: str = "") -> Optional[str]:
"""Process a single image and return the final path."""
nonlocal processed_count, skipped_count

try:
# Get original size for bucket assignment (needed even if skipping resize)
with Image.open(image_path) as img:
orig_w, orig_h = img.size

# Check if already processed
if not overwrite and out_path.exists() and out_path != image_path:
skipped_count += 1
# Still need to add bucket assignment for skipped files
crops_coords = _calculate_crop_coords(orig_w, orig_h)
bucket_assignments[str(out_path)] = {
"bucket": target_bucket,
"original_size": (orig_w, orig_h),
"aspect_ratio": orig_w / orig_h,
"crops_coords_top_left": crops_coords,
}
return str(out_path)

with Image.open(image_path) as img:
# Convert to RGB if necessary
if img.mode != "RGB":
Expand All @@ -87,7 +127,7 @@ def image_resizing(
img,
target_resolution,
target_resolution,
resize_mode=resize_mode,
resize_mode=resize_mode_enum,
crop_position=crop_position,
fill_color=fill_color,
resample_filter=resample_filter,
Expand All @@ -96,13 +136,70 @@ def image_resizing(
result.save(out_path, quality=95)
processed_count += 1

# Update dataset path if output location changed
if out_path != image_path:
dataset.image_paths[i] = out_path
# Store bucket assignment
crops_coords = _calculate_crop_coords(orig_w, orig_h)
bucket_assignments[str(out_path)] = {
"bucket": target_bucket,
"original_size": (orig_w, orig_h),
"aspect_ratio": orig_w / orig_h,
"crops_coords_top_left": crops_coords,
}

return str(out_path)

except Exception as e:
logger.warning("Failed to resize %s: %s", image_path, e)
logger.warning("Failed to resize %s%s: %s", prefix, image_path, e)
return None

# Process instance images
for i, item in enumerate(dataset):
image_path = Path(item["image_path"])

# Determine output path
if output_dir:
out_path = Path(output_dir) / image_path.name
else:
out_path = image_path

final_path = _process_image(image_path, out_path)

# Update dataset path if output location changed
if final_path and out_path != image_path:
if hasattr(dataset, "set_image_path"):
dataset.set_image_path(i, out_path)
elif hasattr(dataset, "image_paths"):
dataset.image_paths[i] = out_path

logger.info("Resized %d instance images, skipped %d", processed_count, skipped_count)

# Process class images for DreamBooth (if present)
if hasattr(dataset, "class_image_paths") and dataset.class_image_paths:
logger.info("Processing %d class images for DreamBooth", len(dataset.class_image_paths))

class_processed = 0
class_output_dir = None
if output_dir:
class_output_dir = Path(output_dir) / "class_images"
class_output_dir.mkdir(parents=True, exist_ok=True)

for i, class_path in enumerate(dataset.class_image_paths):
class_path = Path(class_path) # noqa: PLW2901

if class_output_dir:
out_path = class_output_dir / f"class_{i:06d}{class_path.suffix or '.jpg'}"
else:
out_path = class_path

final_path = _process_image(class_path, out_path, prefix="class image ")

if final_path:
class_processed += 1
dataset.class_image_paths[i] = Path(final_path)

logger.info("Processed %d class images", class_processed)

logger.info("Resized %d images, skipped %d", processed_count, skipped_count)
# Store bucket assignments in dataset (for compatibility with aspect_ratio_bucketing)
dataset.bucket_assignments = bucket_assignments
dataset.buckets = [target_bucket]

return dataset
24 changes: 15 additions & 9 deletions olive/data/container/image_data_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,17 +99,23 @@ def _convert_hf_dataset(self, dataset, image_column: str, caption_column: Option

return HuggingFaceImageDataset(dataset, image_column, caption_column)

def pre_process(self, dataset):
"""Run preprocessing with HuggingFace dataset support."""
# Check if this is a HuggingFace dataset and convert if needed
def load_dataset(self):
"""Load dataset, extracting ImageDataContainer-specific params first."""
# Pop image_column and caption_column so they don't get passed to huggingface_dataset
params = self.config.load_dataset_config.params
image_column = params.pop("image_column", "image")
caption_column = params.pop("caption_column", None)

# Load the raw HuggingFace dataset
dataset = super().load_dataset()

# Convert to HuggingFaceImageDataset if needed
if self._is_huggingface_dataset():
load_params = self.config.load_dataset_config.params
image_column = load_params.get("image_column", "image")
caption_column = load_params.get("caption_column")
logger.info(
"Converting HuggingFace dataset: image_column=%s, caption_column=%s", image_column, caption_column
"Converting HuggingFace dataset: image_column=%s, caption_column=%s",
image_column,
caption_column,
)
dataset = self._convert_hf_dataset(dataset, image_column, caption_column)

# Run the standard preprocessing
return super().pre_process(dataset)
return dataset
33 changes: 31 additions & 2 deletions olive/model/handler/diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional, Union

from olive.common.utils import StrEnumBase
Expand Down Expand Up @@ -61,6 +62,9 @@ def __init__(
model_attributes: Additional model attributes.

"""
if not self.is_valid_model(model_path):
raise ValueError(f"The provided model_path '{model_path}' is not a valid diffusion model.")

super().__init__(
framework=Framework.PYTORCH,
model_file_format=ModelFileFormat.PYTORCH_ENTIRE_MODEL,
Expand All @@ -73,6 +77,33 @@ def __init__(
self.load_kwargs = load_kwargs or {}
self._pipeline = None

@classmethod
def is_valid_model(cls, model_path: str) -> bool:
"""Check if the path is a valid diffusion model.

Diffusion models are identified by the presence of a model_index.json file.

Args:
model_path: Local path or HuggingFace model ID.

Returns:
True if the path points to a valid diffusion model.

"""
# Local path
path = Path(model_path)
if path.is_dir():
return (path / "model_index.json").exists()

# HuggingFace model ID - try to check if model_index.json exists
try:
from huggingface_hub import hf_hub_download

hf_hub_download(model_path, "model_index.json")
return True
except Exception:
return False

@property
def adapter_path(self) -> Optional[str]:
"""Return the path to the LoRA adapter."""
Expand All @@ -86,8 +117,6 @@ def size_on_disk(self) -> int:
Returns 0 if unable to compute (e.g., for HuggingFace Hub IDs).
"""
try:
from pathlib import Path

model_path = Path(self.model_path)
if not model_path.exists():
# Remote model (HuggingFace Hub ID)
Expand Down
10 changes: 10 additions & 0 deletions olive/olive_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,16 @@
"supported_quantization_encodings": [ ],
"dataset": "dataset"
},
"SDLoRA": {
"module_path": "olive.passes.diffusers.lora.SDLoRA",
"supported_providers": [ "*" ],
"supported_accelerators": [ "gpu" ],
"supported_precisions": [ "*" ],
"extra_dependencies": [ "sd-lora" ],
"supported_algorithms": [ ],
"supported_quantization_encodings": [ ],
"dataset": "dataset_required"
},
"QNNContextBinaryGenerator": {
"module_path": "olive.passes.qnn.context_binary_generator.QNNContextBinaryGenerator",
"supported_providers": [ "*" ],
Expand Down
Empty file.
Loading
Loading