In [84]:
import os
import torch
import cv2
from PIL import Image
from torch import nn
from torch.autograd import Variable
from torch.nn import functional as F
import torch.utils.data

from torchvision.models.inception import inception_v3
import torchvision.transforms as transforms

import numpy as np
from scipy.stats import entropy

def inception_score(imgs, cuda=True, batch_size=32, resize=False, splits=1):
    """Computes the inception score of the generated images imgs
    imgs -- Torch dataset of (3xHxW) numpy images normalized in the range [-1, 1]
    cuda -- whether or not to run on GPU
    batch_size -- batch size for feeding into Inception v3
    splits -- number of splits
    """
    N = len(imgs)

    assert batch_size > 0
    assert N > batch_size

    # Set up dtype
    if cuda:
        dtype = torch.cuda.FloatTensor
    else:
        if torch.cuda.is_available():
            print("WARNING: You have a CUDA device, so you should probably set cuda=True")
        dtype = torch.FloatTensor

    # Set up dataloader
    dataloader = torch.utils.data.DataLoader(imgs, batch_size=batch_size)

    # Load inception model
    inception_model = inception_v3(pretrained=True, transform_input=False).type(dtype)
    inception_model.eval();
    up = nn.Upsample(size=(299, 299), mode='bilinear').type(dtype)
    def get_pred(x):
        if resize:
            x = up(x)
        x = inception_model(x)
        print(x.size())
        return F.softmax(x).data.cpu().numpy()

    # Get predictions
    preds = np.zeros((N, 1000))

    for i, batch in enumerate(dataloader, 0):
        print(batch.shape)
        batch = batch.type(dtype)
        batchv = Variable(batch)
        batch_size_i = batch.size()[0]
        pred_result = get_pred(batchv)
#         print('pred_result', pred_result)
        preds[i*batch_size:i*batch_size + batch_size_i] = pred_result

    # Now compute the mean kl-div
    split_scores = []
    
    print('preds', preds)
    print('shape', preds.shape)
    print(splits)
    for k in range(N):
#         part = preds[k * (N // splits): (k+1) * (N // splits), :]
        part = preds[k:k+1, :]
        print('part', part)
        py = np.mean(part, axis=0)
        scores = []
        for i in range(part.shape[0]):
            pyx = part[i, :]
            scores.append(entropy(pyx, py))
            print('scores', scores)
        split_scores.append(np.exp(np.mean(scores)))
        print(split_scores)
    return np.mean(split_scores), np.std(split_scores)

def run_inception_model(input_path):
    class IgnoreLabelDataset(torch.utils.data.Dataset):
        def __init__(self, input_path, fake_B):
            self.input_path = input_path
            self.fake_B = fake_B

        def __getitem__(self, index):
            fake_item = Image.open(os.path.join(self.input_path,self.fake_B[index]))
            transform_step = self.get_transform()
            fake = transform_step(fake_item)
            return fake
        def __len__(self):
            return len(self.fake_B)
        
        def get_transform(self):
            transform_step=[]
#             transform_step.append(transforms.Scale(32))
            transform_step.append(transforms.ToTensor())
            transform_step.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
            return transforms.Compose(transform_step)

    
    fake_B = []
    for fname in sorted(os.listdir(input_path)):
        if 'fake' in fname:
            fake_B.append(fname)
    print(fake_B)
#     A_paths.sort(key=lambda x: int(x.rstrip("_gt.tif").split("/")[-1]))
    IgnoreLabelDataset(input_path, fake_B)

    print ("Calculating Inception Score...")
    print (inception_score(IgnoreLabelDataset(input_path, fake_B), cuda=True, batch_size=2, resize=True, splits=10))

In [85]:
input_path = '/media/jacktang/Work/USYD/Research/Deep_Learning/GAN/pytorch-CycleGAN-and-pix2pix/checkpoints/2d_neuron_pix2pix/test_latest/images/just_test'
run_inception_model(input_path)

['epoch148_real_A_fake_B_real_B (3rd copy).png', 'epoch148_real_A_fake_B_real_B (4th copy).png', 'epoch148_real_A_fake_B_real_B (another copy).png', 'epoch148_real_A_fake_B_real_B (copy).png', 'epoch148_real_A_fake_B_real_B.png']
Calculating Inception Score...
torch.Size([2, 3, 128, 384])
torch.Size([2, 1000])
torch.Size([2, 3, 128, 384])
torch.Size([2, 1000])
torch.Size([1, 3, 128, 384])
torch.Size([1, 1000])
preds [[0.00068604 0.00067171 0.00035894 ... 0.00015238 0.00034202 0.00255238]
 [0.00068604 0.00067171 0.00035894 ... 0.00015238 0.00034202 0.00255238]
 [0.00068604 0.00067171 0.00035894 ... 0.00015238 0.00034202 0.00255238]
 [0.00068604 0.00067171 0.00035894 ... 0.00015238 0.00034202 0.00255238]
 [0.00068604 0.00067171 0.00035894 ... 0.00015238 0.00034202 0.00255238]]
shape (5, 1000)
10
part [[6.86044863e-04 6.71706861e-04 3.58937250e-04 6.30160619e-04
  7.63653254e-04 6.57825149e-04 2.74076825e-04 5.41864603e-04
  3.67724046e-04 5.24537114e-04 1.73393855e-04 1.60900119e-04
  1.

