In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torchvision
import numpy as np

from fjd_metric import FJDMetric
from embeddings import OneHotEmbedding, InceptionEmbedding
from torchvision.models.inception import inception_v3
from torch.autograd import Variable
from torch.nn import functional as F
from scipy.stats import entropy

In [12]:
batch_size = 100


class oldNetG(nn.Module):
    def __init__(self, lenz):
        super(oldNetG, self).__init__()
        self.lenz = lenz
        self.l = nn.Linear(110,384)

        self.t1 = nn.ConvTranspose2d(384, 192, 6, (2,2))
        self.bn1 = nn.BatchNorm2d(192)
        self.t2 = nn.ConvTranspose2d(192, 96, 5, (2, 2))
        self.bn2 = nn.BatchNorm2d(96)
        
        self.t3 = nn.ConvTranspose2d(96, 3, 4, (2, 2))
        

    def forward(self, x:torch.Tensor):
#         print(x.shape)
        x = x.view(-1, self.lenz)
#         print(x.shape)
        x = self.l(x).view(-1, 384, 1, 1)
#         print(x.shape)
        x = F.relu(self.bn1(self.t1(x)))
#         print(x.shape)
        x = F.relu(self.bn2(self.t2(x)))
#         print(x.shape)
        x = F.tanh(self.t3(x))
#         print(x.shape)
        return x

class GANWrapper:
    def __init__(self, model, model_checkpoint=None):
        self.model = model

        if model_checkpoint is not None:
            self.model_checkpoint = model_checkpoint
            self.load_model()

    def load_model(self):
        # self.model.eval()  # uncomment to put in eval mode if desired
        self.model = self.model.cuda()

        state_dict = torch.load(self.model_checkpoint)
        self.model.load_state_dict(state_dict)

    def get_noise(self, batch_size):
        # change the noise dimension as required
        z = torch.cuda.FloatTensor(batch_size, 128).normal_()
        return z

    def __call__(self, y):
#         print(y)
#         print(type(y))
#         print(y.shape)
#         y.unsqueeze(1)
        np_gen_label = y.cpu().numpy()
        batch_size = y.size(0)
#         z = self.get_noise(batch_size)
#         samples = self.model(z, y)
        
    
        noise = np.random.normal(0, 1, (batch_size, 100))
        onehot = np.zeros((batch_size, 10))
        onehot[np.arange(batch_size), np_gen_label] = 1
        z = np.concatenate((noise, onehot), axis=1)
        z = torch.from_numpy(z).float().to('cuda')


        gen_imgs = self.model(z)
        return gen_imgs

transform_train = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),

])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),

])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=0)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=0)

inception_embedding = InceptionEmbedding(parallel=False)
onehot_embedding = OneHotEmbedding(num_classes=10)
fjds = []
fids = []
for epoch in range(0,550,50):
    epoch = str(epoch)
    net_g = oldNetG(110).to('cuda')
    PATH = 'original_model/models/netG_epoch_'+epoch+'.pth'

    net_g.load_state_dict(torch.load(PATH))
    net_g.eval()

    gan = GANWrapper(net_g)
    fjd_metric = FJDMetric(gan=gan,
                           reference_loader=trainloader,
                           condition_loader=testloader,
                           image_embedding=inception_embedding,
                           condition_embedding=onehot_embedding,
                           reference_stats_path='datasets/cifar_train_stats.npz',
                           save_reference_stats=True,
                           samples_per_condition=1,
                           cuda=True)

    fid = fjd_metric.get_fid()
    fjd = fjd_metric.get_fjd()
    print('FID: ', fid)
    print('FJD: ', fjd)
    fids.append(fid)
    fjds.append(fjd)

Files already downloaded and verified
Files already downloaded and verified
Loading reference statistics from datasets/cifar_train_stats.npz


Computing generated distribution: 100%|██████████| 100/100 [00:17<00:00,  5.79it/s]
Computing generated distribution: 100%|██████████| 100/100 [00:17<00:00,  5.67it/s]


FID:  453.18124027152896
FJD:  454.9895702136364
Loading reference statistics from datasets/cifar_train_stats.npz


Computing generated distribution: 100%|██████████| 100/100 [00:17<00:00,  5.62it/s]
Computing generated distribution: 100%|██████████| 100/100 [00:17<00:00,  5.69it/s]


FID:  133.99499967665682
FJD:  147.29278612807775
Loading reference statistics from datasets/cifar_train_stats.npz


Computing generated distribution: 100%|██████████| 100/100 [00:17<00:00,  5.71it/s]
Computing generated distribution: 100%|██████████| 100/100 [00:17<00:00,  5.69it/s]


FID:  103.22225640096735
FJD:  116.3909947829834
Loading reference statistics from datasets/cifar_train_stats.npz


Computing generated distribution: 100%|██████████| 100/100 [00:18<00:00,  5.54it/s]
Computing generated distribution: 100%|██████████| 100/100 [00:17<00:00,  5.68it/s]


FID:  95.2756787667343
FJD:  110.53598964496905
Loading reference statistics from datasets/cifar_train_stats.npz


Computing generated distribution: 100%|██████████| 100/100 [00:17<00:00,  5.69it/s]
Computing generated distribution: 100%|██████████| 100/100 [00:17<00:00,  5.63it/s]


FID:  101.6863377848386
FJD:  121.05548137514461
Loading reference statistics from datasets/cifar_train_stats.npz


Computing generated distribution: 100%|██████████| 100/100 [00:17<00:00,  5.68it/s]
Computing generated distribution: 100%|██████████| 100/100 [00:17<00:00,  5.68it/s]


FID:  93.06280173352599
FJD:  113.44248829378625
Loading reference statistics from datasets/cifar_train_stats.npz


Computing generated distribution: 100%|██████████| 100/100 [00:17<00:00,  5.64it/s]
Computing generated distribution: 100%|██████████| 100/100 [00:17<00:00,  5.67it/s]


FID:  102.82236268852671
FJD:  127.27991817975453
Loading reference statistics from datasets/cifar_train_stats.npz


Computing generated distribution: 100%|██████████| 100/100 [00:17<00:00,  5.62it/s]
Computing generated distribution: 100%|██████████| 100/100 [00:17<00:00,  5.74it/s]


FID:  105.1096725185871
FJD:  126.57541131805351
Loading reference statistics from datasets/cifar_train_stats.npz


Computing generated distribution: 100%|██████████| 100/100 [00:17<00:00,  5.65it/s]
Computing generated distribution: 100%|██████████| 100/100 [00:17<00:00,  5.57it/s]


FID:  109.81523429635183
FJD:  138.73378826735916
Loading reference statistics from datasets/cifar_train_stats.npz


Computing generated distribution: 100%|██████████| 100/100 [00:17<00:00,  5.62it/s]
Computing generated distribution: 100%|██████████| 100/100 [00:17<00:00,  5.62it/s]


FID:  110.50924722254615
FJD:  129.66991052564845
Loading reference statistics from datasets/cifar_train_stats.npz


Computing generated distribution: 100%|██████████| 100/100 [00:17<00:00,  5.57it/s]
Computing generated distribution: 100%|██████████| 100/100 [00:17<00:00,  5.62it/s]


FID:  117.0128656042956
FJD:  145.11874566658275


In [28]:



def inception_score(imgs, cuda=True, batch_size=32, resize=False, splits=1):
    """Computes the inception score of the generated images imgs
    imgs -- Torch dataset of (3xHxW) numpy images normalized in the range [-1, 1]
    cuda -- whether or not to run on GPU
    batch_size -- batch size for feeding into Inception v3
    splits -- number of splits
    """
    N = len(imgs)

    assert batch_size > 0
    assert N > batch_size

    # Set up dtype
    if cuda:
        dtype = torch.cuda.FloatTensor
    else:
        if torch.cuda.is_available():
            print("WARNING: You have a CUDA device, so you should probably set cuda=True")
        dtype = torch.FloatTensor

    # Set up dataloader
    dataloader = torch.utils.data.DataLoader(imgs, batch_size=batch_size)

    # Load inception model
    inception_model = inception_v3(pretrained=True, transform_input=False).type(dtype)
    inception_model.eval();
    up = nn.Upsample(size=(299, 299), mode='bilinear').type(dtype)
    def get_pred(x, resize_1):

        if resize_1:
            x = up(x)

        x = inception_model(x)
        return F.softmax(x).data.cpu().numpy()

    # Get predictions
    preds = np.zeros((N, 1000))

    for i, batch in enumerate(dataloader, 0):
        batch = batch.type(dtype)
        batchv = Variable(batch)
        batch_size_i = batch.size()[0]

        preds[i*batch_size:i*batch_size + batch_size_i] = get_pred(batchv, resize)

    # Now compute the mean kl-div
    split_scores = []

    for k in range(splits):
        part = preds[k * (N // splits): (k+1) * (N // splits), :]
        py = np.mean(part, axis=0)
        scores = []
        for i in range(part.shape[0]):
            pyx = part[i, :]
            scores.append(entropy(pyx, py))
        split_scores.append(np.exp(np.mean(scores)))

    return np.mean(split_scores), np.std(split_scores)

isss = []
for epoch in range(0,550,50):
    epoch = str(epoch)
    net_g = oldNetG(110).to('cuda')
    PATH = 'original_model/models/netG_epoch_'+epoch+'.pth'

    net_g.load_state_dict(torch.load(PATH))
    net_g.eval()



    batch_size = 500
    noise = np.random.normal(0, 1, (batch_size, 100))
    np_gen_label = np.random.randint(0, 10, batch_size)
    onehot = np.zeros((batch_size, 10))
    onehot[np.arange(batch_size), np_gen_label] = 1
    z = np.concatenate((noise, onehot), axis=1)
    z = torch.from_numpy(z).float().to('cuda')


    gen_imgs = net_g(z)
    iss = inception_score(gen_imgs,resize = True)
    print(iss)
    isss.append(iss)

RuntimeError: CUDA out of memory. Tried to allocate 12.00 MiB (GPU 0; 5.38 GiB total capacity; 4.49 GiB already allocated; 0 bytes free; 4.51 GiB reserved in total by PyTorch)