In [1]:
import os
from tqdm import tqdm
import torch
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
from PIL import Image

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# ---- Configurations ---- #
base_model_id = "runwayml/stable-diffusion-v1-5"  # or your custom fine-tuned base model
lora_path = "./lora_output"                      # folder containing your LoRA weights
textual_inversion_path = "./textual_inversion_output"    # folder containing your .bin or .pt or .safetensors file
custom_token = "<ultrasound>"                     # this is the token you used for textual inversion

output_dir = "data/LoRA_TI/"

# Number of images per class to generate
num_images_per_class = 300

# Prompts for each class
class_prompts = {
    "benign": f"an ultrasound image showing a benign breast lesion with smooth borders",
    "malignant": f"an ultrasound image of a malignant breast lesion with irregular borders",
    "normal": f"an ultrasound image showing normal breast tissue without any tumor or lesion"
}

device = "cuda" if torch.cuda.is_available() else "cpu"


In [3]:
# ----------------- LOAD PIPELINE ----------------- #
pipe = StableDiffusionPipeline.from_pretrained(
    base_model_id,
    torch_dtype=torch.float16,
    safety_checker=None,
    scheduler=DPMSolverMultistepScheduler.from_pretrained(base_model_id, subfolder="scheduler")
).to(device)

# ---- Load Textual Inversion ---- #
pipe.load_textual_inversion(textual_inversion_path, token=custom_token)
# ---- Load LoRA weights ---- #
pipe.unet.load_attn_procs(lora_path)
# Optional: Enable memory-efficient attention if you're on a limited GPU
pipe.enable_xformers_memory_efficient_attention()

Loading pipeline components...: 100%|██████████| 6/6 [00:00<00:00, 12.19it/s]
You have disabled the safety checker for <class 'diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .
The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`
  deprecate("loa

In [4]:
# ----------------- AUGMENTATION LOOP ----------------- #
os.makedirs(output_dir, exist_ok=True)

for cls, prompt in class_prompts.items():
    save_path = os.path.join(output_dir, cls)
    os.makedirs(save_path, exist_ok=True)

    print(f"Generating {num_images_per_class} images for class: {cls}")
    for i in tqdm(range(num_images_per_class)):
        seed = 1000 + i
        generator = torch.Generator(device=device).manual_seed(seed)
        image = pipe(prompt, num_inference_steps=50, generator=generator).images[0]

        img_path = os.path.join(save_path, f"lora_texInv_{cls}_{i:03d}.png")
        image.save(img_path)

print(f"\n✅ Augmentation complete. Images saved to {output_dir}/")


Generating 300 images for class: benign


100%|██████████| 50/50 [00:03<00:00, 14.58it/s]
100%|██████████| 50/50 [00:03<00:00, 15.02it/s]
100%|██████████| 50/50 [00:03<00:00, 14.94it/s]
100%|██████████| 50/50 [00:03<00:00, 14.86it/s]
100%|██████████| 50/50 [00:03<00:00, 14.79it/s]
100%|██████████| 50/50 [00:03<00:00, 14.71it/s]
100%|██████████| 50/50 [00:03<00:00, 14.75it/s]
100%|██████████| 50/50 [00:03<00:00, 14.71it/s]
100%|██████████| 50/50 [00:03<00:00, 14.63it/s]
100%|██████████| 50/50 [00:03<00:00, 14.55it/s]
100%|██████████| 50/50 [00:03<00:00, 14.52it/s]]
100%|██████████| 50/50 [00:03<00:00, 14.48it/s]]
100%|██████████| 50/50 [00:03<00:00, 14.42it/s]]
100%|██████████| 50/50 [00:03<00:00, 14.37it/s]]
100%|██████████| 50/50 [00:03<00:00, 14.32it/s]]
100%|██████████| 50/50 [00:03<00:00, 14.28it/s]]
100%|██████████| 50/50 [00:03<00:00, 14.24it/s]]
100%|██████████| 50/50 [00:03<00:00, 14.20it/s]]
100%|██████████| 50/50 [00:03<00:00, 14.17it/s]]
100%|██████████| 50/50 [00:03<00:00, 14.13it/s]]
100%|██████████| 50/50 [00:03<

Generating 300 images for class: malignant


100%|██████████| 50/50 [00:03<00:00, 13.63it/s]
100%|██████████| 50/50 [00:03<00:00, 13.64it/s]
100%|██████████| 50/50 [00:03<00:00, 13.63it/s]
100%|██████████| 50/50 [00:03<00:00, 13.62it/s]
100%|██████████| 50/50 [00:03<00:00, 13.62it/s]
100%|██████████| 50/50 [00:03<00:00, 13.63it/s]
100%|██████████| 50/50 [00:03<00:00, 13.63it/s]
100%|██████████| 50/50 [00:03<00:00, 13.63it/s]
100%|██████████| 50/50 [00:03<00:00, 13.63it/s]
100%|██████████| 50/50 [00:03<00:00, 13.63it/s]
100%|██████████| 50/50 [00:03<00:00, 13.63it/s]]
100%|██████████| 50/50 [00:03<00:00, 13.62it/s]]
100%|██████████| 50/50 [00:03<00:00, 13.63it/s]]
100%|██████████| 50/50 [00:03<00:00, 13.62it/s]]
100%|██████████| 50/50 [00:03<00:00, 13.62it/s]]
100%|██████████| 50/50 [00:03<00:00, 13.63it/s]]
100%|██████████| 50/50 [00:03<00:00, 13.62it/s]]
100%|██████████| 50/50 [00:03<00:00, 13.62it/s]]
100%|██████████| 50/50 [00:03<00:00, 13.63it/s]]
100%|██████████| 50/50 [00:03<00:00, 13.63it/s]]
100%|██████████| 50/50 [00:03<

Generating 300 images for class: normal


100%|██████████| 50/50 [00:03<00:00, 13.63it/s]
100%|██████████| 50/50 [00:03<00:00, 13.64it/s]
100%|██████████| 50/50 [00:03<00:00, 13.63it/s]
100%|██████████| 50/50 [00:03<00:00, 13.62it/s]
100%|██████████| 50/50 [00:03<00:00, 13.64it/s]
100%|██████████| 50/50 [00:03<00:00, 13.63it/s]
100%|██████████| 50/50 [00:03<00:00, 13.64it/s]
100%|██████████| 50/50 [00:03<00:00, 13.64it/s]
100%|██████████| 50/50 [00:03<00:00, 13.64it/s]
100%|██████████| 50/50 [00:03<00:00, 13.64it/s]
100%|██████████| 50/50 [00:03<00:00, 13.64it/s]]
100%|██████████| 50/50 [00:03<00:00, 13.63it/s]]
100%|██████████| 50/50 [00:03<00:00, 13.63it/s]]
100%|██████████| 50/50 [00:03<00:00, 13.63it/s]]
100%|██████████| 50/50 [00:03<00:00, 13.63it/s]]
100%|██████████| 50/50 [00:03<00:00, 13.63it/s]]
100%|██████████| 50/50 [00:03<00:00, 13.63it/s]]
100%|██████████| 50/50 [00:03<00:00, 13.64it/s]]
100%|██████████| 50/50 [00:03<00:00, 13.64it/s]]
100%|██████████| 50/50 [00:03<00:00, 13.63it/s]]
100%|██████████| 50/50 [00:03<


✅ Augmentation complete. Images saved to data/LoRA_textual_inversion//



