In [None]:
# To render using a Flux Pruned Checkpoint or LoRA model copied from your mounted Google Drive, run this cell first
from google.colab import drive
drive.mount('/content/drive')

# **PRUNED FLUX DEV FP8 CONFIGURATION SETUP AND USAGE INSTRUCTIONS**
- Requires a Pruned FLUX fp8 Checkpoint safetensor model (UNET only), ~11GB or smaller.
- Choose one CHECKPOINT_SOURCE from menu below (googledrive, civitai, or huggingface).
- If using googledrive as your CHECKPOINT_SOURCE, paste the googledrive path of your pre-downloaded checkpoint model into CHECKPOINT_DRIVE_PATH.
- If using civitai as your CHECKPOINT_SOURCE, choose the civitai_checkpoint_option to either use a publicly available civitai model (no further action required) or enter your own civitai pruned fp8 checkpoint model download URL (paste into CHECKPOINT_CAI_URL).
- If using huggingface as your CHECKPOINT_SOURCE, choose the hf_checkpoint_option to either use a publicly available huggingface model (no further action required) or enter your own huggingface pruned fp8 checkpoint model download URL (paste into CHECKPOINT_HF_URL).
- Note most civitai and a few huggingface models (either checkpoint or LoRA) require an API token for direct download to the Colab.  Paste in either or both API tokens below, and they will automatically be added to your checkpoint or LoRA download URLs.  (API tokens not necessary if using either a pre-downloaded or publicly available checkpoint model and no LoRAs.)
- To use one or two optional LoRAs in your render, you can either download them from a Civitai or Hugging Face URL that you paste into a text box in the Gradio interface, or you can copy over a pre-downloaded LoRA from your mounted Google Drive by selecting it through a dropdown menu in the Gradio Interface.
- To use LoRAs from googledrive, check the use_drive_loras box and paste in your googledrive LoRA directory path into LORA_DRIVE_DIR.
- You can potentially use different LoRA sources for each LoRA, no matter what choice you make for CHECKPOINT_SOURCE, and you can change the LoRAs to be used in each render through the Gradio Interface without restarting the Colab.  Once a LoRA has been downloaded, it will be saved for the duration of the run, so it will load faster, for shorter render times, if you reuse it.
- Renders take ~3min using default settings and one LoRA.  Renders without LoRAs will be faster, and renders using two LoRAs will be slower.  If you want to experiment, this colab has also run successfully using Schnell or Hybrid safetensor checkpoint models requiring fewer steps
- BTW, seeds used for renders appear below in this colab, in case you want to adjust parameters in the Gradio Interface and rerun a render using the same seed.
- Have fun rendering!

In [None]:
# @title Pruned Flux Configuration Setup

# Checkpoint Source Selection
CHECKPOINT_SOURCE = "civitai" # @param ["googledrive", "civitai", "huggingface"]

# Google Drive Configuration
CHECKPOINT_DRIVE_PATH = "" # @param {type:"string"}

# Civitai Configuration
civitai_checkpoint_option = "Use publicly available model (Flux Unchained)" # @param ["Use publicly available model (Flux Unchained)", "Paste custom URL below"]
CHECKPOINT_CAI_URL = "https://civitai.com/api/download/models/742989" # @param {type:"string"}

if civitai_checkpoint_option == "Use publicly available model (Flux Unchained)":
    CHECKPOINT_CAI_URL = "https://civitai.com/api/download/models/742989"
else:
    CHECKPOINT_CAI_URL = CHECKPOINT_CAI_URL.split("?")[0]  # Clean URL

# Hugging Face Configuration  
hf_checkpoint_option = "Use publicly available model" # @param ["Use publicly available model", "Paste custom URL below"]
CHECKPOINT_HF_URL = "https://huggingface.co/camenduru/FLUX.1-dev/resolve/main/flux1-dev-fp8.safetensors" # @param {type:"string"}

if hf_checkpoint_option == "Use publicly available model":
    CHECKPOINT_HF_URL = "https://huggingface.co/camenduru/FLUX.1-dev/resolve/main/flux1-dev-fp8.safetensors"
else:
    CHECKPOINT_HF_URL = CHECKPOINT_HF_URL.split("?")[0]  # Clean URL

# API Tokens
CIVITAI_TOKEN = "" # @param {type:"string"}
HF_TOKEN = "" # @param {type:"string"}

# Default to placeholder values if empty
CIVITAI_TOKEN = CIVITAI_TOKEN if CIVITAI_TOKEN else "YOUR_CIVITAI_TOKEN_HERE"
HF_TOKEN = HF_TOKEN if HF_TOKEN else "YOUR_HUGGINGFACE_TOKEN_HERE"

# LoRA Configuration
use_drive_loras = False # @param {type:"boolean"}
LORA_DRIVE_DIR = "" # @param {type:"string"}

# Summary
print("\n" + "=" * 50)
print("CONFIGURATION SUMMARY:")
print(f"✓ Checkpoint Source: {CHECKPOINT_SOURCE}")
if CHECKPOINT_SOURCE == "civitai":
    print(f"✓ Civitai URL: {CHECKPOINT_CAI_URL}")
    print(f"✓ Civitai Token: {'Provided' if CIVITAI_TOKEN != 'YOUR_CIVITAI_TOKEN_HERE' else 'Not provided'}")
elif CHECKPOINT_SOURCE == "huggingface":
    print(f"✓ HuggingFace URL: {CHECKPOINT_HF_URL}")
    print(f"✓ HuggingFace Token: {'Provided' if HF_TOKEN != 'YOUR_HUGGINGFACE_TOKEN_HERE' else 'Not provided'}")
else:
    print(f"✓ Google Drive Path: {CHECKPOINT_DRIVE_PATH if CHECKPOINT_DRIVE_PATH else 'Not specified'}")

if use_drive_loras and LORA_DRIVE_DIR:
    print(f"✓ LoRA Directory: {LORA_DRIVE_DIR}")
else:
    print("✓ LoRA Directory: Not configured")
print("=" * 50)

In [None]:
# @title After configuration, change runtime type to T4 GPU, run this cell for setup, ~7min, then launch the link for the Gradio Interface appearing at the bottom.

%cd /content
!git clone -b totoro4 https://github.com/camenduru/ComfyUI /content/TotoroUI
%cd /content/TotoroUI

import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

!pip install -q torchsde einops diffusers accelerate xformers==0.0.28.post1 gradio==4.44.1 python-multipart==0.0.12
!pip install pydantic==2.9.2
!pip install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/cu124
!apt -y install -qq aria2


import random
import torch
import numpy as np
from PIL import Image
import nodes
from nodes import NODE_CLASS_MAPPINGS
from totoro_extras import nodes_custom_sampler
from totoro_extras import nodes_flux
from totoro import model_management
import gradio as gr
import requests
import shutil
import totoro.utils


def patched_flux_to_diffusers(mmdit_config, output_prefix=""):
    """
    Enhanced Flux model key mapping for LoRA compatibility
    Maps from source components to combined linear1
    """
    n_double_layers = mmdit_config.get("depth", 0)
    n_single_layers = mmdit_config.get("depth_single_blocks", 0)
    hidden_size = mmdit_config.get("hidden_size", 0)

    key_map = {}

    # First map transformer blocks - these correspond to the double blocks
    for index in range(n_double_layers):
        prefix = f"{output_prefix}double_blocks.{index}"
        prefix_from = f"transformer_blocks.{index}"

        # Handle attention components with proper tensor shapes
        for end in [".weight", ".bias", ".lora_A.weight", ".lora_B.weight"]:
            clean_end = end.replace("lora_A", "A").replace("lora_B", "B")
            # Map main attention components
            qkv_map = {
                f"{prefix_from}.attn.to_q{end}": (f"{prefix}.img_attn.qkv{clean_end}", (0, 0, hidden_size)),
                f"{prefix_from}.attn.to_k{end}": (f"{prefix}.img_attn.qkv{clean_end}", (0, hidden_size, hidden_size)),
                f"{prefix_from}.attn.to_v{end}": (f"{prefix}.img_attn.qkv{clean_end}", (0, hidden_size * 2, hidden_size)),
                f"{prefix_from}.attn.add_q_proj{end}": (f"{prefix}.txt_attn.qkv{clean_end}", (0, 0, hidden_size)),
                f"{prefix_from}.attn.add_k_proj{end}": (f"{prefix}.txt_attn.qkv{clean_end}", (0, hidden_size, hidden_size)),
                f"{prefix_from}.attn.add_v_proj{end}": (f"{prefix}.txt_attn.qkv{clean_end}", (0, hidden_size * 2, hidden_size)),
            }
            key_map.update(qkv_map)
            key_map.update({f"transformer.{k}": v for k, v in qkv_map.items()})

            # Map output projections without splits
            proj_map = {
                f"{prefix_from}.attn.to_out.0{end}": f"{prefix}.img_attn.proj{clean_end}",
                f"{prefix_from}.attn.to_add_out{end}": f"{prefix}.txt_attn.proj{clean_end}",
            }
            key_map.update(proj_map)
            key_map.update({f"transformer.{k}": v for k, v in proj_map.items()})

        # Map MLP components
        for end in [".weight", ".bias", ".lora_A.weight", ".lora_B.weight"]:
            clean_end = end.replace("lora_A", "A").replace("lora_B", "B")
            mlp_map = {
                f"{prefix_from}.ff.net.0.proj{end}": f"{prefix}.img_mlp.0{clean_end}",
                f"{prefix_from}.ff.net.2{end}": f"{prefix}.img_mlp.2{clean_end}",
                f"{prefix_from}.ff_context.net.0.proj{end}": f"{prefix}.txt_mlp.0{clean_end}",
                f"{prefix_from}.ff_context.net.2{end}": f"{prefix}.txt_mlp.2{clean_end}",
                f"{prefix_from}.norm1.linear{end}": f"{prefix}.img_mod.lin{clean_end}",
                f"{prefix_from}.norm1_context.linear{end}": f"{prefix}.txt_mod.lin{clean_end}",
            }
            key_map.update(mlp_map)
            key_map.update({f"transformer.{k}": v for k, v in mlp_map.items()})

    # Map single transformer blocks with combined linear1
    for index in range(n_single_layers):
        prefix = f"{output_prefix}single_blocks.{index}"
        prefix_from = f"single_transformer_blocks.{index}"

        # Handle components with LoRA variants
        for end in [".weight", ".bias", ".lora_A.weight", ".lora_B.weight"]:
            clean_end = end.replace("lora_A", "A").replace("lora_B", "B")

            # Map the attention components to sections of linear1
            linear1 = f"{prefix}.linear1{clean_end}"
            key_map[f"{prefix_from}.attn.to_q{end}"] = (linear1, (0, 0, hidden_size))
            key_map[f"{prefix_from}.attn.to_k{end}"] = (linear1, (0, hidden_size, hidden_size))
            key_map[f"{prefix_from}.attn.to_v{end}"] = (linear1, (0, hidden_size * 2, hidden_size))
            key_map[f"{prefix_from}.proj_mlp{end}"] = (linear1, (0, hidden_size * 3, hidden_size * 4))

            # Also map transformer prefixed versions
            key_map[f"transformer.{prefix_from}.attn.to_q{end}"] = (linear1, (0, 0, hidden_size))
            key_map[f"transformer.{prefix_from}.attn.to_k{end}"] = (linear1, (0, hidden_size, hidden_size))
            key_map[f"transformer.{prefix_from}.attn.to_v{end}"] = (linear1, (0, hidden_size * 2, hidden_size))
            key_map[f"transformer.{prefix_from}.proj_mlp{end}"] = (linear1, (0, hidden_size * 3, hidden_size * 4))

            # Map other components
            other_map = {
                f"{prefix_from}.proj_out{end}": f"{prefix}.linear2{clean_end}",
                f"{prefix_from}.norm.linear{end}": f"{prefix}.modulation.lin{clean_end}",
            }
            key_map.update(other_map)
            key_map.update({f"transformer.{k}": v for k, v in other_map.items()})

    # Map base model components
    for end in [".weight", ".bias", ".lora_A.weight", ".lora_B.weight"]:
        clean_end = end.replace("lora_A", "A").replace("lora_B", "B")
        base_map = {
            f"context_embedder{end}": f"{output_prefix}txt_in{clean_end}",
            f"time_text_embed.timestep_embedder.linear_1{end}": f"{output_prefix}time_in.in_layer{clean_end}",
            f"time_text_embed.timestep_embedder.linear_2{end}": f"{output_prefix}time_in.out_layer{clean_end}",
            f"time_text_embed.text_embedder.linear_1{end}": f"{output_prefix}vector_in.in_layer{clean_end}",
            f"time_text_embed.text_embedder.linear_2{end}": f"{output_prefix}vector_in.out_layer{clean_end}",
            f"proj_out{end}": f"{output_prefix}final_layer.linear{clean_end}",
        }
        key_map.update(base_map)
        key_map.update({f"transformer.{k}": v for k, v in base_map.items()})

    return key_map

# Replace the original function
totoro.utils.flux_to_diffusers = patched_flux_to_diffusers

print("Patched Flux LoRA compatibility layer with correct component mapping")


# Helper function to check if a token is valid
def is_valid_token(token):
    placeholder_values = ["YOUR_TOKEN_HERE", "YOUR_HUGGINGFACE_TOKEN_HERE", "", None]
    return token not in placeholder_values

# Helper function to list available LoRAs from Google Drive
def list_drive_loras(lora_dir):
    if not os.path.exists(lora_dir):
        return []
    return [f for f in os.listdir(lora_dir) if f.endswith('.safetensors')]

# Modified checkpoint loading logic
if CHECKPOINT_SOURCE.lower() == "googledrive":
    if not os.path.exists(CHECKPOINT_DRIVE_PATH):
        raise Exception(f"Checkpoint not found at: {CHECKPOINT_DRIVE_PATH}")
    # Copy checkpoint to ComfyUI directory
    checkpoint_dest = "/content/TotoroUI/models/unet/pruned-checkpoint.safetensors"
    os.makedirs(os.path.dirname(checkpoint_dest), exist_ok=True)
    shutil.copy2(CHECKPOINT_DRIVE_PATH, checkpoint_dest)
    print(f"Copied checkpoint from Google Drive to: {checkpoint_dest}")
elif CHECKPOINT_SOURCE.lower() == "civitai":
    if not is_valid_token(CIVITAI_TOKEN):
        print("Warning: No Civitai token provided. Attempting download in case model is public...")
    download_url = f"{CHECKPOINT_CAI_URL}{'?token=' + CIVITAI_TOKEN if is_valid_token(CIVITAI_TOKEN) else ''}"
    download_cmd = f'aria2c --console-log-level=error -c -x 16 -s 16 -k 1M "{download_url}" -d /content/TotoroUI/models/unet -o pruned-checkpoint.safetensors'
    os.system(download_cmd)
else:  # huggingface
    if not is_valid_token(HF_TOKEN):
        print("Warning: No Hugging Face token provided. Attempting download in case model is public...")
    download_url = CHECKPOINT_HF_URL
    download_cmd = f'aria2c --console-log-level=error -c -x 16 -s 16 -k 1M "{download_url}" -d /content/TotoroUI/models/unet -o pruned-checkpoint.safetensors'
    if is_valid_token(HF_TOKEN):
        download_cmd += f' --header="Authorization: Bearer {HF_TOKEN}"'
    os.system(download_cmd)

# Download other required models
!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/FLUX.1-dev/resolve/main/ae.sft -d /content/TotoroUI/models/vae -o ae.sft
!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/FLUX.1-dev/resolve/main/clip_l.safetensors -d /content/TotoroUI/models/clip -o clip_l.safetensors
!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/FLUX.1-dev/resolve/main/t5xxl_fp8_e4m3fn.safetensors -d /content/TotoroUI/models/clip -o t5xxl_fp8_e4m3fn.safetensors


DualCLIPLoader = NODE_CLASS_MAPPINGS["DualCLIPLoader"]()
UNETLoader = NODE_CLASS_MAPPINGS["UNETLoader"]()
LoraLoader = NODE_CLASS_MAPPINGS["LoraLoader"]()
FluxGuidance = nodes_flux.NODE_CLASS_MAPPINGS["FluxGuidance"]()
RandomNoise = nodes_custom_sampler.NODE_CLASS_MAPPINGS["RandomNoise"]()
BasicGuider = nodes_custom_sampler.NODE_CLASS_MAPPINGS["BasicGuider"]()
KSamplerSelect = nodes_custom_sampler.NODE_CLASS_MAPPINGS["KSamplerSelect"]()
BasicScheduler = nodes_custom_sampler.NODE_CLASS_MAPPINGS["BasicScheduler"]()
SamplerCustomAdvanced = nodes_custom_sampler.NODE_CLASS_MAPPINGS["SamplerCustomAdvanced"]()
VAELoader = NODE_CLASS_MAPPINGS["VAELoader"]()
VAEDecode = NODE_CLASS_MAPPINGS["VAEDecode"]()
EmptyLatentImage = NODE_CLASS_MAPPINGS["EmptyLatentImage"]()


with torch.inference_mode():
    clip = DualCLIPLoader.load_clip("t5xxl_fp8_e4m3fn.safetensors", "clip_l.safetensors", "flux")[0]
    unet = UNETLoader.load_unet("pruned-checkpoint.safetensors", "fp8_e4m3fn")[0]
    vae = VAELoader.load_vae("ae.sft")[0]

def closestNumber(n, m):
    q = int(n / m)
    n1 = m * q
    if (n * m) > 0:
        n2 = m * (q + 1)
    else:
        n2 = m * (q - 1)
    if abs(n - n1) < abs(n - n2):
        return n1
    return n2

downloaded_loras = {}

# Modified download_lora function with filename collision handling
def download_lora(url, source):
    global downloaded_loras

    cache_key = f"{source}:{url}"
    if cache_key in downloaded_loras:
        if os.path.exists(downloaded_loras[cache_key]):
            print(f"Using cached LoRA: {downloaded_loras[cache_key]}")
            return downloaded_loras[cache_key]
        else:
            # Cached file no longer exists, remove from cache
            del downloaded_loras[cache_key]

    def get_unique_filepath(base_path, filename):
        """Generate unique filepath by appending source and counter if needed"""
        name, ext = os.path.splitext(filename)
        filepath = os.path.join(base_path, filename)
        counter = 1

        while os.path.exists(filepath):
            # Check if it's literally the same file
            if any(existing_path == filepath for existing_path in downloaded_loras.values()):
                print(f"Reusing identical LoRA path: {filepath}")
                return filepath

            new_name = f"{name}_{source}_{counter}{ext}"
            filepath = os.path.join(base_path, new_name)
            counter += 1

        return filepath

    try:
        if source.lower() == "googledrive":
            # For Google Drive, url is actually the path to the LoRA file
            if not os.path.exists(url):
                raise Exception(f"LoRA not found at: {url}")
            filename = os.path.basename(url)
            filepath = get_unique_filepath('/content/TotoroUI/models/loras', filename)
            os.makedirs(os.path.dirname(filepath), exist_ok=True)
            shutil.copy2(url, filepath)
            print(f"LoRA path: {filepath}")
            downloaded_loras[cache_key] = filepath
            return filepath
        elif source.lower() == "civitai":
            if not is_valid_token(CIVITAI_TOKEN):
                print("Warning: No Civitai token provided. Attempting LoRA download in case model is public...")
            headers = {"Authorization": f"Bearer {CIVITAI_TOKEN}"} if is_valid_token(CIVITAI_TOKEN) else {}
            response = requests.get(url, headers=headers)
            if response.status_code == 200:
                content_disposition = response.headers.get('Content-Disposition')
                if content_disposition:
                    filename = content_disposition.split('filename=')[-1].strip('"').split("?")[0]
                else:
                    filename = f"lora_{len(downloaded_loras)}.safetensors"
                filepath = get_unique_filepath('/content/TotoroUI/models/loras', filename)
                os.makedirs(os.path.dirname(filepath), exist_ok=True)
                download_url = f"{url}{'?token=' + CIVITAI_TOKEN if is_valid_token(CIVITAI_TOKEN) else ''}"

                download_cmd = f'aria2c --console-log-level=error -c -x 16 -s 16 -k 1M "{download_url}" -d /content/TotoroUI/models/loras -o "{os.path.basename(filepath)}"'
                os.system(download_cmd)

                if not os.path.exists(filepath):
                    raise Exception(f"Download failed: {filename} not found. If this is a private model, make sure to provide the appropriate token.")

                print(f"LoRA downloaded to: {filepath}")
                downloaded_loras[cache_key] = filepath
                return filepath
        else:  # huggingface
            os.makedirs("/content/TotoroUI/models/loras", exist_ok=True)
            filename = url.split("/")[-1] # Filename at end of HF URL
            filepath = get_unique_filepath('/content/TotoroUI/models/loras', filename)

            download_cmd = f'aria2c --console-log-level=error -c -x 16 -s 16 -k 1M "{url}" -d /content/TotoroUI/models/loras -o "{os.path.basename(filepath)}"'
            if is_valid_token(HF_TOKEN):
                download_cmd += f' --header="Authorization: Bearer {HF_TOKEN}"'
            else:
                print("Warning: No Hugging Face token provided. Attempting LoRA download in case model is public...")

            os.system(download_cmd)

            if not os.path.exists(filepath):
                raise Exception(f"Download failed: {filename} not found. If this is a private model, make sure to provide the appropriate token.")

            print(f"LoRA downloaded to: {filepath}")
            downloaded_loras[cache_key] = filepath
            return filepath

    except Exception as e:
        print(f"Error downloading LoRA: {str(e)}")
        raise


@torch.inference_mode()
def generate(positive_prompt, width, height, seed, steps, sampler_name, scheduler, guidance,
             lora1_source, lora1_url, lora1_strength_model, lora1_strength_clip,
             lora2_source, lora2_url, lora2_strength_model, lora2_strength_clip, clip_skip):
    global unet, clip, vae

    if seed == 0:
        seed = random.randint(0, 18446744073709551615)
    print(f"Using seed: {seed}")

    try:
        # LoRA handling
        unet_lora, clip_lora = unet, clip

        if lora1_url:
            lora1_path = download_lora(lora1_url.split("?")[0], lora1_source)
            lora1_filename = os.path.basename(lora1_path)
            print(f"Using LoRA 1: {lora1_filename} from {lora1_source}")
            unet_lora, clip_lora = LoraLoader.load_lora(unet_lora, clip_lora, lora1_filename, lora1_strength_model, lora1_strength_clip)

        if lora2_url:
            lora2_path = download_lora(lora2_url.split("?")[0], lora2_source)
            lora2_filename = os.path.basename(lora2_path)
            print(f"Using LoRA 2: {lora2_filename} from {lora2_source}")
            unet_lora, clip_lora = LoraLoader.load_lora(unet_lora, clip_lora, lora2_filename, lora2_strength_model, lora2_strength_clip)

        if not lora1_url and not lora2_url:
            print("No LoRA URLs provided, using base model")

        # Encode the prompt
        tokens = clip_lora.tokenize(positive_prompt)
        cond, pooled = clip_lora.encode_from_tokens(tokens, return_pooled=True)

        # Apply CLIP skip if needed
        if clip_skip > 0:
            cond = cond[:, :cond.shape[1] - clip_skip]

        # Create proper conditioning structure expected by FluxGuidance
        cond_dict = {"pooled_output": pooled, "guidance": guidance}
        conditioning = [[cond, cond_dict]]

        # Set up sampling parameters
        noise = RandomNoise.get_noise(seed)[0]
        guider = BasicGuider.get_guider(unet_lora, conditioning)[0]
        sampler = KSamplerSelect.get_sampler(sampler_name)[0]
        sigmas = BasicScheduler.get_sigmas(unet_lora, scheduler, steps, 1.0)[0]

        # Create latent noise image
        latent_image = EmptyLatentImage.generate(closestNumber(width, 16), closestNumber(height, 16))[0]

        # Perform sampling
        sample, _ = SamplerCustomAdvanced.sample(noise, guider, sampler, sigmas, latent_image)
        model_management.soft_empty_cache()

        # Decode with normalized range
        decoded = VAEDecode.decode(vae, sample)[0].detach()
        image_array = np.array(decoded * 255, dtype=np.uint8)[0]
        Image.fromarray(image_array).save("/content/flux.png")
        return "/content/flux.png"

    except Exception as e:
        print(f"Error during generation: {str(e)}")
        raise


# Modified Gradio interface with complete UI handling
with gr.Blocks(analytics_enabled=False) as demo:
    with gr.Row():
        with gr.Column():
            positive_prompt = gr.Textbox(lines=3, interactive=True,
                                       value="Intergalactic beauty contest winner, pretty female humanoid alien in party dress, holding up sign saying 'UNET Flux With 2 LoRA Colab'",
                                       label="Prompt")
            width = gr.Slider(minimum=256, maximum=2048, value=1024, step=16, label="width")
            height = gr.Slider(minimum=256, maximum=2048, value=1024, step=16, label="height")
            seed = gr.Slider(minimum=0, maximum=18446744073709551615, value=0, step=1, label="seed (0=random)")
            steps = gr.Slider(minimum=4, maximum=50, value=20, step=1, label="steps")
            guidance = gr.Slider(minimum=0, maximum=20, value=3.5, step=0.5, label="guidance")
            sampler_name = gr.Dropdown(["euler", "heun", "heunpp2", "lms", "dpm_2", "dpmpp_2m", "deis", "ddim", "uni_pc", "uni_pc_bh2"],
                                     label="sampler_name", value="dpmpp_2m")
            scheduler = gr.Dropdown(["normal", "sgm_uniform", "simple", "ddim_uniform"],
                                  label="scheduler", value="sgm_uniform")

            # Modified LoRA 1 settings with Google Drive support
            with gr.Group():
                gr.Markdown("### LoRA 1 Settings")
                lora1_source = gr.Radio(["civitai", "huggingface", "googledrive"],
                                      label="LoRA 1 Source", value="civitai")
                lora1_drive_files = gr.Dropdown(choices=[], label="Select LoRA 1 from Drive", visible=False)
                lora1_url = gr.Textbox(lines=1, interactive=True, value="", label="LoRA 1 URL/Path")
                lora1_strength_model = gr.Slider(minimum=-1.0, maximum=1.5, value=1.0, step=0.05,
                                               label="LoRA 1 Strength (Model)")
                lora1_strength_clip = gr.Slider(minimum=-1.0, maximum=1.5, value=1.0, step=0.05,
                                              label="LoRA 1 Strength (CLIP)")

            # Modified LoRA 2 settings with Google Drive support
            with gr.Group():
                gr.Markdown("### LoRA 2 Settings")
                lora2_source = gr.Radio(["civitai", "huggingface", "googledrive"],
                                      label="LoRA 2 Source", value="civitai")
                lora2_drive_files = gr.Dropdown(choices=[], label="Select LoRA 2 from Drive", visible=False)
                lora2_url = gr.Textbox(lines=1, interactive=True, value="", label="LoRA 2 URL/Path")
                lora2_strength_model = gr.Slider(minimum=-1.0, maximum=1.5, value=1.0, step=0.05,
                                               label="LoRA 2 Strength (Model)")
                lora2_strength_clip = gr.Slider(minimum=-1.0, maximum=1.5, value=1.0, step=0.05,
                                              label="LoRA 2 Strength (CLIP)")

            clip_skip = gr.Slider(minimum=0, maximum=2, value=0, step=1, label="CLIP Skip")
            generate_button = gr.Button("Generate")

        with gr.Column():
            output_image = gr.Image(label="Generated image", interactive=False)

    # Updated interface update functions
    def update_lora_interface(source, url_textbox):
        """Update LoRA interface based on selected source"""
        if source == "googledrive" and 'LORA_DRIVE_DIR' in globals():
            return (
                gr.update(visible=True, choices=list_drive_loras(LORA_DRIVE_DIR), value=None),
                gr.update(visible=True, interactive=False, value="")
            )
        else:
            return (
                gr.update(visible=False, choices=[], value=None),
                gr.update(visible=True, interactive=True, value="")
            )

    def update_lora_url(selected_file, source):  # Added source parameter
        """Update URL when a file is selected from dropdown"""
        if selected_file and source == "googledrive":  # Only modify interactivity for Google Drive
            full_path = f"{LORA_DRIVE_DIR}/{selected_file}"
            return gr.update(value=full_path, visible=True, interactive=False)
        return gr.update(value="", visible=True)  # Maintain current interactive state for other sources

    # Set up LoRA 1 interface events
    lora1_source.change(
        fn=update_lora_interface,
        inputs=[lora1_source, lora1_url],
        outputs=[lora1_drive_files, lora1_url]
    )
    lora1_drive_files.change(
        fn=update_lora_url,
        inputs=[lora1_drive_files, lora1_source],  # Added source input
        outputs=[lora1_url]
    )

    # Set up LoRA 2 interface events
    lora2_source.change(
        fn=update_lora_interface,
        inputs=[lora2_source, lora2_url],
        outputs=[lora2_drive_files, lora2_url]
    )
    lora2_drive_files.change(
        fn=update_lora_url,
        inputs=[lora2_drive_files, lora2_source],  # Added source input
        outputs=[lora2_url]
    )

    generate_button.click(
        fn=generate,
        inputs=[
            positive_prompt, width, height, seed, steps, sampler_name, scheduler, guidance,
            lora1_source, lora1_url, lora1_strength_model, lora1_strength_clip,
            lora2_source, lora2_url, lora2_strength_model, lora2_strength_clip, clip_skip
        ],
        outputs=output_image
    )

demo.queue().launch(inline=False, share=True, debug=True)