In [None]:
# Import necessary libraries
import fiftyone as fo
from PIL import Image
import torch
import torchvision.transforms as transforms
from torchvision.models import inception_v3
from src.data_loading import load_datasets, merge_datasets
from src.image_embeddings import get_inception_activations, calculate_fid, calculate_inception_score
import numpy as np

# Load datasets
real_dataset_1, real_dataset_2, syn_dataset = load_datasets()

# Merge real datasets
merged_real_dataset = merge_datasets(real_dataset_1, real_dataset_2)

# Tag datasets
for sample in real_dataset_1:
    sample["dataset_type"] = "real"
    sample.save()
for sample in real_dataset_2:
    sample["dataset_type"] = "real"
    sample.save()
for sample in syn_dataset:
    sample["dataset_type"] = "synthetic"
    sample.save()

# Merge datasets for visualization
embeddings_dataset = fo.Dataset(name="embeddings_dataset", persistent=True)
embeddings_dataset.add_samples(real_dataset_1)
embeddings_dataset.add_samples(real_dataset_2)
embeddings_dataset.add_samples(syn_dataset)

# Launch FiftyOne app
session = fo.launch_app(embeddings_dataset, remote=True)

# Load Inception model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
inception_model = inception_v3(pretrained=True, transform_input=False).to(device)
inception_model.fc = torch.nn.Identity()

# Get activations
real_activations = get_inception_activations(merged_real_dataset, inception_model, device)
syn_activations = get_inception_activations(syn_dataset, inception_model, device)

# Calculate FID
fid = calculate_fid(real_activations, syn_activations)
print(f"FID: {fid}")

# Calculate Inception Score
is_mean, is_std = calculate_inception_score(syn_activations)
print(f"Inception Score: {is_mean} ± {is_std}")
