In [1]:
import torch.nn as nn
import torchvision.transforms as T
import torchvision.models as models
import os
from PIL import Image, ImageChops
from diffusers import StableDiffusionXLImg2ImgPipeline, LatentConsistencyModelImg2ImgPipeline, StableDiffusionInpaintPipeline, AutoPipelineForInpainting, DEISMultistepScheduler
from torchvision.ops import box_convert
import torch
import numpy as np
import shutil

# Function to load YOLO bounding boxes from a text file
def load_yolo_boxes(label_path, img_width, img_height):
    boxes = []
    with open(label_path, 'r') as file:
        for line in file:
            parts = line.strip().split()
            _, x_center, y_center, width, height = map(float, parts)
            boxes.append([x_center, y_center, width, height])

    boxes_tensor = torch.tensor(boxes)
    boxes_unnorm = boxes_tensor * torch.Tensor([img_width, img_height, img_width, img_height])
    boxes_xyxy = box_convert(boxes=boxes_unnorm, in_fmt="cxcywh", out_fmt="xyxy").numpy()

    return boxes_xyxy
# ----------------------------------------------------------------------------------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ----------------------------------------------------------------------------------------------------------------------
#pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained("stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True).to(device)

#pipe = LatentConsistencyModelImg2ImgPipeline.from_pretrained("SimianLuo/LCM_Dreamshaper_v7", torch_dtype=torch.float16).to(device)

pipe = StableDiffusionInpaintPipeline.from_pretrained("stabilityai/stable-diffusion-2-inpainting", torch_dtype=torch.float16, variant="fp16", use_safetensors=True).to(device)

# pipe = AutoPipelineForInpainting.from_pretrained('Lykon/dreamshaper-8-inpainting', torch_dtype=torch.float16, variant="fp16")
# pipe.scheduler = DEISMultistepScheduler.from_config(pipe.scheduler.config)
# pipe = pipe.to("cuda")


# ----------------------------------------------------------------------------------------------------------------------
# Load InceptionV3 for FID feature extraction
# ----------------------------------------------------------------------------------------------------------------------
inception_model = models.inception_v3(
    pretrained=True,
    aux_logits=True,
    transform_input=False
)
inception_model.to(device)
inception_model.eval()
inception_model.dropout = nn.Identity()
inception_model.fc      = nn.Identity()

# Preprocessing pipeline for FID (resize to 299×299, ToTensor, ImageNet normalization)
fid_transform = T.Compose([
    T.Resize((299, 299)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406],
                std =[0.229, 0.224, 0.225])
])

@torch.no_grad()
def get_inception_feat(pil_img: Image.Image) -> torch.Tensor:
    """
    Given a PIL.Image, apply the standard FID transformations, pass through
    InceptionV3 backbone (with final layers replaced by Identity), and return
    a 2048-dimensional feature vector on CPU.
    """
    x = fid_transform(pil_img).unsqueeze(0).to(device)  # shape: (1, 3, 299, 299)
    feat_2048 = inception_model(x)                      # shape: (1, 2048)
    return feat_2048.squeeze(0).cpu()                   # shape: (2048,)

# ----------------------------------------------------------------------------------------------------------------------
def compute_fid_between_two_images(orig_path: str, var_img: Image.Image) -> float:
    """
    1. Load the original from disk (orig_path) → PIL.Image (RGB).
    2. Extract a 2048-D Inception feature from that.
    3. Extract a 2048-D Inception feature from var_img (PIL.Image).
    4. Return squared Euclidean distance between those two feature vectors.
    """
    img_orig = Image.open(orig_path).convert("RGB")
    feat_orig = get_inception_feat(img_orig)  # (2048,)
    feat_var  = get_inception_feat(var_img)   # (2048,)
    distance = torch.sum((feat_orig - feat_var) ** 2).item()
    return distance

# ----------------------------------------------------------------------------------------------------------------------
def generate_and_select_variations(original_image_path: str, num_variations: int, top_k: int):
    """
    1. Load original image from disk.
    2. Call Stable Diffusion XL Img2Img once with num_images_per_prompt = num_variations.
    3. Compute FID distance between original and each variation.
    4. Sort by ascending FID, keep top_k.
    5. Save top_k variations as JPEG, using only the portion of the filename before any ".rf.".
    6. Copy the corresponding YOLO label, naming it exactly to match the saved image basename.
    7. Return a list of the top_k FID distances.
    """
    # --- derive clean base name ---
    orig_filename = os.path.basename(original_image_path)
    root, ext       = os.path.splitext(orig_filename)

    # split at ".rf." and keep the first part
    clean_base = root.split('.rf.')[0]

    # load original image
    img_orig = Image.open(original_image_path).convert("RGB")
    img_width, img_height = img_orig.size

    # corresponding label path (expects labels named <clean_base>.txt)
    orig_label = os.path.join(yolo_labels_folder, root + ".txt")
    boxes_xyxy = load_yolo_boxes(orig_label, img_width, img_height)

    # build mask for inpainting
    mask = np.zeros((img_height, img_width), dtype=np.uint8)
    for x0, y0, x1, y1 in boxes_xyxy:
        mask[int(y0):int(y1), int(x0):int(x1)] = 255
    image_mask    = Image.fromarray(mask)
    image_resized = img_orig.resize((1024, 1024))
    mask_resized  = image_mask.resize((1024, 1024))
    mask_inverted = ImageChops.invert(mask_resized)

    STRENGTH = .2 ## level of variation from original

    # Guidance scale (CFG scale):
    # How strongly the model adheres to the prompt. Higher values make it more strict.
    GUIDANCE_SCALE = 10.0

    # More steps generally lead to better quality but take longer. 20-30 is a good balance.
    NUM_INFERENCE_STEPS = 1000
    NEGATIVE_PROMPT = "blurry, low quality, cartoon, painting, drawing, ugly, deformed, out of focus, unrealistic, pixelated, bad composition, watermark, text"

    # Number of variations to generate in a single run
    # generate variations in one call
    out = pipe(
        prompt=prompt,
        image=image_resized,
        strength=STRENGTH,
        guidance_scale=GUIDANCE_SCALE,
        negative_prompt=NEGATIVE_PROMPT,
        num_inference_steps=NUM_INFERENCE_STEPS,
        num_images_per_prompt=num_variations,
        mask_image=mask_inverted
    )

    generated_images = out.images  # list length = num_variations

    # compute FID distances
    variations = []
    for idx, var_img in enumerate(generated_images):
        dist = compute_fid_between_two_images(original_image_path, var_img)
        variations.append((var_img, dist, idx))

    # select top_k by FID
    variations.sort(key=lambda x: x[1])
    top_variations = variations[: min(top_k, len(variations))]

    saved_distances = []
    for rank, (img_var, dist, _) in enumerate(top_variations, start=1):
        # resize back to original
        img_var = img_var.resize((img_width, img_height))

        save_name = f"{clean_base}_var{rank:02d}_fid{dist:.4f}"
        img_name = f"{save_name}{ext}"
        save_path = os.path.join(output_folder, img_name)

        # save JPEG
        img_var.save(save_path, format="JPEG")
        #composite = img_var.copy()
        #for x0, y0, x1, y1 in boxes_xyxy:
            # crop patch from original
            #patch = img_orig.crop((int(x0), int(y0), int(x1), int(y1)))
            # paste patch onto generated
            # composite.paste(patch, (int(x0), int(y0)))
            # composite.save(save_path, format="JPEG")
        # copy label to match exactly
        new_label = os.path.join(output_label_folder, f"{save_name}.txt")
        if os.path.isfile(orig_label):
            shutil.copyfile(orig_label, new_label)

        saved_distances.append(dist)

    return saved_distances


# ----------------------------------------------------------------------------------------------------------------------
source_folder       = r"C:\Users\cmull\DataspellProjects\AutoAnnotate\AI gen project\GT-red-leaf-11\train\images"
output_folder       = "./RLgen_SDIP/images"
output_label_folder = "./RLgen_SDIP/labels"
prompt              = "generate exact variation"
yolo_labels_folder  = r"C:\Users\cmull\DataspellProjects\AutoAnnotate\AI gen project\GT-red-leaf-11\train\labels"

os.makedirs(output_folder, exist_ok=True)
os.makedirs(output_label_folder, exist_ok=True)

if __name__ == "__main__":
    all_files = sorted(os.listdir(source_folder))
    for filename in all_files:
        orig_path = os.path.join(source_folder, filename)

        # Only process images
        if not orig_path.lower().endswith((".jpg", ".jpeg", ".png")):
            continue

        # Example: generate 5 variations, then keep the top 2
        num_variations = 4
        top_k          = 1

        fid_list = generate_and_select_variations(orig_path, num_variations, top_k)
        print(f"Image {filename}: top FID distances = {fid_list}")


Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]



  0%|          | 0/200 [00:00<?, ?it/s]

Image 20220722_111508_jpg.rf.243aaa5adde3ce6ec67ce39e5a231cb8.jpg: top FID distances = [60.778953552246094]


  0%|          | 0/200 [00:00<?, ?it/s]

Image 20220722_111508_jpg.rf.8410cc0f48b76f0e0dfdc00786f11022.jpg: top FID distances = [59.176395416259766]


  0%|          | 0/200 [00:00<?, ?it/s]

Image DSCF3382_JPG.rf.503632c6c007e68c8b94e3f89424f34a.jpg: top FID distances = [49.42195129394531]


  0%|          | 0/200 [00:00<?, ?it/s]

Image DSCF3382_JPG.rf.8e9f4eec0f9cfe35c1e9582c7b1362e5.jpg: top FID distances = [38.378299713134766]


  0%|          | 0/200 [00:00<?, ?it/s]

Image DSCF3386_JPG.rf.08bdf498a065c24b6e2087f982b1dfdd.jpg: top FID distances = [61.70756530761719]


  0%|          | 0/200 [00:00<?, ?it/s]

Image DSCF3386_JPG.rf.1ffa59994570d99d4cbf70a3edd509fd.jpg: top FID distances = [56.19220733642578]


  0%|          | 0/200 [00:00<?, ?it/s]

Image DSCF3407_JPG.rf.10d8badc05264ba276a4595a48792e1a.jpg: top FID distances = [34.01392364501953]


  0%|          | 0/200 [00:00<?, ?it/s]

Image DSCF3407_JPG.rf.f21ede95a95bb3acd65059c2cb437bd5.jpg: top FID distances = [72.39913940429688]


  0%|          | 0/200 [00:00<?, ?it/s]

Image DSCF3408_JPG.rf.3d414036e7d598b9391e575f2d5c64d4.jpg: top FID distances = [25.39070701599121]


  0%|          | 0/200 [00:00<?, ?it/s]

Image DSCF3408_JPG.rf.bdb69b5fd4f7f942ef36f54e9fddaf32.jpg: top FID distances = [29.253698348999023]


  0%|          | 0/200 [00:00<?, ?it/s]

Image DSCF3411_JPG.rf.17f89f664b8e9589d1d71e8022be0df3.jpg: top FID distances = [43.3604621887207]


  0%|          | 0/200 [00:00<?, ?it/s]

Image DSCF3411_JPG.rf.ee54e6b6583654964bfb1cb1744fb9c3.jpg: top FID distances = [51.61919403076172]


  0%|          | 0/200 [00:00<?, ?it/s]

Image DSCF3416_JPG.rf.80579040152742f6add2e5463aa953bb.jpg: top FID distances = [51.35577392578125]


  0%|          | 0/200 [00:00<?, ?it/s]

Image DSCF3416_JPG.rf.b48ad47860ac5b7d8dfe6a686cf59ae0.jpg: top FID distances = [49.6343879699707]


  0%|          | 0/200 [00:00<?, ?it/s]

Image DSCF3417_JPG.rf.0d5e0f43cf0641e388703735b5579bd7.jpg: top FID distances = [54.596466064453125]


  0%|          | 0/200 [00:00<?, ?it/s]

Image DSCF3417_JPG.rf.9f8472484e0f3fa9a60fb90e5d01862a.jpg: top FID distances = [48.8845329284668]


  0%|          | 0/200 [00:00<?, ?it/s]

Image DSCF3418_JPG.rf.1b19cd7345d7077389e1a044c5472765.jpg: top FID distances = [39.138851165771484]


  0%|          | 0/200 [00:00<?, ?it/s]

Image DSCF3418_JPG.rf.94436962c8f7740688db6f3575b7de05.jpg: top FID distances = [36.384586334228516]


  0%|          | 0/200 [00:00<?, ?it/s]

Image DSCF3450_JPG.rf.78f4c7ed2434e75c6f96cb99e74b4e57.jpg: top FID distances = [82.04019165039062]


  0%|          | 0/200 [00:00<?, ?it/s]

Image DSCF3450_JPG.rf.c9406ee60b46ea928b577ba89e9448b5.jpg: top FID distances = [83.75072479248047]


  0%|          | 0/200 [00:00<?, ?it/s]

Image DSCF3451_JPG.rf.59c4d6181a626fe8fdcb1774f05e9c80.jpg: top FID distances = [134.7001953125]


  0%|          | 0/200 [00:00<?, ?it/s]

Image DSCF3451_JPG.rf.fadff852701d8bab6ddaed4bfaf33d5c.jpg: top FID distances = [158.1081085205078]


  0%|          | 0/200 [00:00<?, ?it/s]

Image DSCF3479_JPG.rf.0048c40a3547d339ddc8112b7317607f.jpg: top FID distances = [71.8448257446289]


  0%|          | 0/200 [00:00<?, ?it/s]

Image DSCF3479_JPG.rf.db569453b012f03ff59606e758d0b15e.jpg: top FID distances = [78.77117919921875]


  0%|          | 0/200 [00:00<?, ?it/s]

Image DSCF3482_JPG.rf.3d90b60d9ee691e262c95ed31a071307.jpg: top FID distances = [70.56239318847656]


  0%|          | 0/200 [00:00<?, ?it/s]

Image DSCF3482_JPG.rf.e39520c5fd88cd5f9fe28f557656e452.jpg: top FID distances = [89.8516845703125]


  0%|          | 0/200 [00:00<?, ?it/s]

Image IMG_0395_jpg.rf.1d57b0a13a98c22ca481c03b4c26adf6.jpg: top FID distances = [24.27971649169922]


  0%|          | 0/200 [00:00<?, ?it/s]

Image IMG_0395_jpg.rf.28cf16a3c3d8d0734dbc830090df3746.jpg: top FID distances = [36.418949127197266]


  0%|          | 0/200 [00:00<?, ?it/s]

Image IMG_0410_jpg.rf.2f88de79348d8351abf32930c1088ace.jpg: top FID distances = [33.94023132324219]


  0%|          | 0/200 [00:00<?, ?it/s]

Image IMG_0410_jpg.rf.8bb8045e8877bfd6510db41dba1e7c1b.jpg: top FID distances = [35.69694137573242]


  0%|          | 0/200 [00:00<?, ?it/s]

Image IMG_0412_jpg.rf.37f556281dc86171ed21f830316aa1a0.jpg: top FID distances = [38.446258544921875]


  0%|          | 0/200 [00:00<?, ?it/s]

Image IMG_0412_jpg.rf.ad45d22bbd3064cf1c06329a08a148f7.jpg: top FID distances = [29.085603713989258]


  0%|          | 0/200 [00:00<?, ?it/s]

Image IMG_0414_jpg.rf.50494c852cf5da15f6ee6807ec15f7cd.jpg: top FID distances = [60.3425407409668]


  0%|          | 0/200 [00:00<?, ?it/s]

Image IMG_0414_jpg.rf.af8297adb59dcbfb0e7535c55debc144.jpg: top FID distances = [75.35189056396484]


  0%|          | 0/200 [00:00<?, ?it/s]

Image IMG_0446_jpg.rf.6988e41ab07ac2ed51b9a7e9836b58a3.jpg: top FID distances = [27.77134132385254]


  0%|          | 0/200 [00:00<?, ?it/s]

Image IMG_0446_jpg.rf.92e73abf35db21c383584680c7241301.jpg: top FID distances = [38.71853256225586]


  0%|          | 0/200 [00:00<?, ?it/s]

Image IMG_0447_jpg.rf.ac691c2add1fa692b147c38fc5683456.jpg: top FID distances = [28.56695556640625]


  0%|          | 0/200 [00:00<?, ?it/s]

Image IMG_0447_jpg.rf.c55985cc6a558d4a7779fdb7e1cc91ef.jpg: top FID distances = [33.28870391845703]


  0%|          | 0/200 [00:00<?, ?it/s]

Image IMG_0474_jpg.rf.3cd940097b57e7c4ae9843414dc74d90.jpg: top FID distances = [19.83914566040039]


  0%|          | 0/200 [00:00<?, ?it/s]

Image IMG_0474_jpg.rf.57243dc2e7bc802d0fddd230b1b673fe.jpg: top FID distances = [20.160791397094727]


  0%|          | 0/200 [00:00<?, ?it/s]

Image IMG_0499_jpg.rf.242b11226ea3771fd01a69b4fe211cf6.jpg: top FID distances = [47.0510139465332]


  0%|          | 0/200 [00:00<?, ?it/s]

Image IMG_0499_jpg.rf.de074b8115290623fb20c2e5d04c81de.jpg: top FID distances = [61.98786926269531]
