# Google Drive Mount

In [19]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# Install dependencies

In [20]:
!pip install -q diffusers transformers accelerate ftfy lpips einops torchmetrics Pillow tqdm xformers torch-fidelity

# Imports

In [21]:
import torch
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
from pathlib import Path
from PIL import Image
from tqdm import tqdm
from torchvision.transforms import ToTensor, Resize
from torchmetrics.image.fid import FrechetInceptionDistance
import torchvision.transforms as transforms
from peft import PeftModel
from torch.utils.data import Dataset, DataLoader
from transformers import CLIPProcessor, CLIPModel
from torch.nn import functional as F

# Config

In [22]:
# Device and base model
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BASE_MODEL = "runwayml/stable-diffusion-v1-5"

# Paths for LoRA weights
LORA_ROOT = "/content/drive/MyDrive/Colab Notebooks/sd_lora"
OUTPUT_DIR = f"{LORA_ROOT}/outputs/unet_only"
UNET_LORA_WEIGHTS = f"{OUTPUT_DIR}/final/unet"
TEXT_ENCODER_LORA_WEIGHTS = f"{OUTPUT_DIR}/final/text_encoder"

# Dataset paths for testing
TEST_IMAGES_DIR = f"{LORA_ROOT}/dataset/test/images"
TEST_CAPTIONS_DIR = f"{LORA_ROOT}/dataset/test/captions"

# Generation / testing parameters
RESOLUTION = 512
BATCH_SIZE = 2
NUM_INFERENCE_STEPS = 50          # DDIM / DPM solver steps for generation
SAVE_GENERATED_IMAGES = True

# Unified test output directory
TEST_OUTPUT_DIR = f"{LORA_ROOT}/testing/metrics_test/unet_only/final"
GENERATED_DIR = f"{TEST_OUTPUT_DIR}/generated_images"
Path(GENERATED_DIR).mkdir(parents=True, exist_ok=True)

# Check if LoRA weights for text encoder exist (optional)
USE_TEXT_ENCODER_LORA = Path(TEXT_ENCODER_LORA_WEIGHTS).exists()

# Prepare test dataset

In [23]:
class ImageCaptionDataset(Dataset):
    def __init__(self, images_dir, captions_dir, resolution=512):
        self.images_dir = Path(images_dir)
        self.captions_dir = Path(captions_dir)
        self.ids = sorted([p.stem for p in self.images_dir.glob("*.jpg")])
        self.resolution = resolution
        self.transform = transforms.Compose([
            transforms.Resize((resolution, resolution)),
            transforms.ToTensor(),
            transforms.Normalize([0.5]*3, [0.5]*3)  # scale [-1,1]
        ])

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        id_ = self.ids[idx]
        img = Image.open(self.images_dir / f"{id_}.jpg").convert("RGB")
        img = self.transform(img)
        caption = (self.captions_dir / f"{id_}.txt").read_text(encoding="utf-8").strip()
        return {"image": img, "caption": caption, "id": id_}

# Load dataset
test_dataset = ImageCaptionDataset(TEST_IMAGES_DIR, TEST_CAPTIONS_DIR, RESOLUTION)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)

# Load Stable Diffusion + LoRA

In [24]:
# Load base pipeline
pipe = StableDiffusionPipeline.from_pretrained(
    BASE_MODEL,
    torch_dtype=torch.float16,
    safety_checker=None
).to(DEVICE)

# Load LoRA weights for UNet
pipe.unet = PeftModel.from_pretrained(pipe.unet, UNET_LORA_WEIGHTS).to(DEVICE)

# Optionally load LoRA weights for Text Encoder
if USE_TEXT_ENCODER_LORA:
    pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, TEXT_ENCODER_LORA_WEIGHTS).to(DEVICE)
    print("Loaded LoRA weights for Text Encoder.")

# Use DPM solver for faster inference
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)

# Memory optimizations
pipe.enable_xformers_memory_efficient_attention()
pipe.enable_attention_slicing()

Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

You have disabled the safety checker for <class 'diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .


# Initialize metrics

In [25]:
# CLIP setup
CLIP_MODEL_NAME = "openai/clip-vit-large-patch14"
clip_model = CLIPModel.from_pretrained(CLIP_MODEL_NAME).to(DEVICE)
clip_processor = CLIPProcessor.from_pretrained(CLIP_MODEL_NAME)

# FID metric
fid_metric = FrechetInceptionDistance(feature=2048).to(DEVICE)

# Transform helpers
to_tensor = ToTensor()
resize = transforms.Resize((RESOLUTION, RESOLUTION))

# Optional: create folder for generated images
if SAVE_GENERATED_IMAGES:
    Path(GENERATED_DIR).mkdir(parents=True, exist_ok=True)

# Generate images and compute metrics

In [26]:


pipe.scheduler.set_timesteps(NUM_INFERENCE_STEPS)
pipe.enable_attention_slicing()  # reduce VRAM usage

clip_scores = []
clip_cos_sims = []

for batch in tqdm(test_loader, desc="Testing"):
    captions = batch["caption"]
    gt_images = batch["image"].to(DEVICE)  # [-1,1] scale

    # Generate images
    with torch.autocast("cuda"):
        gen_images = pipe(
            captions,
            height=RESOLUTION,
            width=RESOLUTION,
            num_inference_steps=NUM_INFERENCE_STEPS
        ).images

    # Convert generated images to tensor [-1,1]
    gen_tensors = torch.stack([to_tensor(resize(img)) * 2 - 1 for img in gen_images]).to(DEVICE)

    # ------------------ CLIPScore & cosine similarity ------------------
    # Preprocess for CLIP
    clip_inputs = clip_processor(
        text=captions,
        images=[Image.fromarray(((img.permute(1,2,0).cpu().numpy()+1)/2*255).astype("uint8")) for img in gen_tensors],
        return_tensors="pt",
        padding=True
    ).to(DEVICE)

    # Forward pass
    with torch.no_grad():
        clip_outputs = clip_model(**clip_inputs)
        image_embeds = F.normalize(clip_outputs.image_embeds, dim=-1)
        text_embeds = F.normalize(clip_outputs.text_embeds, dim=-1)

        # Cosine similarity
        cos_sim = (image_embeds * text_embeds).sum(dim=-1)
        clip_cos_sims.extend(cos_sim.cpu().tolist())

        # CLIPScore (scaled 0-100)
        clip_score = ((cos_sim + 1)/2 * 100)
        clip_scores.extend(clip_score.cpu().tolist())

    # ------------------ FID update ------------------
    gt_images_fid = ((gt_images + 1) * 127.5).clamp(0, 255).to(torch.uint8)
    gen_images_fid = ((gen_tensors + 1) * 127.5).clamp(0, 255).to(torch.uint8)
    fid_metric.update(gt_images_fid, real=True)
    fid_metric.update(gen_images_fid, real=False)

    # ------------------ Save generated images ------------------
    if SAVE_GENERATED_IMAGES:
        for img, id_ in zip(gen_images, batch["id"]):
            img.save(Path(GENERATED_DIR) / f"{id_}_gen.png")


Testing:   0%|          | 0/24 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing:   4%|▍         | 1/24 [00:27<10:32, 27.51s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing:   8%|▊         | 2/24 [00:54<09:54, 27.01s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing:  12%|█▎        | 3/24 [01:19<09:15, 26.44s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing:  17%|█▋        | 4/24 [01:45<08:44, 26.22s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing:  21%|██        | 5/24 [02:12<08:23, 26.50s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing:  25%|██▌       | 6/24 [02:40<08:01, 26.76s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing:  29%|██▉       | 7/24 [03:06<07:30, 26.49s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing:  33%|███▎      | 8/24 [03:31<07:00, 26.31s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing:  38%|███▊      | 9/24 [03:57<06:32, 26.19s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing:  42%|████▏     | 10/24 [04:23<06:05, 26.12s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing:  46%|████▌     | 11/24 [04:49<05:38, 26.07s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing:  50%|█████     | 12/24 [05:15<05:12, 26.05s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing:  54%|█████▍    | 13/24 [05:41<04:46, 26.03s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing:  58%|█████▊    | 14/24 [06:07<04:20, 26.02s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing:  62%|██████▎   | 15/24 [06:33<03:53, 25.98s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing:  67%|██████▋   | 16/24 [06:59<03:27, 25.98s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing:  71%|███████   | 17/24 [07:25<03:01, 25.95s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing:  75%|███████▌  | 18/24 [07:51<02:35, 25.95s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing:  79%|███████▉  | 19/24 [08:17<02:09, 25.95s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing:  83%|████████▎ | 20/24 [08:43<01:43, 25.95s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing:  88%|████████▊ | 21/24 [09:09<01:17, 25.97s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing:  92%|█████████▏| 22/24 [09:35<00:51, 25.99s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing:  96%|█████████▌| 23/24 [10:01<00:25, 25.99s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing: 100%|██████████| 24/24 [10:15<00:00, 25.63s/it]


# Compute final metrics

In [27]:
final_fid = fid_metric.compute().item()
avg_clip_score = sum(clip_scores) / len(clip_scores)
avg_cos_sim = sum(clip_cos_sims) / len(clip_cos_sims)

# Prepare content
metrics_content = f"""
FID score: {final_fid:.4f}
Average CLIPScore (0-100): {avg_clip_score:.2f}
Average CLIP cosine similarity (-1 to 1): {avg_cos_sim:.4f}
"""

# Print metrics
print(metrics_content.strip())

# Save metrics as TXT
metrics_txt_path = Path(GENERATED_DIR).parent / "metrics.txt"
with open(metrics_txt_path, "w") as f:
    f.write(metrics_content.strip())

print(f"Metrics saved to {metrics_txt_path}")

FID score: 165.0374
Average CLIPScore (0-100): 64.61
Average CLIP cosine similarity (-1 to 1): 0.2922
Metrics saved to /content/drive/MyDrive/Colab Notebooks/sd_lora/testing/metrics_test/unet_only/final/metrics.txt
