# Google Drive Mount

In [10]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# Install dependencies

In [11]:
!pip install -q diffusers transformers accelerate ftfy lpips einops torchmetrics Pillow tqdm xformers torch-fidelity

# Imports

In [12]:
import torch
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
from pathlib import Path
from PIL import Image
from tqdm import tqdm
import lpips
from torchvision.transforms import ToTensor, Resize
from torchmetrics.image.fid import FrechetInceptionDistance
import torchvision.transforms as transforms
from peft import PeftModel
from torch.utils.data import Dataset, DataLoader

# Config

In [13]:
# Global config
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BASE_MODEL = "runwayml/stable-diffusion-v1-5"

LORA_ROOT = "/content/drive/MyDrive/Colab Notebooks/sd_lora"
LORA_WEIGHTS = f"{LORA_ROOT}/output/lora_weights"
TEST_IMAGES_DIR = f"{LORA_ROOT}/test/images"
TEST_CAPTIONS_DIR = f"{LORA_ROOT}/test/captions"
RESOLUTION = 512
BATCH_SIZE = 2
NUM_INFERENCE_STEPS = 50
SAVE_GENERATED_IMAGES = True
GENERATED_DIR = f"{LORA_ROOT}/generated_test_images"

# Prepare test dataset

In [14]:
class ImageCaptionDataset(Dataset):
    def __init__(self, images_dir, captions_dir, resolution=512):
        self.images_dir = Path(images_dir)
        self.captions_dir = Path(captions_dir)
        self.ids = sorted([p.stem for p in self.images_dir.glob("*.jpg")])
        self.resolution = resolution
        self.transform = transforms.Compose([
            transforms.Resize((resolution, resolution)),
            transforms.ToTensor(),
            transforms.Normalize([0.5]*3, [0.5]*3)  # scale [-1,1]
        ])

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        id_ = self.ids[idx]
        img = Image.open(self.images_dir / f"{id_}.jpg").convert("RGB")
        img = self.transform(img)
        caption = (self.captions_dir / f"{id_}.txt").read_text(encoding="utf-8").strip()
        return {"image": img, "caption": caption, "id": id_}

# Load dataset
test_dataset = ImageCaptionDataset(TEST_IMAGES_DIR, TEST_CAPTIONS_DIR, RESOLUTION)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)


# Load Stable Diffusion + LoRA

In [15]:
pipe = StableDiffusionPipeline.from_pretrained(
    BASE_MODEL,
    torch_dtype=torch.float16,
    safety_checker=None
).to(DEVICE)

# Load LoRA weights into UNet
pipe.unet = PeftModel.from_pretrained(pipe.unet, LORA_WEIGHTS).to(DEVICE)

# Use DPM solver for faster inference
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)

# Memory optimizations
pipe.enable_xformers_memory_efficient_attention()
pipe.enable_attention_slicing()

Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

You have disabled the safety checker for <class 'diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .


# Initialize metrics

In [16]:
# LPIPS perceptual similarity
lpips_alex = lpips.LPIPS(net='alex').to(DEVICE)

# FID metric
fid_metric = FrechetInceptionDistance(feature=2048).to(DEVICE)

# Transform helpers
to_tensor = ToTensor()
resize = Resize((RESOLUTION, RESOLUTION))

# Optional: create folder for generated images
if SAVE_GENERATED_IMAGES:
    Path(GENERATED_DIR).mkdir(parents=True, exist_ok=True)

# Store lpips results
lpips_scores  = []

Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.12/dist-packages/lpips/weights/v0.1/alex.pth


# Generate images and compute metrics

In [17]:
pipe.scheduler.set_timesteps(NUM_INFERENCE_STEPS)
pipe.enable_attention_slicing()  # reduce VRAM usage

for batch in tqdm(test_loader, desc="Testing"):
    captions = batch["caption"]
    gt_images = batch["image"].to(DEVICE)  # should be [-1,1]

    # Generate images
    with torch.autocast("cuda"):
        gen_images = pipe(
            captions,
            height=RESOLUTION,
            width=RESOLUTION,
            num_inference_steps=NUM_INFERENCE_STEPS
        ).images

    # Convert generated images to tensor [-1,1]
    gen_tensors = torch.stack([to_tensor(resize(img)) * 2 - 1 for img in gen_images]).to(DEVICE)

    # --- LPIPS per image ---
    lpips_score = lpips_alex(gen_tensors, gt_images).mean().item()
    lpips_scores.append(lpips_score)

    # --- FID update ---
    gt_images_fid = ((gt_images + 1) * 127.5).clamp(0, 255).to(torch.uint8)
    gen_images_fid = ((gen_tensors + 1) * 127.5).clamp(0, 255).to(torch.uint8)
    fid_metric.update(gt_images_fid, real=True)
    fid_metric.update(gen_images_fid, real=False)

    # --- Save generated images ---
    if SAVE_GENERATED_IMAGES:
        for img, id_ in zip(gen_images, batch["id"]):
            img.save(Path(GENERATED_DIR) / f"{id_}_gen.png")


Testing:   0%|          | 0/24 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing:   4%|▍         | 1/24 [00:31<11:54, 31.05s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing:   8%|▊         | 2/24 [01:03<11:35, 31.60s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing:  12%|█▎        | 3/24 [01:32<10:40, 30.52s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing:  17%|█▋        | 4/24 [02:01<09:59, 29.97s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing:  21%|██        | 5/24 [02:31<09:32, 30.15s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing:  25%|██▌       | 6/24 [03:00<08:56, 29.78s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing:  29%|██▉       | 7/24 [03:30<08:25, 29.74s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing:  33%|███▎      | 8/24 [03:59<07:53, 29.62s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing:  38%|███▊      | 9/24 [04:29<07:22, 29.49s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing:  42%|████▏     | 10/24 [04:58<06:50, 29.35s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing:  46%|████▌     | 11/24 [05:27<06:21, 29.33s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing:  50%|█████     | 12/24 [05:56<05:50, 29.17s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing:  54%|█████▍    | 13/24 [06:25<05:19, 29.06s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing:  58%|█████▊    | 14/24 [06:53<04:49, 28.99s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing:  62%|██████▎   | 15/24 [07:22<04:19, 28.86s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing:  67%|██████▋   | 16/24 [07:51<03:50, 28.87s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing:  71%|███████   | 17/24 [08:20<03:21, 28.85s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing:  75%|███████▌  | 18/24 [08:48<02:52, 28.74s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing:  79%|███████▉  | 19/24 [09:17<02:23, 28.77s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing:  83%|████████▎ | 20/24 [09:46<01:55, 28.93s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing:  88%|████████▊ | 21/24 [10:15<01:26, 28.82s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing:  92%|█████████▏| 22/24 [10:44<00:58, 29.01s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing:  96%|█████████▌| 23/24 [11:14<00:29, 29.09s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing: 100%|██████████| 24/24 [11:29<00:00, 28.72s/it]


# Compute final metrics

In [18]:
final_fid = fid_metric.compute().item()
avg_lpips = sum(lpips_scores )/len(lpips_scores )

print(f"Average LPIPS: {avg_lpips:.4f}")
print(f"FID score: {final_fid:.4f}")

Average LPIPS: 0.7760
FID score: 153.7604
