In [5]:
import os
from tqdm import tqdm
import torch
from PIL import Image
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, DPMSolverMultistepScheduler
from diffusers.utils import load_image
import numpy as np

In [2]:
# ---- Configurations ---- #
base_model_id = "runwayml/stable-diffusion-v1-5"
lora_path = "./lora_output"
textual_inversion_path = "./textual_inversion_output"
custom_token = "<ultrasound>"

output_dir = "data/LoRA_TI_ControlNet/"
controlnet_model_id = "lllyasviel/sd-controlnet-canny"
control_dir = "data/BUSI_edges"  # folder with edge maps (matching class)

num_images_per_class = 300

# Prompts for each class
class_prompts = {
    "benign": f"an {custom_token} image showing a benign breast lesion with smooth borders",
    "malignant": f"an {custom_token} image of a malignant breast lesion with irregular borders",
    "normal": f"an {custom_token} image showing normal breast tissue without any tumor or lesion"
}

device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
# ----------------- LOAD CONTROLNET PIPELINE ----------------- #
controlnet = ControlNetModel.from_pretrained(controlnet_model_id, torch_dtype=torch.float16)

pipe = StableDiffusionControlNetPipeline.from_pretrained(
    base_model_id,
    controlnet=controlnet,
    torch_dtype=torch.float16,
    scheduler=DPMSolverMultistepScheduler.from_pretrained(base_model_id, subfolder="scheduler"),
    safety_checker=None
).to(device)

# Load TI and LoRA
pipe.load_textual_inversion(textual_inversion_path, token=custom_token)
pipe.unet.load_attn_procs(lora_path)
pipe.enable_xformers_memory_efficient_attention()


Loading pipeline components...: 100%|██████████| 6/6 [00:00<00:00, 14.01it/s]
You have disabled the safety checker for <class 'diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .
The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`
  deprecate("load_

In [6]:
# ----------------- GENERATION LOOP ----------------- #
os.makedirs(output_dir, exist_ok=True)

for cls in ['benign', 'malignant', 'normal']:
    save_path = os.path.join(output_dir, cls)
    os.makedirs(save_path, exist_ok=True)

    prompt = class_prompts[cls]
    edge_files = sorted([f for f in os.listdir(control_dir) if f.startswith(cls) and f.endswith(('.png', '.jpg'))])
    target_count = num_images_per_class
    img_index = 0

    print(f"Generating {num_images_per_class} images for class: {cls}")
    
    if cls == "benign":
        # One edge map per image
        for i, fname in enumerate(tqdm(edge_files[:target_count])):
            control_image = load_image(os.path.join(control_dir, fname)).resize((512, 512))
            seed = 1000 + i
            generator = torch.Generator(device=device).manual_seed(seed)
            image = pipe(prompt=prompt, image=control_image, num_inference_steps=50, generator=generator).images[0]
            image.save(os.path.join(save_path, f"{cls}_{img_index:04d}.png"))
            img_index += 1
    elif cls == "malignant":
        # Fewer edge maps: reuse with different seeds
        repeats = (target_count + len(edge_files) - 1) // len(edge_files)
        edge_pool = (edge_files * repeats)[:target_count]
        for i, fname in enumerate(tqdm(edge_pool)):
            control_image = load_image(os.path.join(control_dir, fname)).resize((512, 512))
            seed = 2000 + i
            generator = torch.Generator(device=device).manual_seed(seed)
            image = pipe(prompt=prompt, image=control_image, num_inference_steps=50, generator=generator).images[0]
            image.save(os.path.join(save_path, f"{cls}_{img_index:04d}.png"))
            img_index += 1

    elif cls == "normal":
        # No edge maps → use blank image + prompt
        blank = Image.fromarray(np.zeros((512, 512), dtype=np.uint8))
        for i in tqdm(range(target_count)):
            seed = 3000 + i
            generator = torch.Generator(device=device).manual_seed(seed)
            image = pipe(prompt=prompt, image=blank, num_inference_steps=50, generator=generator).images[0]
            image.save(os.path.join(save_path, f"{cls}_{img_index:04d}.png"))
            img_index += 1


print(f"\n✅ ControlNet + LoRA + TI image generation complete. Images saved to {output_dir}/")


Generating 300 images for class: normal


100%|██████████| 50/50 [00:04<00:00, 10.90it/s]
100%|██████████| 50/50 [00:04<00:00, 10.81it/s]
100%|██████████| 50/50 [00:04<00:00, 10.73it/s]
100%|██████████| 50/50 [00:04<00:00, 10.70it/s]
100%|██████████| 50/50 [00:04<00:00, 10.68it/s]
100%|██████████| 50/50 [00:04<00:00, 10.59it/s]
100%|██████████| 50/50 [00:04<00:00, 10.55it/s]
100%|██████████| 50/50 [00:04<00:00, 10.49it/s]
100%|██████████| 50/50 [00:04<00:00, 10.43it/s]
100%|██████████| 50/50 [00:04<00:00, 10.37it/s]
100%|██████████| 50/50 [00:04<00:00, 10.31it/s]]
100%|██████████| 50/50 [00:04<00:00, 10.26it/s]]
100%|██████████| 50/50 [00:04<00:00, 10.21it/s]]
100%|██████████| 50/50 [00:04<00:00, 10.19it/s]]
100%|██████████| 50/50 [00:04<00:00, 10.16it/s]]
100%|██████████| 50/50 [00:04<00:00, 10.14it/s]]
100%|██████████| 50/50 [00:04<00:00, 10.11it/s]]
100%|██████████| 50/50 [00:04<00:00, 10.09it/s]]
100%|██████████| 50/50 [00:04<00:00, 10.07it/s]]
100%|██████████| 50/50 [00:04<00:00, 10.05it/s]]
100%|██████████| 50/50 [00:04<


✅ ControlNet + LoRA + TI image generation complete. Images saved to data/LoRA_TI_ControlNet//



