In [None]:
import os
import sys

from typing import List, Tuple
from PIL import Image

import torch
from torchvision.transforms.functional import to_pil_image, to_tensor

import accelerate

from pathlib import Path
root_dir = Path().resolve()

sys.path.append(root_dir)

from omnigen2.pipelines.omnigen2.pipeline_omnigen2 import OmniGen2Pipeline

In [None]:
def create_collage(images: List[torch.Tensor]) -> Image.Image:
    """Create a horizontal collage from a list of images."""
    max_height = max(img.shape[-2] for img in images)
    total_width = sum(img.shape[-1] for img in images)
    canvas = torch.zeros((3, max_height, total_width), device=images[0].device)
    
    current_x = 0
    for img in images:
        h, w = img.shape[-2:]
        canvas[:, :h, current_x:current_x+w] = img * 0.5 + 0.5
        current_x += w
    
    return to_pil_image(canvas)

In [None]:
def preprocess(input_image_path: List[str] = []) -> Tuple[str, str, List[Image.Image]]:
    """Preprocess the input images."""
    # Process input images
    input_images = []

    if input_image_path:
        if isinstance(input_image_path, str):
            input_image_path = [input_image_path]
            
        if len(input_image_path) == 1 and os.path.isdir(input_image_path[0]):
            input_images = [Image.open(os.path.join(input_image_path[0], f)) 
                          for f in os.listdir(input_image_path[0])]
        else:
            input_images = [Image.open(path) for path in input_image_path]

    return input_images

**Pipeline Initialization**

In [None]:
accelerator = accelerate.Accelerator()
pipeline = OmniGen2Pipeline.from_pretrained("/share_2/luoxin/projects/OmniGen2/pretrained_models/omnigen2_pipe",
                                            torch_dtype=torch.bfloat16,
                                            trust_remote_code=True,
                                            token="REMOVED_TOKENYVrtMysWgKpjKpdiquPiOMevDqhiDYkKRL")
pipeline = pipeline.to(accelerator.device, dtype=torch.bfloat16)

**Text to image generation**

In [None]:
generator = torch.Generator(device=accelerator.device).manual_seed(223)

instruction = "A dog running in the park"
negative_prompt = ""

results = pipeline(
    prompt=instruction,
    input_images=[],
    width=1024,
    height=1024,
    num_inference_steps=28,
    max_sequence_length=1024,
    text_guidance_scale=5.0,
    image_guidance_scale=1.0,
    negative_prompt=negative_prompt,
    num_images_per_prompt=3,
    generator=generator,
    output_type="pil",
)

vis_images = [to_tensor(image) * 2 - 1 for image in results.images]
output_image = create_collage(vis_images)

display(output_image)

**Editing with instruction**

In [None]:
generator = torch.Generator(device=accelerator.device).manual_seed(223)

instruction = "Add a beautiful girl with long flowing hair seated beside the teddy bear on the park bench."
negative_prompt = ""

input_images = preprocess("example_images/02.jpg")

results = pipeline(
    prompt=instruction,
    input_images=input_images,
    width=1024,
    height=1024,
    num_inference_steps=28,
    max_sequence_length=1024,
    text_guidance_scale=5.0,
    image_guidance_scale=1.8,
    negative_prompt=negative_prompt,
    num_images_per_prompt=3,
    generator=generator,
    output_type="pil",
)
# !! Uncomment following lines to visualize the input images
# vis_images = [to_tensor(image) * 2 - 1 for image in input_images]
# input_images = create_collage(vis_images)
# display(input_images)

vis_images = [to_tensor(image) * 2 - 1 for image in results.images]
output_image = create_collage(vis_images)
display(output_image)

**Subject-Driven Editing**

In [None]:
generator = torch.Generator(device=accelerator.device).manual_seed(223)

instruction = "The car toy and the bear toy are placed on the luxury hotel bed."
negative_prompt = ""

input_images = preprocess("example_images")

results = pipeline(
    prompt=instruction,
    input_images=input_images,
    width=1024,
    height=1024,
    num_inference_steps=28,
    max_sequence_length=1024,
    text_guidance_scale=5.0,
    image_guidance_scale=1.8,
    negative_prompt=negative_prompt,
    num_images_per_prompt=3,
    generator=generator,
    output_type="pil",
)

# !! Uncomment following lines to visualize the input images
# vis_images = [to_tensor(image) * 2 - 1 for image in input_images]
# input_images = create_collage(vis_images)
# display(input_images)

vis_images = [to_tensor(image) * 2 - 1 for image in results.images]
output_image = create_collage(vis_images)

display(output_image)