<a href="https://colab.research.google.com/github/ljkrajewski/jupyter_notebooks/blob/main/flux/GUFlux.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Grand Unified Flux
With much love to @camenduru for the [flux_jupyter](https://github.com/camenduru/flux-jupyter) repository.

In [None]:
#@title Install prerequisits and restart the session
#@markdown (Colab will report a system crash. _Don't Panic!!_)
import IPython

!pip3 install -U xformers --index-url https://download.pytorch.org/whl/cu124
!pip install -q torchsde einops diffusers accelerate gradio==3.50.2 python-multipart==0.0.12
!apt -y install -qq aria2

# Install torchvision with CUDA support
!pip uninstall torch torchvision -y
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124

print("\nRestarting session...")
IPython.get_ipython().kernel.do_shutdown(restart=True)

## After the restart, run the following cells.

In [None]:
#@title Connect Google Drive
from google.colab import drive
from IPython.display import clear_output
import ipywidgets as widgets
import os
from datetime import datetime

def inf(msg, style, wdth): inf = widgets.Button(description=msg, disabled=True, button_style=style, layout=widgets.Layout(min_width=wdth));display(inf)
Connect_Google_drive = False #@param {type:"boolean"}
#@markdown (optional) Leave "Directory_name" blank for default directory name.
Directory_name = "" #@param {type:"string"}

if Connect_Google_drive:
  print("Connecting...")
  drive.mount('/content/gdrive')
  mainpth="/content/gdrive/MyDrive"
else:
  mainpth="/content"

if Directory_name == "":
  now = datetime.now()
  timestamp = now.strftime("%Y-%m-%d_%H-%M-%S")
  Directory_name = f"flux-{timestamp}"

if not os.path.exists(f'{mainpth}/{Directory_name}'):
  %mkdir -p $mainpth/flux-$Directory_name
picture_path = f'{mainpth}/{Directory_name}'

picture_path

In [None]:
#@title Download code and models
%cd /content
!git clone -b totoro4 https://github.com/camenduru/ComfyUI /content/TotoroUI
%cd /content/TotoroUI

!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/FLUX.1-dev/resolve/main/flux1-dev-fp8-all-in-one.safetensors -d /content/TotoroUI/models/checkpoints -o flux1-dev-fp8-all-in-one.safetensors
!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/FLUX.1-dev/resolve/main/flux_realism_lora.safetensors -d /content/TotoroUI/models/loras -o flux_realism_lora.safetensors

In [None]:
#@title Define local routines and start gradio
import nodes
from nodes import NODE_CLASS_MAPPINGS
from totoro_extras import nodes_custom_sampler
from totoro_extras import nodes_flux
from totoro import model_management
import math
import random
import torch
import numpy as np
from PIL import Image
from PIL.PngImagePlugin import PngInfo
from google.colab import runtime
import gradio as gr
import os
import itertools
import re

def add_ai_metadata(image_path, prompt, seed, steps, guidance, sampler_name, scheduler, lora_strength_model, lora_strength_clip):
    """
    Adds metadata related to a stable diffusion image generation to a PNG image file.

    Args:
        image_path: Path to the PNG image file.
        prompt: The text prompt used for image generation.
        seed: The random seed used for image generation.
        steps: The number of denoising steps.
        guidance: The classifier-free guidance scale (cfg_scale).
        sampler_name: The name of the sampler used.
        scheduler: The scheduler used.
        lora_strength_model: The strength of the LoRA model.
        lora_strength_clip: The strength of the LoRA CLIP.
    """
    try:
        img = Image.open(image_path)
        png_info = PngInfo()

        png_info.add_text("Prompt", prompt)
        png_info.add_text("Seed", str(seed))
        png_info.add_text("Steps", str(steps))
        png_info.add_text("Guidance", str(guidance))
        png_info.add_text("Sampler", sampler_name)
        png_info.add_text("Scheduler", scheduler)
        png_info.add_text("LoRA", "flux_realism_lora")
        png_info.add_text("LoRA Strength Model", str(lora_strength_model))
        png_info.add_text("LoRA Strength CLIP", str(lora_strength_clip))
        img.save(image_path, pnginfo=png_info)
        print(f"Metadata added successfully to {image_path}")

    except FileNotFoundError:
        print(f"Error: File not found at {image_path}")
    except Exception as e:
        print(f"An error occurred: {e}")

def optimal_dimensions(wh_ratio):
    """
    Calculates optimal dimensions for Stable Diffusion.

    Args:
      wh_ratio (str): Width-to-height ratio in the format "width:height".

    Returns:
      A tuple of (new_width, new_height) representing the optimal dimensions.
    """
    sw, sh = wh_ratio.split(':')
    w, h = int(sw), int(sh)
    c = math.sqrt(1024**2 / (w * h))
    new_width = int(((w * c) // 16) * 16)
    new_height = int(((h * c) // 16) * 16)
    #print(f"Optimal dimensions: {new_width}x{new_height}")
    return new_width, new_height

CheckpointLoaderSimple = NODE_CLASS_MAPPINGS["CheckpointLoaderSimple"]()
LoraLoader = NODE_CLASS_MAPPINGS["LoraLoader"]()
FluxGuidance = nodes_flux.NODE_CLASS_MAPPINGS["FluxGuidance"]()
RandomNoise = nodes_custom_sampler.NODE_CLASS_MAPPINGS["RandomNoise"]()
BasicGuider = nodes_custom_sampler.NODE_CLASS_MAPPINGS["BasicGuider"]()
KSamplerSelect = nodes_custom_sampler.NODE_CLASS_MAPPINGS["KSamplerSelect"]()
BasicScheduler = nodes_custom_sampler.NODE_CLASS_MAPPINGS["BasicScheduler"]()
SamplerCustomAdvanced = nodes_custom_sampler.NODE_CLASS_MAPPINGS["SamplerCustomAdvanced"]()
VAELoader = NODE_CLASS_MAPPINGS["VAELoader"]()
VAEDecode = NODE_CLASS_MAPPINGS["VAEDecode"]()
EmptyLatentImage = NODE_CLASS_MAPPINGS["EmptyLatentImage"]()

with torch.inference_mode():
    unet, clip, vae = CheckpointLoaderSimple.load_checkpoint("flux1-dev-fp8-all-in-one.safetensors")

def closestNumber(n, m):
    q = int(n / m)
    n1 = m * q
    if (n * m) > 0:
        n2 = m * (q + 1)
    else:
        n2 = m * (q - 1)
    if abs(n - n1) < abs(n - n2):
        return n1
    return n2

@torch.inference_mode()
def generate(positive_prompt, wh_ratio, orientation, seed, steps, sampler_name, scheduler, guidance, lora_strength_model, lora_strength_clip):
    global unet, clip
    if seed == 0:
        seed = random.randint(0, 18446744073709551615)
    print(f"\nSeed:  {seed}")
    width, height = optimal_dimensions(wh_ratio)
    if orientation == "portrait":
        width, height = height, width
    print(f"Dimentions:  {width}x{height} ({orientation})")
    unet_lora, clip_lora = LoraLoader.load_lora(unet, clip, "flux_realism_lora.safetensors", lora_strength_model, lora_strength_clip)
    cond, pooled = clip_lora.encode_from_tokens(clip_lora.tokenize(positive_prompt), return_pooled=True)
    cond = [[cond, {"pooled_output": pooled}]]
    cond = FluxGuidance.append(cond, guidance)[0]
    noise = RandomNoise.get_noise(seed)[0]
    guider = BasicGuider.get_guider(unet_lora, cond)[0]
    sampler = KSamplerSelect.get_sampler(sampler_name)[0]
    sigmas = BasicScheduler.get_sigmas(unet_lora, scheduler, steps, 1.0)[0]
    latent_image = EmptyLatentImage.generate(closestNumber(width, 16), closestNumber(height, 16))[0]
    sample, sample_denoised = SamplerCustomAdvanced.sample(noise, guider, sampler, sigmas, latent_image)
    decoded = VAEDecode.decode(vae, sample)[0].detach()
    Image.fromarray(np.array(decoded*255, dtype=np.uint8)[0]).save(f"{picture_path}/flux.png")
    add_ai_metadata(f"{picture_path}/flux.png", positive_prompt, seed, steps, guidance, sampler_name, scheduler, lora_strength_model, lora_strength_clip)
    return f"{picture_path}/flux.png"

def round_robin_prompts(prompt):
    """
    Generates all possible permutations of a prompt with round-robin sections.

    Args:
        prompt: A natural language prompt containing round-robin sections
                enclosed in curly braces with items separated by '|'.
                Example: "a {red|blue} cup on a {table|chair}"

    Returns:
        A list of strings, where each string is a possible permutation of the prompt.
    """

    sections = []
    split_prompt = re.split(r'({.*?})', prompt)

    for part in split_prompt:
        if part.startswith('{') and part.endswith('}'):
            items = part[1:-1].split('|')
            sections.append(items)
        else:
            sections.append([part])  # Non-round-robin parts are treated as single-item sections

    combinations = list(itertools.product(*sections))
    result = []
    for combo in combinations:
        result.append("".join(combo))
    return result

def generate_wrapper(positive_prompt, quantity, wh_ratio, orientation, seed, steps, sampler_name, scheduler, guidance, lora_strength_model, lora_strength_clip):
    prompts = round_robin_prompts(positive_prompt)
    num_prompts = len(prompts)
    for p_idx, prompt in enumerate(prompts):
        #print(f"Prompt:  {prompt}")
        for i in range(quantity):
            print(f"----\nGenerating image {i+1}/{quantity}, prompt {p_idx+1}/{num_prompts}...")
            generate(prompt, wh_ratio, orientation, seed, steps, sampler_name, scheduler, guidance, lora_strength_model, lora_strength_clip)
            new_image_name=f"{picture_path}/flux{p_idx+1}-{i+1}.png"
            os.rename(f"{picture_path}/flux.png", new_image_name)
            output_image.value = new_image_name
            output_image.update(value=new_image_name)
    if suicide_switch:
        runtime.unassign()
    return new_image_name

with gr.Blocks(analytics_enabled=False) as demo:
    with gr.Row():
        with gr.Column():
            positive_prompt = gr.Textbox(lines=3, interactive=True, value="Anime drawing, full body portrait, attractive 19-year-old Caucasian woman, long straight blonde hair, red lipstick, white button-up blouse, black neck tie, black suspenders, tan suit jacket, tan shorts, tan pantyhose, tan flat shoes, smiling, sitting in a recliner, legs crossed", label="Prompt")
            #width = gr.Slider(minimum=256, maximum=2048, value=1024, step=16, label="width")
            #height = gr.Slider(minimum=256, maximum=2048, value=1024, step=16, label="height")
            with gr.Row():
                #wh_ratio = gr.Textbox(lines=1, interactive=True, value="4:7", label="width:height ratio")
                wh_ratio = gr.Dropdown(["1:1","2:1","3:2","4:3","5:3","7:4","9:7","16:9","21:11","17:15"], value="7:4", label="width:height ratio")
                orientation = gr.Dropdown(["portrait", "landscape"], label="orientation", value="portrait")
                quantity = gr.Slider(minimum=1, maximum=10, value=1, step=1, label="quantity per prompt")
                suicide_switch = gr.Checkbox(value=False, label="Disconnect and delete runtime when done.")
                #width, height = optimal_dimensions(wh_ratio.value)
            seed = gr.Slider(minimum=0, maximum=18446744073709551615, value=0, step=1, label="seed (0=random)")
            steps = gr.Slider(minimum=4, maximum=50, value=20, step=1, label="steps")
            guidance = gr.Slider(minimum=0, maximum=20, value=3.5, step=0.5, label="guidance")
            lora_strength_model = gr.Slider(minimum=0, maximum=1, value=1.0, step=0.1, label="lora_strength_model")
            lora_strength_clip = gr.Slider(minimum=0, maximum=1, value=1.0, step=0.1, label="lora_strength_clip")
            sampler_name = gr.Dropdown(["euler", "heun", "heunpp2", "heunpp2", "dpm_2", "lms", "dpmpp_2m", "ipndm", "deis", "ddim", "uni_pc", "uni_pc_bh2"], label="sampler_name", value="euler")
            scheduler = gr.Dropdown(["normal", "sgm_uniform", "simple", "ddim_uniform"], label="scheduler", value="simple")
            generate_button = gr.Button("Generate")
        with gr.Column():
            output_image = gr.Image(label="Generated image", interactive=False)

    #generate_button.click(fn=generate, inputs=[positive_prompt, width, height, seed, steps, sampler_name, scheduler, guidance, lora_strength_model, lora_strength_clip], outputs=output_image)
    generate_button.click(fn=generate_wrapper, inputs=[positive_prompt, quantity, wh_ratio, orientation, seed, steps, sampler_name, scheduler, guidance, lora_strength_model, lora_strength_clip], outputs=output_image)

demo.queue().launch(inline=False, share=True, debug=True)