In [5]:
import numpy as np
import torch
from torch.nn.functional import adaptive_avg_pool2d
from torchvision.models import inception_v3
from torchvision import transforms
from scipy.linalg import sqrtm
from torch.utils.data import Dataset, DataLoader, random_split
from PIL import Image
import os
import random
import csv
from tqdm import tqdm

# Function to calculate FID
def calculate_fid(real_images, generated_images, device):
    # Load the InceptionV3 model pre-trained on ImageNet
    model = inception_v3(pretrained=True, transform_input=True).to(device)
    model.fc = torch.nn.Identity()  # Remove the final classification layer
    model.eval()

    def get_activations(images):
        with torch.no_grad():
            # Resize and preprocess images
            images = adaptive_avg_pool2d(images, (299, 299))
            activations = model(images)
        return activations

    # Get the feature vectors for both sets of images
    real_features = get_activations(real_images).cpu().numpy()
    generated_features = get_activations(generated_images).cpu().numpy()

    # Calculate the mean and covariance of the features
    mu_real, sigma_real = np.mean(real_features, axis=0), np.cov(real_features, rowvar=False)
    mu_generated, sigma_generated = np.mean(generated_features, axis=0), np.cov(generated_features, rowvar=False)

    # Calculate the Fréchet distance
    diff = mu_real - mu_generated
    covmean, _ = sqrtm(sigma_real.dot(sigma_generated), disp=False)

    # Check for numerical issues
    if np.iscomplexobj(covmean):
        covmean = covmean.real

    fid = np.sum(diff**2) + np.trace(sigma_real + sigma_generated - 2 * covmean)
    return fid

class SingleClassImageFolder(Dataset):
    def __init__(self, root, transform=None):
        self.root = root
        self.transform = transform
        self.images = [os.path.join(root, file) for file in os.listdir(root) if file.endswith(('png', 'jpg', 'jpeg'))]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image

# Transform for the images
transform = transforms.Compose([
    transforms.Resize((299, 299)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Dataset directories to loop through
dataset_names = ['all_defect', 'double_crossing', 'gap_crossing', 'gap_double', 'no_defect', 'only_crossing', 'only_double', 'only_gap']

# Set the seed for reproducibility
seed_value = 42
random.seed(seed_value)
torch.manual_seed(seed_value)

# Prepare CSV to save FID values
csv_file = 'fid_values_dreambooth_txt2img.csv'

with open(csv_file, mode='w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(['Defect Type', 'FID Value'])  # Write header

    # Use tqdm to show progress bar
    for dataset_name in tqdm(dataset_names, desc="Calculating FID for datasets"):
        print(f"Processing dataset: {dataset_name}")
        
        # Load datasets from directories
        real_images_dir = f'traing_data_for_FID_filtered/{dataset_name}'
        generated_images_dir = f'dreambooth_txt2img_for_FID_resized/{dataset_name}'

        real_dataset = SingleClassImageFolder(root=real_images_dir, transform=transform)
        generated_dataset = SingleClassImageFolder(root=generated_images_dir, transform=transform)

        # Randomly select a subset of 100 images
        subset_size = 100
        real_subset, _ = random_split(real_dataset, [subset_size, len(real_dataset) - subset_size])
        generated_subset, _ = random_split(generated_dataset, [subset_size, len(generated_dataset) - subset_size])

        real_loader = DataLoader(real_subset, batch_size=subset_size, shuffle=False)
        generated_loader = DataLoader(generated_subset, batch_size=subset_size, shuffle=False)

        # Get real and generated images
        real_images = next(iter(real_loader))
        generated_images = next(iter(generated_loader))

        # Move images to the same device as the model
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        real_images = real_images.to(device)
        generated_images = generated_images.to(device)

        # Calculate FID
        fid_value = calculate_fid(real_images, generated_images, device)
        
        # Output the FID value for the current dataset
        print(f"FID for {dataset_name}: {fid_value}")
        
        # Write the dataset name and FID value to the CSV file
        writer.writerow([dataset_name, fid_value])

print(f"FID values have been saved to {csv_file}")


Calculating FID for datasets:   0%|                   | 0/8 [00:00<?, ?it/s]

Processing dataset: all_defect


Calculating FID for datasets:  12%|█▍         | 1/8 [00:24<02:51, 24.54s/it]

FID for all_defect: 82.89404468934539
Processing dataset: double_crossing


Calculating FID for datasets:  25%|██▊        | 2/8 [00:43<02:08, 21.39s/it]

FID for double_crossing: 128.7517346948915
Processing dataset: gap_crossing


Calculating FID for datasets:  38%|████▏      | 3/8 [01:05<01:47, 21.43s/it]

FID for gap_crossing: 57.58578675341671
Processing dataset: gap_double


Calculating FID for datasets:  50%|█████▌     | 4/8 [01:18<01:13, 18.41s/it]

FID for gap_double: 89.86304416336476
Processing dataset: no_defect


Calculating FID for datasets:  62%|██████▉    | 5/8 [01:29<00:46, 15.44s/it]

FID for no_defect: 67.86065848521105
Processing dataset: only_crossing


Calculating FID for datasets:  75%|████████▎  | 6/8 [01:38<00:26, 13.46s/it]

FID for only_crossing: 146.8109826031206
Processing dataset: only_double


Calculating FID for datasets:  88%|█████████▋ | 7/8 [02:08<00:18, 18.66s/it]

FID for only_double: 68.6564465421372
Processing dataset: only_gap


Calculating FID for datasets: 100%|███████████| 8/8 [02:32<00:00, 19.04s/it]

FID for only_gap: 87.90062747689208
FID values have been saved to fid_values_dreambooth_txt2img.csv



