In [1]:
import pickle
import random
import torch
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt

from dcgan import Generator, Discriminator
from dataset_loader import CustomDataset
from scipy.linalg import sqrtm
from torch.utils.data import DataLoader
from torcheval.metrics import FrechetInceptionDistance
from dcgan_with_embeddings import PathFoundationModel

In [2]:
DATASET_DIR = "../datasets/merged_embeddings/merged_dataset.pkl"
SLIDE_DIR = "../datasets/wsi"
DEVICE = torch.device("mps" if torch.mps.is_available() else "cpu")

In [None]:
netGoogle = PathFoundationModel(model_name="google/path-foundation")

In [None]:
gen_model = Generator(nz=100, ngf=64, nc=3).to(DEVICE)
gen_model.load_state_dict(torch.load("./dcgan_outputs/trained_models/DCGAN/generator_10.pth"))
gen_model.eval()

In [None]:
with open(DATASET_DIR, "rb") as f:
    train_dataset = pickle.load(f)

transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor()
])

fids = []
train_data = CustomDataset(train_dataset, slide_dir=SLIDE_DIR, transform=transform)
train_loader = DataLoader(train_data, batch_size=128, shuffle=True)

for i, data in enumerate(train_loader):
    real_images, _ = data
    real_images = real_images.to(DEVICE)
    
    with torch.no_grad():
        fake_images = gen_model(torch.randn(128, 100, 1, 1, device=DEVICE))

    fake_data = fake_images.clamp(0, 1)
    print("Real Data Max: ", real_images.max())
    print("Real Data Min: ", real_images.min())
    print("Fake Data Max: ", fake_data.max())
    print("Fake Data Min: ", fake_data.min())


    # Calculate FID
    fid = FrechetInceptionDistance()
    fid.update(real_images, is_real=True)
    fid.update(fake_data, is_real=False)
    fid_value = fid.compute()
    print("FID: ", fid_value)
    fids.append(fid_value)

print("Average FID: ", np.mean(fids))
print(len(fids))