In [None]:
import pickle
import torch
import torch.nn as nn
import torchvision.models as models
import numpy as np
import matplotlib.pyplot as plt
import logging

from sklearn.manifold import TSNE
from scipy.linalg import sqrtm
from torch.utils.data import DataLoader
from datasets import load_dataset
from torchvision import transforms
from utils.dataset_loader import CustomDatasetFromSlide, CustomDatasetWithGenerated

In [None]:
SLIDE_DIR = "/home/cilem/Lfstorage/wsis"
DEVICE = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
GENERATED_DIR = "../diffusion/generated_images/generated_images_t2i.pkl"

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s', datefmt='%d-%b-%y %H:%M:%S', filename='fid_t2i.log')
logging.info(f"Device: {DEVICE}")

In [None]:
train_dataset = load_dataset("Cilem/mixed-histopathology-512")

logging.info(f"Train dataset size: {len(train_dataset)}")
generated_dataset = pickle.load(open(GENERATED_DIR, "rb"))

In [None]:
transform = transforms.Compose([transforms.Resize((299, 299)),
                                transforms.ToTensor(),
                                transforms.Normalize([0.5], [0.5])])

logging.info(f"Transforms: {transform}")

fids = []
train_data = CustomDatasetFromSlide(train_dataset, slide_dir=SLIDE_DIR, transform=transform)
train_loader = DataLoader(train_data, batch_size=32, shuffle=False)

generated_data = CustomDatasetWithGenerated(generated_dataset, transform=transform)
gen_loader = DataLoader(generated_data, batch_size=32, shuffle=False)

In [None]:
class InceptionV3(nn.Module):
    def __init__(self):
        super(InceptionV3, self).__init__()
        self.inception = models.inception_v3(weights="DEFAULT")
        #self.inception.fc = nn.Identity()
    
    def forward(self, x):
        x = self.inception(x)
        if isinstance(x, tuple):
            x = x[0]
        return x

In [None]:
model = InceptionV3().eval().to(DEVICE)
logging.info(f"Model: {model}")

In [None]:
def get_features(dataloader, model, key):
    features = []
    with torch.no_grad():
        for images in dataloader:
            images = images[key].to(DEVICE)
            feat = model(images)
            features.append(feat.cpu().numpy())
    return np.concatenate(features, axis=0)

In [None]:
train_features = get_features(train_loader, model, key="image")
generated_features = get_features(gen_loader, model, key="generated")

In [None]:
def calculate_fid(mu1, sigma1, mu2, sigma2):
    diff = mu1 - mu2
    covmean = sqrtm(sigma1 @ sigma2)
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    fid = diff @ diff + np.trace(sigma1 + sigma2 - 2 * covmean)
    return fid

In [None]:
mu_train, sigma_train = train_features.mean(axis=0), np.cov(train_features, rowvar=False)
mu_generated, sigma_generated = generated_features.mean(axis=0), np.cov(generated_features, rowvar=False)

logging.info(f"mu_train: {mu_train}, sigma_train: {sigma_train}")
logging.info(f"mu_generated: {mu_generated}, sigma_generated: {sigma_generated}")

In [None]:
# FID skoru hesaplama
fid_score = calculate_fid(mu_train, sigma_train, mu_generated, sigma_generated)
print(f'FID Score: {fid_score}')
logging.info(f"FID Score: {fid_score}")

In [None]:
# TSNE ile görselleştirme
all_features = np.vstack([train_features, generated_features])
labels = np.array([0] * len(train_features) + [1] * len(generated_features))

tsne = TSNE(n_components=2, perplexity=30, random_state=42)
features_2d = tsne.fit_transform(all_features)

In [None]:
fig = plt.figure(figsize=(8,6))
plt.scatter(features_2d[labels == 0, 0], features_2d[labels == 0, 1], label='Real Data', alpha=0.5)
plt.scatter(features_2d[labels == 1, 0], features_2d[labels == 1, 1], label='Generated Data', alpha=0.5)
plt.legend()
plt.title("TSNE Visualization of Feature Distributions Text2Image")
plt.show()
fig.savefig("tsne_text2image.png")

In [None]:
from torch.nn import functional as F
def inception_score(dataloader, model, key):

    preds = []
    with torch.no_grad():
        for batch in dataloader:
            batch = batch[key].to(DEVICE)
            logits = model(batch)
            probas = F.softmax(logits, dim=1)
            preds.append(probas.cpu())

    preds = torch.cat(preds, dim=0)
    p_y = preds.mean(dim=0, keepdim=True)

    # KL Divergence hesapla
    kl_div = preds * (torch.log(preds) - torch.log(p_y))
    kl_div = kl_div.sum(dim=1) 
    kl_mean = kl_div.mean().item()
   
    IS = torch.exp(kl_mean).item()

    return IS

In [None]:
is_score = inception_score(gen_loader, model, key="generated")
print(f'Inception Score: {is_score}')
logging.info(f"Inception Score: {is_score}")