In [1]:
import os
from tqdm import tqdm
import torch
from PIL import Image
from diffusers import StableDiffusionImg2ImgPipeline, StableDiffusionControlNetPipeline, ControlNetModel, DPMSolverMultistepScheduler
from diffusers.utils import load_image
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [3]:

# ---- Configurations ---- #
base_model_id = "runwayml/stable-diffusion-v1-5"
lora_path = "./lora_output"
textual_inversion_path = "./textual_inversion_output"
custom_token = "<ultrasound>"

input_dir = "data/LoRA_TI_ControlNet"
refined_output_dir = "data/LoRA_TI_ControlNet_refined"
os.makedirs(refined_output_dir, exist_ok=True)

num_images_per_class = 300
img2img_strength = 0.3
num_inference_steps = 50

# Prompts (same as generation prompts)
class_prompts = {
    "benign": f"an {custom_token} image showing a benign breast lesion with smooth borders",
    "malignant": f"an {custom_token} image of a malignant breast lesion with irregular borders",
    "normal": f"an {custom_token} image showing normal breast tissue without any tumor or lesion"
}

device = "cuda" if torch.cuda.is_available() else "cpu"

In [4]:
# ----------------- LOAD IMG2IMG PIPELINE ----------------- #
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
    base_model_id,
    torch_dtype=torch.float16,
    safety_checker=None
).to(device)

pipe.load_textual_inversion(textual_inversion_path, token=custom_token)
pipe.unet.load_attn_procs(lora_path)
pipe.enable_xformers_memory_efficient_attention()

Loading pipeline components...: 100%|██████████| 6/6 [00:00<00:00,  9.62it/s]
You have disabled the safety checker for <class 'diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .
The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`
 

In [5]:
# ----------------- REFINEMENT LOOP ----------------- #
for cls in ['benign', 'malignant', 'normal']:
    input_path = os.path.join(input_dir, cls)
    output_path = os.path.join(refined_output_dir, cls)
    os.makedirs(output_path, exist_ok=True)

    prompt = class_prompts[cls]
    image_files = sorted([f for f in os.listdir(input_path) if f.endswith(('.png', '.jpg'))])[:num_images_per_class]

    print(f"Refining {len(image_files)} images for class: {cls}")
    
    for i, fname in enumerate(tqdm(image_files)):
        input_image = Image.open(os.path.join(input_path, fname)).convert("RGB").resize((512, 512))

        seed = 4000 + i
        generator = torch.Generator(device=device).manual_seed(seed)

        refined = pipe(
            prompt=prompt,
            image=input_image,
            strength=img2img_strength,
            num_inference_steps=num_inference_steps,
            generator=generator
        ).images[0]

        refined.save(os.path.join(output_path, f"{cls}_refined_{i:04d}.png"))

print(f"\n✅ Image-to-Image refinement complete. Refined images saved to {refined_output_dir}/")

Refining 300 images for class: benign


100%|██████████| 15/15 [00:01<00:00, 13.43it/s]
100%|██████████| 15/15 [00:01<00:00, 13.21it/s]
100%|██████████| 15/15 [00:01<00:00, 13.10it/s]
100%|██████████| 15/15 [00:01<00:00, 13.16it/s]
100%|██████████| 15/15 [00:01<00:00, 13.10it/s]
100%|██████████| 15/15 [00:01<00:00, 13.08it/s]
100%|██████████| 15/15 [00:01<00:00, 13.05it/s]
100%|██████████| 15/15 [00:01<00:00, 13.01it/s]
100%|██████████| 15/15 [00:01<00:00, 12.97it/s]
100%|██████████| 15/15 [00:01<00:00, 12.98it/s]
100%|██████████| 15/15 [00:01<00:00, 12.94it/s]]
100%|██████████| 15/15 [00:01<00:00, 12.92it/s]]
100%|██████████| 15/15 [00:01<00:00, 12.99it/s]]
100%|██████████| 15/15 [00:01<00:00, 12.99it/s]]
100%|██████████| 15/15 [00:01<00:00, 12.98it/s]]
100%|██████████| 15/15 [00:01<00:00, 12.96it/s]]
100%|██████████| 15/15 [00:01<00:00, 12.95it/s]]
100%|██████████| 15/15 [00:01<00:00, 12.91it/s]]
100%|██████████| 15/15 [00:01<00:00, 12.87it/s]]
100%|██████████| 15/15 [00:01<00:00, 12.85it/s]]
100%|██████████| 15/15 [00:01<

Refining 300 images for class: malignant


100%|██████████| 15/15 [00:01<00:00, 11.96it/s]
100%|██████████| 15/15 [00:01<00:00, 11.95it/s]
100%|██████████| 15/15 [00:01<00:00, 11.94it/s]
100%|██████████| 15/15 [00:01<00:00, 11.97it/s]
100%|██████████| 15/15 [00:01<00:00, 11.95it/s]
100%|██████████| 15/15 [00:01<00:00, 11.95it/s]
100%|██████████| 15/15 [00:01<00:00, 11.94it/s]
100%|██████████| 15/15 [00:01<00:00, 11.93it/s]
100%|██████████| 15/15 [00:01<00:00, 11.94it/s]
100%|██████████| 15/15 [00:01<00:00, 11.93it/s]
100%|██████████| 15/15 [00:01<00:00, 11.94it/s]]
100%|██████████| 15/15 [00:01<00:00, 11.93it/s]]
100%|██████████| 15/15 [00:01<00:00, 11.94it/s]]
100%|██████████| 15/15 [00:01<00:00, 11.92it/s]]
100%|██████████| 15/15 [00:01<00:00, 11.95it/s]]
100%|██████████| 15/15 [00:01<00:00, 11.92it/s]]
100%|██████████| 15/15 [00:01<00:00, 11.92it/s]]
100%|██████████| 15/15 [00:01<00:00, 11.93it/s]]
100%|██████████| 15/15 [00:01<00:00, 11.95it/s]]
100%|██████████| 15/15 [00:01<00:00, 11.91it/s]]
100%|██████████| 15/15 [00:01<

Refining 300 images for class: normal


100%|██████████| 15/15 [00:01<00:00, 10.40it/s]
100%|██████████| 15/15 [00:01<00:00, 10.11it/s]
100%|██████████| 15/15 [00:01<00:00, 10.43it/s]
100%|██████████| 15/15 [00:01<00:00, 10.20it/s]
100%|██████████| 15/15 [00:01<00:00, 10.10it/s]
100%|██████████| 15/15 [00:01<00:00, 10.64it/s]
100%|██████████| 15/15 [00:01<00:00, 10.60it/s]
100%|██████████| 15/15 [00:01<00:00, 10.14it/s]
100%|██████████| 15/15 [00:01<00:00, 10.13it/s]
100%|██████████| 15/15 [00:01<00:00, 10.44it/s]
100%|██████████| 15/15 [00:01<00:00, 10.11it/s]]
100%|██████████| 15/15 [00:01<00:00, 10.20it/s]]
100%|██████████| 15/15 [00:01<00:00, 10.78it/s]]
100%|██████████| 15/15 [00:01<00:00, 10.45it/s]]
100%|██████████| 15/15 [00:01<00:00, 10.08it/s]]
100%|██████████| 15/15 [00:01<00:00, 10.28it/s]]
100%|██████████| 15/15 [00:01<00:00, 10.09it/s]]
100%|██████████| 15/15 [00:01<00:00, 10.14it/s]]
100%|██████████| 15/15 [00:01<00:00, 10.32it/s]]
100%|██████████| 15/15 [00:01<00:00, 11.02it/s]]
100%|██████████| 15/15 [00:01<


✅ Image-to-Image refinement complete. Refined images saved to data/LoRA_TI_ControlNet_refined/



