In [1]:
#%load_ext autoreload
#%autoreload 2

#import os
#os.chdir("../")

#%pylab inline
#pylab.rcParams['figure.figsize'] = (10, 5)

In [2]:
import torchvision
from torchvision.models.inception import model_urls
import numpy as np
import scipy.misc
from torchvision import models, transforms
from PIL import Image
import torch
from torch.autograd import Variable
import torch.nn.functional as F
from tqdm import tqdm


class InceptionScore:
    
    def __init__(self, gpu = True):
        """
        Class for the computation of the inception score
        
        Parameters
        ----------
        gpu : bool
            if True:  computation is executed on the gpu
            if False: computation is executed on the cpu
        """
        
        # set global vars
        self.gpu = gpu
        
        # load model
        try:
            self.incept = torchvision.models.inception_v3(pretrained=True)
        except:
            name = 'inception_v3_google'
            model_urls[name] = model_urls[name].replace('https://', 'http://')
            self.incept = torchvision.models.inception_v3(pretrained=True)
        self.incept.training = False
        self.incept.transform_input = False
        if self.gpu:
            self.incept = self.incept.cuda()
        self.incept.eval()
    
        # init data transformer
        normalize = transforms.Normalize(
           mean=[0.485, 0.456, 0.406],
           std=[0.229, 0.224, 0.225]
        )
        self.preprocess = transforms.Compose([
           transforms.ToTensor(),
           normalize
        ])
        
    
    def score(self, imgs_fake, imgs_real, batch_size=32, splits=10):
        """
        Function to compute the inception score
        
        Parameters
        ----------
        imgs : numpy array
            array of the shape (N, X, Y, C)
        batch_size : int
            batch size for the prediction with the inception net
        splits : int
            The inception score is computed for a package of images.
            The variable 'splits' defines the number of these packages.
            Multiple computations of the score (for each package one) are 
            needed to compute a standard diviation (error) for the final
            score.
        """
        
        # preprocess fake images
        if imgs_fake.shape[0] != 299 or imgs_fake.shape[1] != 299:
            imgs_fake = [scipy.misc.imresize(img, (299, 299)) for img in imgs_fake]
            imgs_fake = np.array(imgs_fake)
        n_batches = 1 + (len(imgs_fake) / batch_size)
        batches = np.array_split(imgs_fake, n_batches)
        
        # get prediction vectors of inception net for fake imgs
        preds_fake = []
        for batch in tqdm(batches):
            imgs = [Image.fromarray(img) for img in batch]
            imgs = torch.stack([self.preprocess(img) for img in imgs])
            if self.gpu:
                imgs = imgs.cuda()
            imgs = Variable(imgs)
            pred = self.incept(imgs)
            pred = F.softmax(pred)
            preds_fake.append(pred.data.cpu().numpy())    
        preds_fake = np.concatenate(preds_fake)
        self. preds_fake = preds_fake
        
        # preprocess real images
        if imgs_real.shape[0] != 299 or imgs_real.shape[1] != 299:
            imgs_real = [scipy.misc.imresize(img, (299, 299)) for img in imgs_real]
            imgs_real = np.array(imgs_real)
        n_batches = 1 + (len(imgs_real) / batch_size)
        batches = np.array_split(imgs_real, n_batches)
        
        # get prediction vectors of inception net for fake imgs
        preds_real = []
        for batch in tqdm(batches):
            imgs = [Image.fromarray(img) for img in batch]
            imgs = torch.stack([self.preprocess(img) for img in imgs])
            if self.gpu:
                imgs = imgs.cuda()
            imgs = Variable(imgs)
            pred = self.incept(imgs)
            pred = F.softmax(pred)
            preds_real.append(pred.data.cpu().numpy())    
        preds_real = np.concatenate(preds_real)
        self.preds_real = preds_real
        
        # compute inception score
        scores = []
        for i in range(splits):
            part_fake = preds_fake[(i * preds_fake.shape[0] // splits): \
                         ((i + 1) * preds_fake.shape[0] // splits), :]
            part_real = preds_real[(i * preds_real.shape[0] // splits): \
                         ((i + 1) * preds_real.shape[0] // splits), :]
            kl = part * (np.log(part_fake) - \
                         np.log(np.expand_dims(np.mean(part_real, 0), 0)))
            kl = np.mean(np.sum(kl, 1))
            kl = np.exp(kl)
            kl = kl - (part * (np.log(np.expand_dims(np.mean(part_fake, 0), 0)) - \
                               np.log(np.expand_dims(np.mean(part_real, 0), 0))))

            scores.append(kl)
            
        return np.mean(scores), np.std(scores)

In [3]:
trainset = torchvision.datasets.CIFAR10(root = '/tmp', download=True)
x_real = trainset.train_data
n_pixels = x_real.shape[0]*x_real.shape[1]*x_real.shape[2]*x_real.shape[3]
noise = np.random.normal(127, 3, n_pixels).reshape(x_real.shape)
x_fake = x_real.copy() + noise

incept_score = InceptionScore()
mean, std = incept_score.score(x_fake, x_real)

  m.weight.data.copy_(values)
100%|██████████| 1563/1563 [05:12<00:00,  5.13it/s]
100%|██████████| 1563/1563 [05:14<00:00,  4.99it/s]

Downloading http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /tmp/cifar-10-python.tar.gz





NameError: global name 'part' is not defined

In [4]:
preds_fake = incept_score.preds_fake
preds_real = incept_score.preds_real

In [6]:
# compute inception score
splits = 10

scores = []
for i in range(splits):
    part_fake = preds_fake[(i * preds_fake.shape[0] // splits): \
                 ((i + 1) * preds_fake.shape[0] // splits), :]
    part_real = preds_real[(i * preds_real.shape[0] // splits): \
                 ((i + 1) * preds_real.shape[0] // splits), :]
    
    p_star = np.expand_dims(np.mean(part_fake, 0), 0)
    p      = np.expand_dims(np.mean(part_real, 0), 0)
    
    kl = part_fake * (np.log(part_fake) - np.log(p))
    kl = np.mean(np.sum(kl, 1))
    kl = kl - (p_star * (np.log(p_star) - np.log(p))
    kl = np.exp(kl)

    scores.append(kl)

In [8]:
np.mean(scores), np.std(scores)

(11.251804, 0.17426911)