In [None]:
import os
import sys

from typing import List, Tuple
from PIL import Image, ImageOps
import matplotlib.pyplot as plt

import torch
from torchvision.transforms.functional import to_pil_image, to_tensor

import accelerate

from pathlib import Path
root_dir = Path().resolve()

sys.path.append(root_dir)

from omnigen2.pipelines.omnigen2.pipeline_omnigen2 import OmniGen2Pipeline

In [None]:
def create_collage(images: List[torch.Tensor]) -> Image.Image:
    """Create a horizontal collage from a list of images."""
    max_height = max(img.shape[-2] for img in images)
    total_width = sum(img.shape[-1] for img in images)
    canvas = torch.zeros((3, max_height, total_width), device=images[0].device)
    
    current_x = 0
    for img in images:
        h, w = img.shape[-2:]
        canvas[:, :h, current_x:current_x+w] = img * 0.5 + 0.5
        current_x += w
    
    return to_pil_image(canvas)

In [None]:
def preprocess(input_image_path: List[str] = []) -> Tuple[str, str, List[Image.Image]]:
    """Preprocess the input images."""
    # Process input images
    input_images = []

    if input_image_path:
        if isinstance(input_image_path, str):
            input_image_path = [input_image_path]
            
        if len(input_image_path) == 1 and os.path.isdir(input_image_path[0]):
            input_images = [Image.open(os.path.join(input_image_path[0], f)) 
                          for f in os.listdir(input_image_path[0])]
        else:
            input_images = [Image.open(path) for path in input_image_path]

        input_images = [ImageOps.exif_transpose(img) for img in input_images]

    return input_images

**Pipeline Initialization**

In [None]:
accelerator = accelerate.Accelerator()

model_path="OmniGen2/OmniGen2"
pipeline = OmniGen2Pipeline.from_pretrained(
    model_path,
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
    token="REMOVED_TOKENYVrtMysWgKpjKpdiquPiOMevDqhiDYkKRL",
)
pipeline = pipeline.to(accelerator.device, dtype=torch.bfloat16)

**Text to image generation**

In [None]:
negative_prompt = "(((deformed))), blurry, over saturation, bad anatomy, disfigured, poorly drawn face, mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), fused fingers, messy drawing, broken legs censor, censored, censor_bar"

instructions = [
    "The sun rises slightly, the dew on the rose petals in the garden is clear, a crystal ladybug is crawling to the dew, the background is the early morning garden, macro lens.",
    "Hyperrealistic macro photograph of a whimsical rabbit sculpture, meticulously crafted from an assortment of fresh garden vegetables. Its body is formed from crisp lettuce and cabbage leaves, with vibrant carrot slices for ears, bright red radish for eyes, and delicate parsley sprigs for fur. The rabbit is sitting on a rustic, dark wood cutting board, with a few scattered water droplets glistening on its surface. Dramatic, warm studio lighting from the side casts soft shadows, highlighting the intricate textures of the vegetables. Shallow depth of field, sharp focus, cinematic food photography, 8K, bokeh background.",
]
for instruction in instructions:
    generator = torch.Generator(device=accelerator.device).manual_seed(0)
    results = pipeline(
        prompt=instruction,
        input_images=[],
        width=1024,
        height=1024,
        num_inference_steps=50,
        max_sequence_length=1024,
        text_guidance_scale=4.0,
        image_guidance_scale=1.0,
        negative_prompt=negative_prompt,
        num_images_per_prompt=1,
        generator=generator,
        output_type="pil",
    )

    vis_images = [to_tensor(image) * 2 - 1 for image in results.images]
    output_image = create_collage(vis_images)

    display(output_image)

**Editing with instruction**

In [None]:
negative_prompt = "(((deformed))), blurry, over saturation, bad anatomy, disfigured, poorly drawn face, mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), fused fingers, messy drawing, broken legs censor, censored, censor_bar"

inputs = [
    ("Change the background to classroom.", "example_images/ComfyUI_temp_mllvz_00071_.png"),
    ("Generate a photo of an anime-style figurine placed on a desk. The figurine model should be based on the character photo provided in the attachment, accurately replicating the full-body pose, facial expression, and clothing style of the character in the photo, ensuring the entire figurine is fully presented. The overall design should be exquisite and detailed, soft gradient colors and a delicate texture, leaning towards a Japanese anime style, rich in details, with a realistic quality and beautiful visual appeal.", "example_images/RAL_0315.JPG"),
]

for instruction, input_image in inputs:
    input_images = preprocess(input_image)

    generator = torch.Generator(device=accelerator.device).manual_seed(0)
    results = pipeline(
        prompt=instruction,
        input_images=input_images,
        num_inference_steps=50,
        max_sequence_length=1024,
        text_guidance_scale=5.0,
        image_guidance_scale=2.0,
        negative_prompt=negative_prompt,
        num_images_per_prompt=1,
        generator=generator,
        output_type="pil",
    )
    
    fig, axes = plt.subplots(1, len(results.images) + len(input_images), figsize=(results.images[0].width / results.images[0].height * 5 * (len(results.images) + len(input_images)), 5))

    for i, input_image in enumerate(input_images):
        axes[i].imshow(input_image)
        axes[i].axis('off')
        axes[i].set_title(f'Input {i+1}')
    
    for i, output_image in enumerate(results.images):
        axes[len(input_images) + i].imshow(output_image)
        axes[len(input_images) + i].axis('off')
        axes[len(input_images) + i].set_title(f'Output {i+1}')

    plt.tight_layout()
    plt.show()

**In-context Generation**

In [None]:
negative_prompt = "(((deformed))), blurry, over saturation, bad anatomy, disfigured, poorly drawn face, mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), fused fingers, messy drawing, broken legs censor, censored, censor_bar"

inputs = [
    ("Please let the person in image 2 hold the toy from the first image in a parking lot.", ["example_images/04.jpg", "example_images/000365954.jpg"]),
    ("Add the bird from image 1 to the desk in image 2.", ["example_images/996e2cf6-daa5-48c4-9ad7-0719af640c17_1748848108409.png", "example_images/00066-10350085.png"]),
]

for instruction, input_images in inputs:
    input_images = preprocess(input_images)

    generator = torch.Generator(device=accelerator.device).manual_seed(0)
    results = pipeline(
        prompt=instruction,
        input_images=input_images,
        width=1024,
        height=1024,
        num_inference_steps=50,
        max_sequence_length=1024,
        text_guidance_scale=5.0,
        image_guidance_scale=2.0,
        negative_prompt=negative_prompt,
        num_images_per_prompt=1,
        generator=generator,
        output_type="pil",
    )

    fig, axes = plt.subplots(1, len(results.images) + len(input_images), figsize=(results.images[0].width / results.images[0].height * 5 * (len(results.images) + len(input_images)), 5))

    for i, input_image in enumerate(input_images):
        axes[i].imshow(input_image)
        axes[i].axis('off')
        axes[i].set_title(f'Input {i+1}')
    
    for i, output_image in enumerate(results.images):
        axes[len(input_images) + i].imshow(output_image)
        axes[len(input_images) + i].axis('off')
        axes[len(input_images) + i].set_title(f'Output {i+1}')

    plt.tight_layout()
    plt.show()

**Visual Understanding**

In [None]:
from omnigen2.pipelines.omnigen2.pipeline_omnigen2_chat import OmniGen2ChatPipeline

chat_pipeline = OmniGen2ChatPipeline.from_pipe(pipeline=pipeline, transformer=pipeline.transformer)

inputs = [
    ("Please briefly describe this image.", "example_images/04.jpg"),
    ("Could you tell me the color of the woman's hat in the picture?", "example_images/000077066.jpg"),
]

for instruction, input_image in inputs:
    input_images = preprocess(input_image)

    generator = torch.Generator(device=accelerator.device).manual_seed(0)

    results = chat_pipeline(
        prompt=instruction,
        input_images=input_images,
        generator=generator,
    )

    # !! Uncomment following lines to visualize the input images
    print("Input Image:")
    vis_images = [to_tensor(image) * 2 - 1 for image in input_images]
    input_images = create_collage(vis_images)
    display(input_images)
    print("Output Text:")
    print(results.text)