In [None]:
import os
from PIL import Image
import torch
from src.customID.pipeline_flux import FluxPipeline
from src.customID.transformer_flux import FluxTransformer2DModel
from src.customID.model import CustomIDModel

def image_grid(imgs, rows, cols):
    assert len(imgs) == rows*cols
    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    grid_w, grid_h = grid.size
    
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid

_DEVICE = "cuda:0"
_DTYPE=torch.bfloat16
model_path = "pretrained_ckpt/flux.1-dev" #you can also use `black-forest-labs/FLUX.1-dev`
transformer = FluxTransformer2DModel.from_pretrained(model_path, subfolder="transformer", torch_dtype=_DTYPE).to(_DEVICE)
pipe = FluxPipeline.from_pretrained(model_path, transformer=transformer, torch_dtype=_DTYPE).to(_DEVICE)

In [None]:
num_token=64
trained_ckpt = "pretrained_ckpt/FLUX-customID.pt"
customID_model = CustomIDModel(pipe, trained_ckpt, _DEVICE, _DTYPE, num_token)

In [None]:
num_samples=3
gs= 3.5
_seed=2024
h=1024
w=1024
img_path = "img/man1.jpg"
p="A man wearing a classic leather jacket leans against a vintage motorcycle, surrounded by autumn leaves swirling in the breeze."
images = customID_model.generate(pil_image=img_path,
                            prompt=p,
                            num_samples=num_samples,
                            height=h,
                            width=w,
                            seed=_seed,
                            num_inference_steps=28,
                            guidance_scale=gs)
face_img=[Image.open(img_path).resize((h,w))]
grid = image_grid(images, 1, num_samples)
grid