In [1]:
from __future__ import print_function
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils


import easydict
args = easydict.EasyDict({
    'dataset':'cifar10',
    'dataroot':'../../dataset',
    'workers':2,
    'batchSize':64,
    'imageSize':64,
    'nz':100,
    'ngf':64,
    'ndf':64,
    'niter':25,
    'lr':0.0002,
    'beta1':0.5,
    'cuda':False,
    'dry_run':False,
    'ngpu':1,
    'netG':'',
    'netD':'',
    'manualSeed':None,
    'classes':None,
    'outf':'result_image',
})


#opt = parser.parse_args()
opt = args
print(opt)

try:
    os.makedirs(opt.outf)
except OSError:
    pass

if opt.manualSeed is None:
    opt.manualSeed = random.randint(1, 10000)
print("Random Seed: ", opt.manualSeed)
random.seed(opt.manualSeed)
torch.manual_seed(opt.manualSeed)

cudnn.benchmark = True

if torch.cuda.is_available() and not opt.cuda:
    print("WARNING: You have a CUDA device, so you should probably run with --cuda")
  

if opt.dataroot is None and str(opt.dataset).lower() != 'fake':
    raise ValueError("`dataroot` parameter is required for dataset \"%s\"" % opt.dataset)

if opt.dataset in ['imagenet', 'folder', 'lfw']:
    # folder dataset
    dataset = dset.ImageFolder(root=opt.dataroot,
                               transform=transforms.Compose([
                                   transforms.Resize(opt.imageSize),
                                   transforms.CenterCrop(opt.imageSize),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                               ]))
    nc=3
elif opt.dataset == 'lsun':
    classes = [ c + '_train' for c in opt.classes.split(',')]
    dataset = dset.LSUN(root=opt.dataroot, classes=classes,
                        transform=transforms.Compose([
                            transforms.Resize(opt.imageSize),
                            transforms.CenterCrop(opt.imageSize),
                            transforms.ToTensor(),
                            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                        ]))
    nc=3
elif opt.dataset == 'cifar10':
    dataset = dset.CIFAR10(root=opt.dataroot, #download=True,
                           transform=transforms.Compose([
                               transforms.Resize(opt.imageSize),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))
    nc=3

elif opt.dataset == 'mnist':
        dataset = dset.MNIST(root=opt.dataroot, download=True,
                           transform=transforms.Compose([
                               transforms.Resize(opt.imageSize),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5,), (0.5,)),
                           ]))
        nc=1

elif opt.dataset == 'fake':
    dataset = dset.FakeData(image_size=(3, opt.imageSize, opt.imageSize),
                            transform=transforms.ToTensor())
    nc=3

assert dataset
dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize,
                                         shuffle=True, num_workers=int(opt.workers))

device = torch.device("cuda:2" if opt.cuda else "cpu")
ngpu = int(opt.ngpu)
nz = int(opt.nz)
ngf = int(opt.ngf)
ndf = int(opt.ndf)


# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        torch.nn.init.normal_(m.weight, 1.0, 0.02)
        torch.nn.init.zeros_(m.bias)


class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(     nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2,     ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d(    ngf,      nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )

    def forward(self, input):
        if input.is_cuda and self.ngpu > 1:
            output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
        else:
            output = self.main(input)
        return output


netG = Generator(ngpu).to(device)
netG.apply(weights_init)
if opt.netG != '':
    netG.load_state_dict(torch.load(opt.netG))
print(netG)


class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        if input.is_cuda and self.ngpu > 1:
            output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
        else:
            output = self.main(input)

        return output.view(-1, 1).squeeze(1)


netD = Discriminator(ngpu).to(device)
netD.apply(weights_init)
if opt.netD != '':
    netD.load_state_dict(torch.load(opt.netD))
print(netD)

criterion = nn.BCELoss()

fixed_noise = torch.randn(opt.batchSize, nz, 1, 1, device=device)
real_label = 1
fake_label = 0

# setup optimizer
optimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))


{'dataset': 'cifar10', 'dataroot': '../../dataset', 'workers': 2, 'batchSize': 64, 'imageSize': 64, 'nz': 100, 'ngf': 64, 'ndf': 64, 'niter': 25, 'lr': 0.0002, 'beta1': 0.5, 'cuda': False, 'dry_run': False, 'ngpu': 1, 'netG': '', 'netD': '', 'manualSeed': None, 'classes': None, 'outf': 'result_image'}
Random Seed:  71
Generator(
  (main): Sequential(
    (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTr

In [2]:
import numpy as np
def calculate_activation_statistics(dataloader,model,batch_size=128, dims=2048, device=device):
    model.eval()

    pred_list = []
    for data in dataloader : 
        batch=data[0].to(device)
        pred = model(batch)[0]
        pred_list.append(pred.detach().cpu())
       
    pred = torch.cat(pred_list)

        # If model output is not scalar, apply global spatial average pooling.
        # This happens if you choose a dimensionality not equal 2048.
    if pred.size(2) != 1 or pred.size(3) != 1:
        pred = adaptive_avg_pool2d(pred, output_size=(1, 1))

    act= pred.cpu().data.numpy().reshape(pred.size(0), -1)
    return act
    
def act_to_mu_sig(act):
    mu = np.mean(act, axis=0)
    sigma = np.cov(act, rowvar=False)
    return mu, sigma
    
    
from scipy import linalg
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
    """Numpy implementation of the Frechet Distance.
    The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
    and X_2 ~ N(mu_2, C_2) is
            d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
    """

    mu1 = np.atleast_1d(mu1)
    mu2 = np.atleast_1d(mu2)

    sigma1 = np.atleast_2d(sigma1)
    sigma2 = np.atleast_2d(sigma2)

    assert mu1.shape == mu2.shape, \
        'Training and test mean vectors have different lengths'
    assert sigma1.shape == sigma2.shape, \
        'Training and test covariances have different dimensions'

    diff = mu1 - mu2

    
    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
    if not np.isfinite(covmean).all():
        msg = ('fid calculation produces singular product; '
               'adding %s to diagonal of cov estimates') % eps
        print(msg)
        offset = np.eye(sigma1.shape[0]) * eps
        covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))

    
    if np.iscomplexobj(covmean):
        if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
            m = np.max(np.abs(covmean.imag))
            raise ValueError('Imaginary component {}'.format(m))
        covmean = covmean.real

    tr_covmean = np.trace(covmean)

    return (diff.dot(diff) + np.trace(sigma1) +
            np.trace(sigma2) - 2 * tr_covmean)
            
            
def calculate_fretchet(images_real,images_fake,model):
     mu_1,std_1=act_to_mu_sig(calculate_activation_statistics(images_real,model,cuda=True))
     mu_2,std_2=act_to_mu_sig(calculate_activation_statistics(images_fake,model,cuda=True))
    
     """get fretched distance"""
     fid_value = calculate_frechet_distance(mu_1, std_1, mu_2, std_2)
     return fid_value

In [3]:
from pytorch_fid.inception import InceptionV3
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
InceptionModel = InceptionV3([block_idx])
from torch.utils.data import  TensorDataset, DataLoader
import prdc

In [None]:

fakeDataloader = torchlist_to_dataloader(fake_data_list, 512,num_workers=num_workers)
fake_act = calculate_activation_statistics(fakeDataloader,InceptionModel,device='cuda:0')
fake_mu, fake_sigma=act_to_mu_sig(fake_act)
fid_value_static = calculate_frechet_distance(real_mu, real_sigma, fake_mu, fake_sigma)

real_pick = np.random.permutation(real_act)[:10000]
fake_pick = np.random.permutation(fake_act)[:10000]
prdc_metrics = prdc.compute_prdc(real_features=real_pick, fake_features=fake_pick, nearest_k=5)


prdc_metrics['fid'] = float(fid_value_static)

<All keys matched successfully>

In [11]:

if opt.dry_run:
    opt.niter = 1

for epoch in range(opt.niter):
    for i, data in enumerate(dataloader, 0):
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        # train with real
        netD.zero_grad()
        real_cpu = data[0].to(device)
        batch_size = real_cpu.size(0)
        label = torch.full((batch_size,), real_label,
                           dtype=real_cpu.dtype, device=device)

        output = netD(real_cpu)
        errD_real = criterion(output, label)
        errD_real.backward()
        D_x = output.mean().item()

        # train with fake
        noise = torch.randn(batch_size, nz, 1, 1, device=device)
        fake = netG(noise)
        label.fill_(fake_label)
        output = netD(fake.detach())
        errD_fake = criterion(output, label)
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        errD = errD_real + errD_fake
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        output = netD(fake)
        errG = criterion(output, label)
        errG.backward()
        D_G_z2 = output.mean().item()
        optimizerG.step()


        if i % 100 == 0:
            fretchet_dist=calculate_fretchet(real_cpu,fake,model)
            print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f, FID:%.4f'
              % (epoch, opt.niter, i, len(dataloader),
                 errD.item(), errG.item(), D_x, D_G_z1, D_G_z2,  fretchet_dist))
            vutils.save_image(real_cpu,
                    '%s/real_samples.png' % opt.outf,
                    normalize=True)
            fake = netG(fixed_noise)
            vutils.save_image(fake.detach(),
                    '%s/fake_samples_epoch_%03d.png' % (opt.outf, epoch),
                    normalize=True)

        if opt.dry_run:
            break
    # do checkpointing
    torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % (opt.outf, epoch))
    torch.save(netD.state_dict(), '%s/netD_epoch_%d.pth' % (opt.outf, epoch))
    
    for i, data in enumerate(dataloader, 0):
        noise = torch.randn(batch_size, nz, 1, 1, device=device)
        fake = netG(noise)
        vutils.save_image(fake.detach(),
                '%s/fake_samples_epoch_%03d.png' % (opt.outf, epoch),
                normalize=True)


[0/1][0/782] Loss_D: 1.8198 Loss_G: 5.6668 D(x): 0.4607 D(G(z)): 0.5231 / 0.0045
[0/1][1/782] Loss_D: 0.9995 Loss_G: 5.2463 D(x): 0.6394 D(G(z)): 0.3397 / 0.0067
[0/1][2/782] Loss_D: 0.8674 Loss_G: 5.9454 D(x): 0.7720 D(G(z)): 0.3599 / 0.0040
[0/1][3/782] Loss_D: 0.7985 Loss_G: 6.2013 D(x): 0.7579 D(G(z)): 0.3117 / 0.0026
[0/1][4/782] Loss_D: 0.7916 Loss_G: 6.6446 D(x): 0.7509 D(G(z)): 0.3070 / 0.0018
[0/1][5/782] Loss_D: 0.8424 Loss_G: 7.8631 D(x): 0.7775 D(G(z)): 0.3824 / 0.0005
[0/1][6/782] Loss_D: 0.5929 Loss_G: 7.0782 D(x): 0.7528 D(G(z)): 0.1815 / 0.0012
[0/1][7/782] Loss_D: 0.9935 Loss_G: 8.0840 D(x): 0.7227 D(G(z)): 0.3609 / 0.0005
[0/1][8/782] Loss_D: 0.7380 Loss_G: 7.5893 D(x): 0.7405 D(G(z)): 0.2106 / 0.0008
[0/1][9/782] Loss_D: 0.7470 Loss_G: 9.3534 D(x): 0.8024 D(G(z)): 0.3266 / 0.0001
[0/1][10/782] Loss_D: 0.3753 Loss_G: 8.1213 D(x): 0.8242 D(G(z)): 0.0958 / 0.0005
[0/1][11/782] Loss_D: 0.7844 Loss_G: 10.0659 D(x): 0.7929 D(G(z)): 0.3132 / 0.0001
[0/1][12/782] Loss_D: 0.2

[0/1][101/782] Loss_D: 1.4458 Loss_G: 6.6411 D(x): 0.4071 D(G(z)): 0.0017 / 0.0022
[0/1][102/782] Loss_D: 0.0918 Loss_G: 3.6876 D(x): 0.9508 D(G(z)): 0.0358 / 0.0566
[0/1][103/782] Loss_D: 0.6709 Loss_G: 6.5954 D(x): 0.9643 D(G(z)): 0.3843 / 0.0052
[0/1][104/782] Loss_D: 0.2321 Loss_G: 6.4633 D(x): 0.8667 D(G(z)): 0.0443 / 0.0066
[0/1][105/782] Loss_D: 0.3123 Loss_G: 4.9794 D(x): 0.8345 D(G(z)): 0.0469 / 0.0255
[0/1][106/782] Loss_D: 0.2917 Loss_G: 4.2273 D(x): 0.9195 D(G(z)): 0.1633 / 0.0242
[0/1][107/782] Loss_D: 0.4516 Loss_G: 4.9444 D(x): 0.8908 D(G(z)): 0.2279 / 0.0176
[0/1][108/782] Loss_D: 0.5439 Loss_G: 3.8339 D(x): 0.7454 D(G(z)): 0.1400 / 0.0397
[0/1][109/782] Loss_D: 0.8533 Loss_G: 8.2772 D(x): 0.8813 D(G(z)): 0.4419 / 0.0005
[0/1][110/782] Loss_D: 0.8947 Loss_G: 4.0663 D(x): 0.5382 D(G(z)): 0.0093 / 0.0270
[0/1][111/782] Loss_D: 0.7381 Loss_G: 7.2695 D(x): 0.8966 D(G(z)): 0.4103 / 0.0034
[0/1][112/782] Loss_D: 0.1661 Loss_G: 6.8422 D(x): 0.8960 D(G(z)): 0.0411 / 0.0029
[0/1

[0/1][201/782] Loss_D: 0.1019 Loss_G: 7.5799 D(x): 0.9521 D(G(z)): 0.0466 / 0.0008
[0/1][202/782] Loss_D: 0.2884 Loss_G: 6.6736 D(x): 0.8567 D(G(z)): 0.0859 / 0.0019
[0/1][203/782] Loss_D: 0.2430 Loss_G: 11.3796 D(x): 0.9637 D(G(z)): 0.1710 / 0.0000
[0/1][204/782] Loss_D: 0.0454 Loss_G: 10.6418 D(x): 0.9595 D(G(z)): 0.0018 / 0.0000
[0/1][205/782] Loss_D: 0.1921 Loss_G: 6.5065 D(x): 0.8837 D(G(z)): 0.0039 / 0.0021
[0/1][206/782] Loss_D: 0.4421 Loss_G: 15.7772 D(x): 0.9922 D(G(z)): 0.3237 / 0.0000
[0/1][207/782] Loss_D: 0.6255 Loss_G: 14.5500 D(x): 0.6617 D(G(z)): 0.0000 / 0.0000
[0/1][208/782] Loss_D: 0.1008 Loss_G: 9.8340 D(x): 0.9535 D(G(z)): 0.0002 / 0.0001
[0/1][209/782] Loss_D: 0.0584 Loss_G: 5.3255 D(x): 0.9885 D(G(z)): 0.0443 / 0.0084
[0/1][210/782] Loss_D: 0.9511 Loss_G: 18.3838 D(x): 0.9966 D(G(z)): 0.5436 / 0.0000
[0/1][211/782] Loss_D: 0.6022 Loss_G: 18.3160 D(x): 0.7193 D(G(z)): 0.0000 / 0.0000
[0/1][212/782] Loss_D: 0.3401 Loss_G: 12.5566 D(x): 0.8300 D(G(z)): 0.0000 / 0.00

[0/1][301/782] Loss_D: 0.1235 Loss_G: 5.0213 D(x): 0.9499 D(G(z)): 0.0612 / 0.0095
[0/1][302/782] Loss_D: 0.2690 Loss_G: 4.6818 D(x): 0.8852 D(G(z)): 0.1105 / 0.0149
[0/1][303/782] Loss_D: 0.2194 Loss_G: 5.1393 D(x): 0.9189 D(G(z)): 0.1093 / 0.0098
[0/1][304/782] Loss_D: 0.2922 Loss_G: 3.6590 D(x): 0.8388 D(G(z)): 0.0655 / 0.0430
[0/1][305/782] Loss_D: 0.4706 Loss_G: 7.8205 D(x): 0.9254 D(G(z)): 0.2683 / 0.0015
[0/1][306/782] Loss_D: 0.6327 Loss_G: 4.7442 D(x): 0.6537 D(G(z)): 0.0035 / 0.0182
[0/1][307/782] Loss_D: 0.1057 Loss_G: 3.9421 D(x): 0.9765 D(G(z)): 0.0734 / 0.0318
[0/1][308/782] Loss_D: 0.3387 Loss_G: 7.2334 D(x): 0.9570 D(G(z)): 0.2304 / 0.0013
[0/1][309/782] Loss_D: 0.1919 Loss_G: 6.6068 D(x): 0.8726 D(G(z)): 0.0078 / 0.0036
[0/1][310/782] Loss_D: 0.2173 Loss_G: 4.0303 D(x): 0.8586 D(G(z)): 0.0273 / 0.0534
[0/1][311/782] Loss_D: 0.2996 Loss_G: 5.3626 D(x): 0.9551 D(G(z)): 0.1945 / 0.0099
[0/1][312/782] Loss_D: 0.2791 Loss_G: 4.5355 D(x): 0.8957 D(G(z)): 0.0741 / 0.0184
[0/1

[0/1][401/782] Loss_D: 0.5516 Loss_G: 6.0585 D(x): 0.9699 D(G(z)): 0.3721 / 0.0041
[0/1][402/782] Loss_D: 0.2644 Loss_G: 6.0395 D(x): 0.8031 D(G(z)): 0.0096 / 0.0042
[0/1][403/782] Loss_D: 0.1937 Loss_G: 3.7527 D(x): 0.8603 D(G(z)): 0.0277 / 0.0411
[0/1][404/782] Loss_D: 0.3494 Loss_G: 4.3054 D(x): 0.9118 D(G(z)): 0.1836 / 0.0181
[0/1][405/782] Loss_D: 0.3077 Loss_G: 5.8385 D(x): 0.8967 D(G(z)): 0.1468 / 0.0068
[0/1][406/782] Loss_D: 0.6502 Loss_G: 1.9192 D(x): 0.6844 D(G(z)): 0.0430 / 0.2064
[0/1][407/782] Loss_D: 1.4579 Loss_G: 10.5224 D(x): 0.9502 D(G(z)): 0.6670 / 0.0001
[0/1][408/782] Loss_D: 2.2102 Loss_G: 2.6331 D(x): 0.2911 D(G(z)): 0.0012 / 0.1416
[0/1][409/782] Loss_D: 2.0689 Loss_G: 9.6652 D(x): 0.9218 D(G(z)): 0.7429 / 0.0003
[0/1][410/782] Loss_D: 3.1011 Loss_G: 2.6085 D(x): 0.1330 D(G(z)): 0.0074 / 0.1193
[0/1][411/782] Loss_D: 0.9621 Loss_G: 4.0535 D(x): 0.8242 D(G(z)): 0.4397 / 0.0265
[0/1][412/782] Loss_D: 0.4750 Loss_G: 4.2111 D(x): 0.7856 D(G(z)): 0.1505 / 0.0236
[0/

[0/1][501/782] Loss_D: 0.3664 Loss_G: 4.3006 D(x): 0.9117 D(G(z)): 0.2127 / 0.0202
[0/1][502/782] Loss_D: 0.4245 Loss_G: 4.4822 D(x): 0.8446 D(G(z)): 0.1862 / 0.0217
[0/1][503/782] Loss_D: 0.2371 Loss_G: 4.5760 D(x): 0.8792 D(G(z)): 0.0831 / 0.0202
[0/1][504/782] Loss_D: 0.6607 Loss_G: 2.1833 D(x): 0.6531 D(G(z)): 0.1107 / 0.1744
[0/1][505/782] Loss_D: 1.0769 Loss_G: 6.9779 D(x): 0.9319 D(G(z)): 0.5527 / 0.0014
[0/1][506/782] Loss_D: 0.7180 Loss_G: 4.6744 D(x): 0.5603 D(G(z)): 0.0119 / 0.0148
[0/1][507/782] Loss_D: 0.2603 Loss_G: 2.7236 D(x): 0.8684 D(G(z)): 0.0923 / 0.0930
[0/1][508/782] Loss_D: 0.7547 Loss_G: 4.2939 D(x): 0.8409 D(G(z)): 0.3556 / 0.0239
[0/1][509/782] Loss_D: 0.5030 Loss_G: 3.4005 D(x): 0.7761 D(G(z)): 0.1393 / 0.0435
[0/1][510/782] Loss_D: 0.4978 Loss_G: 2.1672 D(x): 0.7313 D(G(z)): 0.1005 / 0.1454
[0/1][511/782] Loss_D: 0.8231 Loss_G: 6.4561 D(x): 0.9356 D(G(z)): 0.4815 / 0.0026
[0/1][512/782] Loss_D: 0.5855 Loss_G: 5.2210 D(x): 0.6393 D(G(z)): 0.0087 / 0.0106
[0/1

[0/1][601/782] Loss_D: 0.2834 Loss_G: 3.3693 D(x): 0.8809 D(G(z)): 0.1292 / 0.0426
[0/1][602/782] Loss_D: 0.8014 Loss_G: 6.1535 D(x): 0.9327 D(G(z)): 0.4789 / 0.0052
[0/1][603/782] Loss_D: 1.4815 Loss_G: 0.8986 D(x): 0.3384 D(G(z)): 0.0277 / 0.5039
[0/1][604/782] Loss_D: 2.1275 Loss_G: 7.5631 D(x): 0.9477 D(G(z)): 0.7781 / 0.0012
[0/1][605/782] Loss_D: 1.3904 Loss_G: 4.2363 D(x): 0.3463 D(G(z)): 0.0079 / 0.0276
[0/1][606/782] Loss_D: 0.7975 Loss_G: 1.1045 D(x): 0.6112 D(G(z)): 0.1302 / 0.3823
[0/1][607/782] Loss_D: 1.7659 Loss_G: 6.0084 D(x): 0.9585 D(G(z)): 0.7494 / 0.0064
[0/1][608/782] Loss_D: 1.0994 Loss_G: 4.2747 D(x): 0.4748 D(G(z)): 0.0465 / 0.0205
[0/1][609/782] Loss_D: 0.4971 Loss_G: 2.5694 D(x): 0.7343 D(G(z)): 0.0826 / 0.1207
[0/1][610/782] Loss_D: 0.7496 Loss_G: 5.2125 D(x): 0.9251 D(G(z)): 0.4287 / 0.0094
[0/1][611/782] Loss_D: 0.4490 Loss_G: 3.7542 D(x): 0.7335 D(G(z)): 0.0477 / 0.0364
[0/1][612/782] Loss_D: 0.4362 Loss_G: 2.8736 D(x): 0.8121 D(G(z)): 0.1635 / 0.0722
[0/1

[0/1][701/782] Loss_D: 0.6537 Loss_G: 3.8301 D(x): 0.7403 D(G(z)): 0.2202 / 0.0369
[0/1][702/782] Loss_D: 0.5063 Loss_G: 4.6598 D(x): 0.8439 D(G(z)): 0.2253 / 0.0136
[0/1][703/782] Loss_D: 0.6971 Loss_G: 5.6531 D(x): 0.7863 D(G(z)): 0.2965 / 0.0058
[0/1][704/782] Loss_D: 0.8235 Loss_G: 3.6311 D(x): 0.6138 D(G(z)): 0.1634 / 0.0413
[0/1][705/782] Loss_D: 0.5540 Loss_G: 4.1529 D(x): 0.7982 D(G(z)): 0.2376 / 0.0252
[0/1][706/782] Loss_D: 0.5003 Loss_G: 4.7362 D(x): 0.8298 D(G(z)): 0.2124 / 0.0150
[0/1][707/782] Loss_D: 0.6762 Loss_G: 3.4996 D(x): 0.6959 D(G(z)): 0.1874 / 0.0456
[0/1][708/782] Loss_D: 0.5618 Loss_G: 5.5434 D(x): 0.8811 D(G(z)): 0.3088 / 0.0059
[0/1][709/782] Loss_D: 0.3414 Loss_G: 4.4233 D(x): 0.8070 D(G(z)): 0.0887 / 0.0186
[0/1][710/782] Loss_D: 0.6704 Loss_G: 4.3761 D(x): 0.7957 D(G(z)): 0.2861 / 0.0196
[0/1][711/782] Loss_D: 0.7317 Loss_G: 6.4533 D(x): 0.8375 D(G(z)): 0.3438 / 0.0034
[0/1][712/782] Loss_D: 1.0046 Loss_G: 1.8049 D(x): 0.4923 D(G(z)): 0.0397 / 0.2123
[0/1

In [23]:
from pytorch_fid.inception import InceptionV3
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
model = InceptionV3([block_idx]).to(device)

In [25]:
model

InceptionV3(
  (blocks): ModuleList(
    (0): Sequential(
      (0): BasicConv2d(
        (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicConv2d(
        (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      )
      (2): BasicConv2d(
        (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      )
      (3): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (1): Sequential(
      (0): BasicConv2d(
        (conv): Conv2d(64, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(80, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      

In [27]:
import numpy as np
def calculate_activation_statistics(images,model,batch_size=128, dims=2048,
                    cuda=False):
    model.eval()
    act=np.empty((len(images), dims))
    
    if cuda:
        batch=images.to(device)
    else:
        batch=images
    pred = model(batch)[0]

        # If model output is not scalar, apply global spatial average pooling.
        # This happens if you choose a dimensionality not equal 2048.
    if pred.size(2) != 1 or pred.size(3) != 1:
        pred = adaptive_avg_pool2d(pred, output_size=(1, 1))

    act= pred.cpu().data.numpy().reshape(pred.size(0), -1)
    
    mu = np.mean(act, axis=0)
    sigma = np.cov(act, rowvar=False)
    return mu, sigma

In [31]:
from scipy import linalg
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
    """Numpy implementation of the Frechet Distance.
    The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
    and X_2 ~ N(mu_2, C_2) is
            d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
    """

    mu1 = np.atleast_1d(mu1)
    mu2 = np.atleast_1d(mu2)

    sigma1 = np.atleast_2d(sigma1)
    sigma2 = np.atleast_2d(sigma2)

    assert mu1.shape == mu2.shape, \
        'Training and test mean vectors have different lengths'
    assert sigma1.shape == sigma2.shape, \
        'Training and test covariances have different dimensions'

    diff = mu1 - mu2

    
    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
    if not np.isfinite(covmean).all():
        msg = ('fid calculation produces singular product; '
               'adding %s to diagonal of cov estimates') % eps
        print(msg)
        offset = np.eye(sigma1.shape[0]) * eps
        covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))

    
    if np.iscomplexobj(covmean):
        if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
            m = np.max(np.abs(covmean.imag))
            raise ValueError('Imaginary component {}'.format(m))
        covmean = covmean.real

    tr_covmean = np.trace(covmean)

    return (diff.dot(diff) + np.trace(sigma1) +
            np.trace(sigma2) - 2 * tr_covmean)

In [16]:
def calculate_fretchet(images_real,images_fake,model):
     mu_1,std_1=calculate_activation_statistics(images_real,model,cuda=True)
     mu_2,std_2=calculate_activation_statistics(images_fake,model,cuda=True)
    
     """get fretched distance"""
     fid_value = calculate_frechet_distance(mu_1, std_1, mu_2, std_2)
     return fid_value

In [33]:
fretchet_dist

434.27731065758974

In [10]:
import tqdm

In [None]:
for i, data in enumerate(dataloader, 0):
    real_cpu = data[0].to(device)
    vutils.save_image(real_cpu,
        '%s/real_samples.png' % opt.outf,
        normalize=True)

In [12]:
for i, data in tqdm.tqdm(enumerate(dataloader, 0)):
    noise = torch.randn(batch_size, nz, 1, 1, device=device)
    fake = netG(noise)
    vutils.save_image(fake.detach(),
            '%s/fake_samples_epoch_%03d_%d.png' % ('result_image/fake_samples', epoch, i),
            normalize=True)

782it [01:21,  9.54it/s]


In [3]:
from pytorch_fid import fid_score
import torch

In [None]:
64*64*64

In [9]:
526 / 64

8.21875

In [7]:
vutils??

In [37]:
64*3*530*530

53932800

In [38]:
64*3*64*64

786432

In [39]:
53932800 / 786432

68.5791015625

In [None]:
b /opt/conda/envs/stylegan2/lib/python3.9/site-packages/pytorch_fid/fid_score.py:128