Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Alternative implementation in Refiners #25

Open
Laurent2916 opened this issue Feb 21, 2024 · 0 comments
Open

Alternative implementation in Refiners #25

Laurent2916 opened this issue Feb 21, 2024 · 0 comments

Comments

@Laurent2916
Copy link

We are building Refiners, an open source, PyTorch-based micro-framework made to easily train and run adapters on top of foundational models. Just wanted to let you know that StyleAligned is now supported in Refiners! (congrats on the great work, by the way!!)

E.g. an equivalent to the style_aligned_sdxl.ipynb notebook (or the demo_stylealigned_sdxl.py demo) would look like this:

  1. Follow these install steps
  2. Run the code snippet below which gives:
from pathlib import Path

import torch
from PIL import Image

from refiners.fluxion.utils import manual_seed, no_grad
from refiners.foundationals.latent_diffusion import StableDiffusion_XL, StyleAlignedAdapter

# initialize Stable Diffusion XL model
sdxl = StableDiffusion_XL(
    device=torch.device("cuda"),
    dtype=torch.float16,
)

# load weights
sdxl_folder_weights = Path(...)
sdxl.lda.load_from_safetensors(sdxl_folder_weights / "sdxl-lda-fp16-fix.safetensors")
sdxl.unet.load_from_safetensors(sdxl_folder_weights / "sdxl-unet.safetensors")
sdxl.clip_text_encoder.load_from_safetensors(sdxl_folder_weights / "DoubleCLIPTextEncoder.safetensors")

# inject style aligned adapter
style_aligned_adapter = StyleAlignedAdapter(sdxl.unet)
style_aligned_adapter.inject()

set_of_prompts = [
    "a toy train. macro photo. 3d game asset",
    "a toy airplane. macro photo. 3d game asset",
    "a toy bicycle. macro photo. 3d game asset",
    "a toy car. macro photo. 3d game asset",
    "a toy boat. macro photo. 3d game asset",
]

with no_grad():
    # create (context) embeddings from prompts
    clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
        text=set_of_prompts,
        negative_text=[""] * len(set_of_prompts),
    )

    time_ids = sdxl.default_time_ids.repeat(len(set_of_prompts), 1)

    # initialize latents
    manual_seed(seed=2)
    x = torch.randn(
        (len(set_of_prompts), 4, 128, 128),
        device=sdxl.device,
        dtype=sdxl.dtype,
    )

    # denoise
    for step in sdxl.steps:
        x = sdxl(
            x,
            step=step,
            clip_text_embedding=clip_text_embedding,
            pooled_text_embedding=pooled_text_embedding,
            time_ids=time_ids,
        )

    # decode latents
    predicted_images = [sdxl.lda.decode_latents(latent.unsqueeze(0)) for latent in x]

# tile all images horizontally
merged_image = Image.new("RGB", (1024 * len(predicted_images), 1024))
for i in range(len(predicted_images)):
    merged_image.paste(predicted_images[i], (i * 1024, 0))
merged_image.save("style_aligned_example.png")

A few more things:

Feedback welcome!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant