In [None]:
import os
import torch
import numpy as np
import torchvision.transforms as TF

from tqdm import tqdm
from PIL import Image
from glob import glob
from pytorch_fid.inception import InceptionV3

In [None]:
dims = 2048
device = "cuda"

In [None]:
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]

model = InceptionV3([block_idx]).to(device)

In [None]:
class ImagePathDataset(torch.utils.data.Dataset):
    def __init__(self, data_root, files, transforms=None):
        self.data_root = data_root
        self.file_names = files
        self.transforms = transforms

    def __len__(self):
        return len(self.file_names)

    def __getitem__(self, i):
        path = self.file_names[i]
        gt_path = os.path.join(self.data_root, "gts", path)
        img_gt = Image.open(gt_path).convert('RGB')
        if self.transforms is not None:
            img_gt = self.transforms(img_gt)
            
        pred_path = os.path.join(self.data_root, "generated", path)
        img_pred = Image.open(pred_path).convert('RGB')
        if self.transforms is not None:
            img_pred = self.transforms(img_pred)
        return img_gt, img_pred 

In [None]:
ex_name = "art_newbreeder_allparents_bs8_disc"
data_root = f"../gen_images/{ex_name}"
file_names = [i.split('/')[-1] for i in glob(f"../gen_images/{ex_name}/gts/*.jpg")]

In [None]:
dataset = ImagePathDataset(data_root, file_names, transforms=TF.ToTensor())
    

In [None]:
dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=8,
                                             shuffle=False,
                                             drop_last=False,
                                             num_workers=4)

In [None]:
len(dataset)

In [None]:
from sklearn.metrics.pairwise import cosine_similarity

# cos_sim=cosine_similarity(A.reshape(1,-1),B.reshape(1,-1))

In [None]:
avg_cos_sim = 0

for (batch_gt, batch_pred) in tqdm(dataloader):
    batch_gt = batch_gt.to(device)
    batch_pred = batch_pred.to(device)

    with torch.no_grad():
        pred_gt = model(batch_gt)[0]
        pred_pred = model(batch_pred)[0]
        
    # If model output is not scalar, apply global spatial average pooling.
    # This happens if you choose a dimensionality not equal 2048.
    if pred_gt.size(2) != 1 or pred_gt.size(3) != 1:
        pred_gt = adaptive_avg_pool2d(pred_gt, output_size=(1, 1))

    pred_gt = pred_gt.squeeze(3).squeeze(2).cpu().numpy()
    
    
    # If model output is not scalar, apply global spatial average pooling.
    # This happens if you choose a dimensionality not equal 2048.
    if pred_pred.size(2) != 1 or pred_pred.size(3) != 1:
        pred_pred = adaptive_avg_pool2d(pred_pred, output_size=(1, 1))

    pred_pred = pred_pred.squeeze(3).squeeze(2).cpu().numpy()
    
    for g,p in zip(pred_gt, pred_pred):
        avg_cos_sim += cosine_similarity(g.reshape(1,-1), p.reshape(1, -1))
    
print(f"{avg_cos_sim[0][0] / dataset.__len__():.3f}")