In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import torch.optim as optim
import torchvision
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import datasets, transforms
import torchvision.utils as vutils
from torchvision.models.inception import inception_v3
from scipy import linalg
from sklearn.metrics.pairwise import polynomial_kernel

import matplotlib.pyplot as plt
import numpy as np
import random
from tqdm import tqdm 

from scipy.stats import entropy

In [2]:
class DCGAN(nn.Module):
    def __init__(self, nz, ngf):
        super(DCGAN, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(nz, ngf*8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf*8),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf*8, ngf*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf*4),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(ngf*4, ngf*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf*2),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(ngf*2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(ngf, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )
    def forward(self, x):
        return self.main(x)

In [3]:
def Inception_score(inception_model, gan_model, test_dl=None, fixed_z=None):
    preds = []
    if fixed_z != None:
        for i in range(fixed_z.shape[0]):
            fakes = gan_model(fixed_z[i])
            pred = get_pred(up(fakes), inception_model)
            preds.append(pred)
        preds = np.array(preds).reshape(-1,pred.shape[1])
    if test_dl != None:
        for datas, targets in tqdm(test_dl):
            if datas.shape[0] < TEST_BATCH_SIZE:
                break
            pred = get_pred(up(datas).to(device), inception_model)
            preds.append(pred)
        preds = np.array(preds)
        preds = preds.reshape(-1,pred.shape[1])
        
    
    py = preds.mean(0)
    scores = []
    for i in range(preds.shape[0]):
        scores.append(entropy(preds[i], py))
    inception_score = np.exp(np.mean(scores))

    return inception_score

In [4]:
def calculate_activation_statistics(act):
    mu = np.mean(act, axis=0)
    sigma = np.cov(act, rowvar=False)
    return mu, sigma

def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
    mu1 = np.atleast_1d(mu1)
    mu2 = np.atleast_1d(mu2)

    sigma1 = np.atleast_2d(sigma1)
    sigma2 = np.atleast_2d(sigma2)

    assert mu1.shape == mu2.shape, \
        'Training and test mean vectors have different lengths'
    assert sigma1.shape == sigma2.shape, \
        'Training and test covariances have different dimensions'

    diff = mu1 - mu2

    # Product might be almost singular
    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
    if not np.isfinite(covmean).all():
        msg = ('fid calculation produces singular product; '
               'adding %s to diagonal of cov estimates') % eps
        print(msg)
        offset = np.eye(sigma1.shape[0]) * eps
        covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))

    # Numerical error might give slight imaginary component
    if np.iscomplexobj(covmean):
        if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
            m = np.max(np.abs(covmean.imag))
            raise ValueError('Imaginary component {}'.format(m))
        covmean = covmean.real

    tr_covmean = np.trace(covmean)

    return (diff.dot(diff) + np.trace(sigma1) +
            np.trace(sigma2) - 2 * tr_covmean)

def FID(inception_model, gan_model, test_dl, fixed_z):
    bootstrap=True
    n_bootstraps=10
    
    real_preds = []
    if test_dl != None:
        for datas, targets in tqdm(test_dl):
            if datas.shape[0] < TEST_BATCH_SIZE:
                break
            pred = get_pred(up(datas).to(device), inception_model, fid=True)
            real_preds.append(pred)
        real_preds = np.array(real_preds).reshape(-1,pred.shape[1])
    fake_preds = []
    if fixed_z != None:
        for i in range(fixed_z.shape[0]):
            fakes = gan_model(fixed_z[i])
            pred = get_pred(up(fakes), inception_model, fid=True)
            fake_preds.append(pred)
        fake_preds = np.array(fake_preds).reshape(-1,pred.shape[1])
    
    n_bootstraps = n_bootstraps if bootstrap else 1
    fid_values = np.zeros((n_bootstraps))
    with tqdm(range(n_bootstraps), desc='FID') as bar:
        for i in bar:
            act1_bs = real_preds[np.random.choice(real_preds.shape[0], real_preds.shape[0], replace=True)]
            act2_bs = fake_preds[np.random.choice(fake_preds.shape[0], fake_preds.shape[0], replace=True)]
            m1, s1 = calculate_activation_statistics(act1_bs)
            m2, s2 = calculate_activation_statistics(act2_bs)
            fid_values[i] = calculate_frechet_distance(m1, s1, m2, s2)
            bar.set_postfix({'mean': fid_values[:i+1].mean()})

    return fid_values.mean(), fid_values.std()

In [5]:
def polynomial_mmd_averages(codes_g, codes_r, n_subsets=50, subset_size=1000,
                            ret_var=True, output=sys.stdout, **kernel_args):
    m = min(codes_g.shape[0], codes_r.shape[0])
    mmds = np.zeros(n_subsets)
    if ret_var:
        vars = np.zeros(n_subsets)
    choice = np.random.choice

    with tqdm(range(n_subsets), desc='MMD', file=output) as bar:
        for i in bar:
            g = codes_g[choice(len(codes_g), subset_size, replace=False)]
            r = codes_r[choice(len(codes_r), subset_size, replace=False)]
            o = polynomial_mmd(g, r, **kernel_args, var_at_m=m, ret_var=ret_var)
            if ret_var:
                mmds[i], vars[i] = o
            else:
                mmds[i] = o
            bar.set_postfix({'mean': mmds[:i+1].mean()})
    return (mmds, vars) if ret_var else mmds
def polynomial_mmd(codes_g, codes_r, degree=3, gamma=None, coef0=1,
                   var_at_m=None, ret_var=True):
    # use  k(x, y) = (gamma <x, y> + coef0)^degree
    # default gamma is 1 / dim
    X = codes_g
    Y = codes_r

    K_XX = polynomial_kernel(X, degree=degree, gamma=gamma, coef0=coef0)
    K_YY = polynomial_kernel(Y, degree=degree, gamma=gamma, coef0=coef0)
    K_XY = polynomial_kernel(X, Y, degree=degree, gamma=gamma, coef0=coef0)

    return _mmd2_and_variance(K_XX, K_XY, K_YY,
                              var_at_m=var_at_m, ret_var=ret_var)
def _sqn(arr):
    flat = np.ravel(arr)
    return flat.dot(flat)
def _mmd2_and_variance(K_XX, K_XY, K_YY, unit_diagonal=False,
                       mmd_est='unbiased', block_size=1024,
                       var_at_m=None, ret_var=True):
    # based on
    # https://github.com/dougalsutherland/opt-mmd/blob/master/two_sample/mmd.py
    # but changed to not compute the full kernel matrix at once
    m = K_XX.shape[0]
    assert K_XX.shape == (m, m)
    assert K_XY.shape == (m, m)
    assert K_YY.shape == (m, m)
    if var_at_m is None:
        var_at_m = m

    # Get the various sums of kernels that we'll use
    # Kts drop the diagonal, but we don't need to compute them explicitly
    if unit_diagonal:
        diag_X = diag_Y = 1
        sum_diag_X = sum_diag_Y = m
        sum_diag2_X = sum_diag2_Y = m
    else:
        diag_X = np.diagonal(K_XX)
        diag_Y = np.diagonal(K_YY)

        sum_diag_X = diag_X.sum()
        sum_diag_Y = diag_Y.sum()

        sum_diag2_X = _sqn(diag_X)
        sum_diag2_Y = _sqn(diag_Y)

    Kt_XX_sums = K_XX.sum(axis=1) - diag_X
    Kt_YY_sums = K_YY.sum(axis=1) - diag_Y
    K_XY_sums_0 = K_XY.sum(axis=0)
    K_XY_sums_1 = K_XY.sum(axis=1)

    Kt_XX_sum = Kt_XX_sums.sum()
    Kt_YY_sum = Kt_YY_sums.sum()
    K_XY_sum = K_XY_sums_0.sum()

    if mmd_est == 'biased':
        mmd2 = ((Kt_XX_sum + sum_diag_X) / (m * m)
                + (Kt_YY_sum + sum_diag_Y) / (m * m)
                - 2 * K_XY_sum / (m * m))
    else:
        assert mmd_est in {'unbiased', 'u-statistic'}
        mmd2 = (Kt_XX_sum + Kt_YY_sum) / (m * (m-1))
        if mmd_est == 'unbiased':
            mmd2 -= 2 * K_XY_sum / (m * m)
        else:
            mmd2 -= 2 * (K_XY_sum - np.trace(K_XY)) / (m * (m-1))

    if not ret_var:
        return mmd2

    Kt_XX_2_sum = _sqn(K_XX) - sum_diag2_X
    Kt_YY_2_sum = _sqn(K_YY) - sum_diag2_Y
    K_XY_2_sum = _sqn(K_XY)

    dot_XX_XY = Kt_XX_sums.dot(K_XY_sums_1)
    dot_YY_YX = Kt_YY_sums.dot(K_XY_sums_0)

    m1 = m - 1
    m2 = m - 2
    zeta1_est = (
        1 / (m * m1 * m2) * (
            _sqn(Kt_XX_sums) - Kt_XX_2_sum + _sqn(Kt_YY_sums) - Kt_YY_2_sum)
        - 1 / (m * m1)**2 * (Kt_XX_sum**2 + Kt_YY_sum**2)
        + 1 / (m * m * m1) * (
            _sqn(K_XY_sums_1) + _sqn(K_XY_sums_0) - 2 * K_XY_2_sum)
        - 2 / m**4 * K_XY_sum**2
        - 2 / (m * m * m1) * (dot_XX_XY + dot_YY_YX)
        + 2 / (m**3 * m1) * (Kt_XX_sum + Kt_YY_sum) * K_XY_sum
    )
    zeta2_est = (
        1 / (m * m1) * (Kt_XX_2_sum + Kt_YY_2_sum)
        - 1 / (m * m1)**2 * (Kt_XX_sum**2 + Kt_YY_sum**2)
        + 2 / (m * m) * K_XY_2_sum
        - 2 / m**4 * K_XY_sum**2
        - 4 / (m * m * m1) * (dot_XX_XY + dot_YY_YX)
        + 4 / (m**3 * m1) * (Kt_XX_sum + Kt_YY_sum) * K_XY_sum
    )
    var_est = (4 * (var_at_m - 2) / (var_at_m * (var_at_m - 1)) * zeta1_est
               + 2 / (var_at_m * (var_at_m - 1)) * zeta2_est)

    return mmd2, var_est

def KID(inception_model, gan_model, test_dl, fixed_z):
    real_preds = []
    if test_dl != None:
        for datas, targets in tqdm(test_dl):
            if datas.shape[0] < TEST_BATCH_SIZE:
                break
            pred = get_pred(up(datas).to(device), inception_model, fid=True)
            real_preds.append(pred)
        real_preds = np.array(real_preds).reshape(-1,pred.shape[1])
    fake_preds = []
    if fixed_z != None:
        for i in range(fixed_z.shape[0]):
            fakes = gan_model(fixed_z[i])
            pred = get_pred(up(fakes), inception_model, fid=True)
            fake_preds.append(pred)
        fake_preds = np.array(fake_preds).reshape(-1,pred.shape[1])
    
    kid_values = polynomial_mmd_averages(real_preds, fake_preds, n_subsets=100)
    
    return kid_values[0].mean()

In [6]:
def get_pred(x, model, fid=False):
    x = model(x)
    if fid:
        pred = x.cpu().detach().numpy()
    else:
        pred = F.softmax(x).data.cpu().numpy()
    return pred

In [7]:
TEST_BATCH_SIZE = 128
nz = 100
ngf = 64
device = f"cuda:{0}" if torch.cuda.is_available() else "cpu"

transform = transforms.Compose([
    transforms.Resize(64),
    transforms.CenterCrop(64),
    transforms.ToTensor(),
    transforms.Normalize((0.5),(0.5))])

test_dataset = datasets.CelebA(root='/home/image/CelebA/data/', split="test", target_type='attr', download=False, transform=transform)
test_dl = DataLoader(test_dataset, batch_size=TEST_BATCH_SIZE, shuffle=False)

dcgan = DCGAN(nz,ngf).to(device)
dcgan.load_state_dict(torch.load("checkpoints/DCGAN_CelebA_netG.pth"))
inception_model = inception_v3(pretrained=True).to(device)
inception_model.eval()

fixed_z = torch.randn(156, TEST_BATCH_SIZE, nz, 1, 1, device=device).requires_grad_()

up = nn.Upsample(size=(128, 128), mode='bilinear')

In [57]:
real_is = Inception_score(inception_model, dcgan, test_dl=test_dl)
print(real_is)
fake_is = Inception_score(inception_model, dcgan, fixed_z=fixed_z)
print(fake_is)

  "See the documentation of nn.Upsample for details.".format(mode)
  This is separate from the ipykernel package so we can avoid doing imports until
 99%|█████████▉| 155/156 [00:54<00:00,  2.85it/s]


15.160202
11.840354


In [26]:
fid = FID(inception_model, dcgan, test_dl=test_dl, fixed_z=fixed_z)
print(fid)

  "See the documentation of nn.Upsample for details.".format(mode)
 99%|█████████▉| 155/156 [00:52<00:00,  2.96it/s]
FID: 100%|██████████| 10/10 [00:15<00:00,  1.52s/it, mean=404]

(403.90210336574563, 8.99333050911341)





In [29]:
kid = KID(inception_model, dcgan, test_dl=test_dl, fixed_z=fixed_z)
print(kid)

  "See the documentation of nn.Upsample for details.".format(mode)
 99%|█████████▉| 155/156 [00:51<00:00,  2.99it/s]


MMD: 100%|██████████| 100/100 [00:06<00:00, 15.04it/s, mean=3.88]
3.8781415670870865


In [8]:
wgan = DCGAN(nz,ngf).to(device)
wgan.load_state_dict(torch.load("checkpoints/WGAN_CelebA_netG.pth"))

<All keys matched successfully>

In [9]:
real_is = Inception_score(inception_model, wgan, test_dl=test_dl)
print(real_is)
fake_is = Inception_score(inception_model, wgan, fixed_z=fixed_z)
print(fake_is)

fid = FID(inception_model, wgan, test_dl=test_dl, fixed_z=fixed_z)
print(fid)

kid = KID(inception_model, wgan, test_dl=test_dl, fixed_z=fixed_z)
print(kid)

  "See the documentation of nn.Upsample for details.".format(mode)
  
 99%|█████████▉| 155/156 [06:28<00:02,  2.51s/it]


15.160202


  0%|          | 0/156 [00:00<?, ?it/s]

11.197154


 99%|█████████▉| 155/156 [00:58<00:00,  2.64it/s]
FID: 100%|██████████| 10/10 [00:09<00:00,  1.00it/s, mean=538]
  0%|          | 0/156 [00:00<?, ?it/s]

(538.4749462789041, 8.901178003748322)


 99%|█████████▉| 155/156 [00:53<00:00,  2.90it/s]


MMD: 100%|██████████| 100/100 [00:05<00:00, 18.39it/s, mean=4.93]
4.9309125899499495
