In [1]:
from pathlib import Path
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchmetrics.image.fid import FrechetInceptionDistance

In [2]:
from config_file import config

In [None]:
class ImageFolderWithoutLabels(Dataset):
    def __init__(self, image_dir, transform=None):
        self.paths = list(Path(image_dir).glob("*.jpg"))
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img_path = self.paths[idx]
        img = Image.open(img_path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        # Преобразование в uint8
        img = (img * 255).clamp(0, 255).byte()
        return img

In [None]:
def get_fid(real_dir: Path, fake_dir: Path):
    transform = transforms.Compose([
        transforms.Resize((299, 299)),
        transforms.ToTensor(),  # вернёт float32, но мы преобразуем ниже
    ])
    
    # Датасеты
    real_dataset = ImageFolderWithoutLabels(real_dir, transform=transform)
    fake_dataset = ImageFolderWithoutLabels(fake_dir, transform=transform)

    # Дата лоадеры
    real_loader = DataLoader(real_dataset, batch_size=32)
    fake_loader = DataLoader(fake_dataset, batch_size=32)

    # Метрика FID
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    fid = FrechetInceptionDistance(feature=2048).to(device)

    # Обновляем метрику
    for imgs in real_loader:
        fid.update(imgs.to(device), real=True)

    for imgs in fake_loader:
        fid.update(imgs.to(device), real=False)

    # Результат
    fid_score = fid.compute()
    return fid_score

In [5]:
# Пути к папкам
real_dir = Path(config.IMAGES_PATH / "validation")
fake_dir = Path(config.IMAGES_PATH / "sd-generated")

In [6]:
orig_fid = get_fid(real_dir, fake_dir)
print(f"📊 FID: {orig_fid.item():.4f}")

📊 FID: 285.6114


In [None]:
import os
real_dir = Path(config.IMAGES_PATH / "validation")

generated_image_dirs = [Path(name) for name in os.listdir(config.IMAGES_PATH / '512x512') if name.startswith('sd_trained')]
generated_image_dirs

[WindowsPath('sd_trained_unet_epoch_110.pt'),
 WindowsPath('sd_trained_unet_epoch_50.pt'),
 WindowsPath('sd_trained_unet_epoch_80.pt')]

In [None]:
# size = 512
# transform = transforms.Compose([
#     transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC),
#     transforms.CenterCrop(size),
# ])

# fake_images = os.listdir(config.IMAGES_PATH / '512x512' / generated_image_dirs[0])
# real_images = os.listdir(config.IMAGES_PATH / 'validation')

# images_to_validate = list(set(fake_images) & set(real_images))

# for filename in images_to_validate:
#     img = transform(Image.open(config.IMAGES_PATH / 'validation' / filename))
#     img.save(config.IMAGES_PATH / 'real-validation' / filename)

In [19]:
for image_dir in generated_image_dirs:
    fid = get_fid(real_dir, config.IMAGES_PATH / '512x512' / image_dir)

    print(f"📊 FID for {image_dir}: {fid.item():.4f}")

📊 FID for sd_trained_unet_epoch_110.pt: 307.3285
📊 FID for sd_trained_unet_epoch_50.pt: 294.8206
📊 FID for sd_trained_unet_epoch_80.pt: 299.6521


📊 FID for sd_trained_unet_epoch_10.pt: 344.3788

📊 FID for sd_trained_unet_epoch_110.pt: 295.6270

📊 FID for sd_trained_unet_epoch_50.pt: 284.7102

📊 FID for sd_trained_unet_epoch_80.pt: 295.8654