In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torchvision
import numpy as np

from fjd_metric import FJDMetric
from embeddings import OneHotEmbedding, InceptionEmbedding
from torchvision.models.inception import inception_v3
from torch.autograd import Variable
from torch.nn import functional as F
from scipy.stats import entropy
from msssim import MultiScaleSSIM

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [2]:
batch_size = 100


class oldNetG(nn.Module):
    def __init__(self, lenz):
        super(oldNetG, self).__init__()
        self.lenz = lenz
        self.l = nn.Linear(110,384)

        self.t1 = nn.ConvTranspose2d(384, 192, 6, (2,2))
        self.bn1 = nn.BatchNorm2d(192)
        self.t2 = nn.ConvTranspose2d(192, 96, 5, (2, 2))
        self.bn2 = nn.BatchNorm2d(96)
        
        self.t3 = nn.ConvTranspose2d(96, 3, 4, (2, 2))
        

    def forward(self, x:torch.Tensor):
#         print(x.shape)
        x = x.view(-1, self.lenz)
#         print(x.shape)
        x = self.l(x).view(-1, 384, 1, 1)
#         print(x.shape)
        x = F.relu(self.bn1(self.t1(x)))
#         print(x.shape)
        x = F.relu(self.bn2(self.t2(x)))
#         print(x.shape)
        x = F.tanh(self.t3(x))
#         print(x.shape)
        return x

class GANWrapper:
    def __init__(self, model, model_checkpoint=None):
        self.model = model

        if model_checkpoint is not None:
            self.model_checkpoint = model_checkpoint
            self.load_model()

    def load_model(self):
        # self.model.eval()  # uncomment to put in eval mode if desired
        self.model = self.model.cuda()

        state_dict = torch.load(self.model_checkpoint)
        self.model.load_state_dict(state_dict)

    def get_noise(self, batch_size):
        # change the noise dimension as required
        z = torch.cuda.FloatTensor(batch_size, 128).normal_()
        return z

    def __call__(self, y):
#         print(y)
#         print(type(y))
#         print(y.shape)
#         y.unsqueeze(1)
        np_gen_label = y.cpu().numpy()
        batch_size = y.size(0)
#         z = self.get_noise(batch_size)
#         samples = self.model(z, y)
        
    
        noise = np.random.normal(0, 1, (batch_size, 100))
        onehot = np.zeros((batch_size, 10))
        onehot[np.arange(batch_size), np_gen_label] = 1
        z = np.concatenate((noise, onehot), axis=1)
        z = torch.from_numpy(z).float().to('cuda')


        gen_imgs = self.model(z)
        return gen_imgs
    
class newNetG(nn.Module):
    def __init__(self, lenz):
        super(newNetG, self).__init__()
        self.lenz = lenz
        self.l = nn.Linear(110,512)

        self.t1 = nn.ConvTranspose2d(512, 384, 4, (2,2),1)
        self.bn1 = nn.BatchNorm2d(384)
        self.t2 = nn.ConvTranspose2d(384, 192, 4, (2, 2),1)
        self.bn2 = nn.BatchNorm2d(192)
        self.t3 = nn.ConvTranspose2d(192, 148, 4, (2, 2),1)
        self.bn3 = nn.BatchNorm2d(148)
        self.t4 = nn.ConvTranspose2d(148, 92, 4, (2, 2),1)
        
        self.bn4 = nn.BatchNorm2d(92)
        self.t5 = nn.ConvTranspose2d(92, 3, 4, (2, 2),1)
        

    def forward(self, x:torch.Tensor):
#         print(x.shape)
        x = x.view(-1, self.lenz)
#         print(x.shape)
        x = self.l(x).view(-1, 512, 1, 1)
#         print(x.shape)
        x = F.relu(self.bn1(self.t1(x)))
#         print(x.shape)
        x = F.relu(self.bn2(self.t2(x)))
#         print(x.shape)
        x = F.relu(self.bn3(self.t3(x)))
        x = F.relu(self.bn4(self.t4(x)))
#         print(x.shape)
        x = F.tanh(self.t5(x))
#         print(x.shape)
        return x

In [3]:
transform_train = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),

])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),

])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=0)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=0)

inception_embedding = InceptionEmbedding(parallel=False)
onehot_embedding = OneHotEmbedding(num_classes=10)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
fjds = []
fids = []
for epoch in range(0,550,50):
    epoch = str(epoch)
    net_g = oldNetG(110).to('cuda')
    PATH = 'original_model/models/netG_epoch_'+epoch+'.pth'

    net_g.load_state_dict(torch.load(PATH))
    net_g.eval()

    gan = GANWrapper(net_g)
    fjd_metric = FJDMetric(gan=gan,
                           reference_loader=trainloader,
                           condition_loader=testloader,
                           image_embedding=inception_embedding,
                           condition_embedding=onehot_embedding,
                           reference_stats_path='datasets/cifar_train_stats.npz',
                           save_reference_stats=True,
                           samples_per_condition=1,
                           cuda=True)

    fid = fjd_metric.get_fid()
    fjd = fjd_metric.get_fjd()
    print('FID: ', fid)
    print('FJD: ', fjd)
    print(fjd_metric.alpha)
    fids.append(fid)
    fjds.append(fjd)

Loading reference statistics from datasets/cifar_train_stats.npz


Computing generated distribution: 100%|██████████| 100/100 [00:16<00:00,  5.94it/s]
Computing generated distribution: 100%|██████████| 100/100 [00:17<00:00,  5.80it/s]


FID:  485.62789502169403
FJD:  487.88277249131124
tensor(21.7352, device='cuda:0', dtype=torch.float64)
Loading reference statistics from datasets/cifar_train_stats.npz


Computing generated distribution: 100%|██████████| 100/100 [00:17<00:00,  5.77it/s]
Computing generated distribution: 100%|██████████| 100/100 [00:17<00:00,  5.84it/s]


FID:  86.92183322781841
FJD:  100.28399270743967
tensor(21.7352, device='cuda:0', dtype=torch.float64)
Loading reference statistics from datasets/cifar_train_stats.npz


Computing generated distribution: 100%|██████████| 100/100 [00:17<00:00,  5.82it/s]
Computing generated distribution: 100%|██████████| 100/100 [00:17<00:00,  5.81it/s]


FID:  82.6511854205275
FJD:  95.07041197309036
tensor(21.7352, device='cuda:0', dtype=torch.float64)
Loading reference statistics from datasets/cifar_train_stats.npz


Computing generated distribution: 100%|██████████| 100/100 [00:17<00:00,  5.76it/s]
Computing generated distribution: 100%|██████████| 100/100 [00:17<00:00,  5.79it/s]


FID:  79.37391519366861
FJD:  91.66352567546346
tensor(21.7352, device='cuda:0', dtype=torch.float64)
Loading reference statistics from datasets/cifar_train_stats.npz


Computing generated distribution: 100%|██████████| 100/100 [00:17<00:00,  5.73it/s]
Computing generated distribution: 100%|██████████| 100/100 [00:17<00:00,  5.64it/s]


FID:  78.55823796639191
FJD:  88.22626166489476
tensor(21.7352, device='cuda:0', dtype=torch.float64)
Loading reference statistics from datasets/cifar_train_stats.npz


Computing generated distribution: 100%|██████████| 100/100 [00:17<00:00,  5.61it/s]
Computing generated distribution: 100%|██████████| 100/100 [00:17<00:00,  5.78it/s]
Computing generated distribution:   0%|          | 0/100 [00:00<?, ?it/s]

FID:  80.16680078565037
FJD:  90.23355326517867
tensor(21.7352, device='cuda:0', dtype=torch.float64)
Loading reference statistics from datasets/cifar_train_stats.npz


Computing generated distribution: 100%|██████████| 100/100 [00:17<00:00,  5.82it/s]
Computing generated distribution: 100%|██████████| 100/100 [00:17<00:00,  5.79it/s]


FID:  69.98073857340347
FJD:  79.9778990042987
tensor(21.7352, device='cuda:0', dtype=torch.float64)
Loading reference statistics from datasets/cifar_train_stats.npz


Computing generated distribution: 100%|██████████| 100/100 [00:17<00:00,  5.79it/s]
Computing generated distribution: 100%|██████████| 100/100 [00:17<00:00,  5.79it/s]


FID:  68.75876184948407
FJD:  78.2382495143363
tensor(21.7352, device='cuda:0', dtype=torch.float64)
Loading reference statistics from datasets/cifar_train_stats.npz


Computing generated distribution: 100%|██████████| 100/100 [00:17<00:00,  5.70it/s]
Computing generated distribution: 100%|██████████| 100/100 [00:17<00:00,  5.77it/s]


FID:  70.1439803882796
FJD:  78.18880871299757
tensor(21.7352, device='cuda:0', dtype=torch.float64)
Loading reference statistics from datasets/cifar_train_stats.npz


Computing generated distribution: 100%|██████████| 100/100 [00:17<00:00,  5.73it/s]
Computing generated distribution: 100%|██████████| 100/100 [00:17<00:00,  5.76it/s]


FID:  67.7486926299519
FJD:  76.51653078745812
tensor(21.7352, device='cuda:0', dtype=torch.float64)
Loading reference statistics from datasets/cifar_train_stats.npz


Computing generated distribution: 100%|██████████| 100/100 [00:17<00:00,  5.68it/s]
Computing generated distribution: 100%|██████████| 100/100 [00:17<00:00,  5.77it/s]


FID:  68.01113164618539
FJD:  77.16531961736177
tensor(21.7352, device='cuda:0', dtype=torch.float64)


In [19]:
new_fjds = []
new_fids = []
for epoch in range(0,550,50):
    epoch = str(epoch)
    net_g = newNetG(110).to('cuda')
    PATH = 'new_model/models/netG_epoch_'+epoch+'.pth'

    net_g.load_state_dict(torch.load(PATH))
    net_g.eval()

    gan = GANWrapper(net_g)
    fjd_metric = FJDMetric(gan=gan,
                           reference_loader=trainloader,
                           condition_loader=testloader,
                           image_embedding=inception_embedding,
                           condition_embedding=onehot_embedding,
                           reference_stats_path='datasets/cifar_train_stats.npz',
                           save_reference_stats=True,
                           samples_per_condition=1,
                           cuda=True)

    fid = fjd_metric.get_fid()
    fjd = fjd_metric.get_fjd()
    print('FID: ', fid)
    print('FJD: ', fjd)
    new_fids.append(fid)
    new_fjds.append(fjd)

Computing generated distribution:   0%|          | 0/100 [00:00<?, ?it/s]

Loading reference statistics from datasets/cifar_train_stats.npz


Computing generated distribution: 100%|██████████| 100/100 [00:17<00:00,  5.83it/s]
Computing generated distribution: 100%|██████████| 100/100 [00:17<00:00,  5.74it/s]


FID:  483.2364051485665
FJD:  484.77989003480866
Loading reference statistics from datasets/cifar_train_stats.npz


Computing generated distribution: 100%|██████████| 100/100 [00:17<00:00,  5.72it/s]
Computing generated distribution: 100%|██████████| 100/100 [00:17<00:00,  5.75it/s]


FID:  130.58560001574222
FJD:  160.23102420293435
Loading reference statistics from datasets/cifar_train_stats.npz


Computing generated distribution: 100%|██████████| 100/100 [00:17<00:00,  5.58it/s]
Computing generated distribution: 100%|██████████| 100/100 [00:17<00:00,  5.62it/s]


FID:  123.3490309028249
FJD:  154.12803149100864
Loading reference statistics from datasets/cifar_train_stats.npz


Computing generated distribution: 100%|██████████| 100/100 [00:17<00:00,  5.64it/s]
Computing generated distribution: 100%|██████████| 100/100 [00:17<00:00,  5.61it/s]


FID:  114.30469661685078
FJD:  141.67969992260896
Loading reference statistics from datasets/cifar_train_stats.npz


Computing generated distribution: 100%|██████████| 100/100 [00:17<00:00,  5.56it/s]
Computing generated distribution: 100%|██████████| 100/100 [00:17<00:00,  5.66it/s]


FID:  114.80089643373623
FJD:  145.568975978141
Loading reference statistics from datasets/cifar_train_stats.npz


Computing generated distribution: 100%|██████████| 100/100 [00:17<00:00,  5.61it/s]
Computing generated distribution: 100%|██████████| 100/100 [00:17<00:00,  5.64it/s]


FID:  113.78075578669407
FJD:  146.2069855187417
Loading reference statistics from datasets/cifar_train_stats.npz


Computing generated distribution: 100%|██████████| 100/100 [00:17<00:00,  5.61it/s]
Computing generated distribution: 100%|██████████| 100/100 [00:17<00:00,  5.59it/s]


FID:  123.48969677493272
FJD:  160.7689515382424
Loading reference statistics from datasets/cifar_train_stats.npz


Computing generated distribution: 100%|██████████| 100/100 [00:18<00:00,  5.52it/s]
Computing generated distribution: 100%|██████████| 100/100 [00:18<00:00,  5.35it/s]


FID:  121.10538853611044
FJD:  154.81216104988653
Loading reference statistics from datasets/cifar_train_stats.npz


Computing generated distribution: 100%|██████████| 100/100 [00:18<00:00,  5.48it/s]
Computing generated distribution: 100%|██████████| 100/100 [00:18<00:00,  5.49it/s]


FID:  147.72691906471675
FJD:  181.44155527631415
Loading reference statistics from datasets/cifar_train_stats.npz


Computing generated distribution: 100%|██████████| 100/100 [00:18<00:00,  5.49it/s]
Computing generated distribution: 100%|██████████| 100/100 [00:17<00:00,  5.63it/s]


FID:  155.24264264852695
FJD:  196.09805606794134
Loading reference statistics from datasets/cifar_train_stats.npz


Computing generated distribution: 100%|██████████| 100/100 [00:17<00:00,  5.56it/s]
Computing generated distribution: 100%|██████████| 100/100 [00:18<00:00,  5.39it/s]


FID:  125.96540365597787
FJD:  166.96340534691535


In [6]:
def inception_score(imgs, cuda=True, batch_size=32, resize=False, splits=1):
    """Computes the inception score of the generated images imgs
    imgs -- Torch dataset of (3xHxW) numpy images normalized in the range [-1, 1]
    cuda -- whether or not to run on GPU
    batch_size -- batch size for feeding into Inception v3
    splits -- number of splits
    """
    N = len(imgs)

    assert batch_size > 0
    assert N > batch_size

    # Set up dtype
    if cuda:
        dtype = torch.cuda.FloatTensor
    else:
        if torch.cuda.is_available():
            print("WARNING: You have a CUDA device, so you should probably set cuda=True")
        dtype = torch.FloatTensor

    # Set up dataloader
    dataloader = torch.utils.data.DataLoader(imgs, batch_size=batch_size)

    # Load inception model
    inception_model = inception_v3(pretrained=True, transform_input=False).type(dtype)
    inception_model.eval();
    up = nn.Upsample(size=(299, 299), mode='bilinear').type(dtype)
    def get_pred(x, resize_1):

        if resize_1:
            x = up(x)

        x = inception_model(x)
        return F.softmax(x).data.cpu().numpy()

    # Get predictions
    preds = np.zeros((N, 1000))

    for i, batch in enumerate(dataloader, 0):
        batch = batch.type(dtype)
        batchv = Variable(batch)
        batch_size_i = batch.size()[0]

        preds[i*batch_size:i*batch_size + batch_size_i] = get_pred(batchv, resize)

    # Now compute the mean kl-div
    split_scores = []

    for k in range(splits):
        part = preds[k * (N // splits): (k+1) * (N // splits), :]
        py = np.mean(part, axis=0)
        scores = []
        for i in range(part.shape[0]):
            pyx = part[i, :]
            scores.append(entropy(pyx, py))
        split_scores.append(np.exp(np.mean(scores)))

    return np.mean(split_scores), np.std(split_scores)

In [5]:
isss = []
for epoch in range(0,550,50):
    epoch = str(epoch)
    net_g = oldNetG(110).to('cuda')
    PATH = 'models/netG_epoch_'+epoch+'.pth'

    net_g.load_state_dict(torch.load(PATH))
    net_g.eval()



    batch_size = 500
    noise = np.random.normal(0, 1, (batch_size, 100))
    np_gen_label = np.random.randint(0, 10, batch_size)
    onehot = np.zeros((batch_size, 10))
    onehot[np.arange(batch_size), np_gen_label] = 1
    z = np.concatenate((noise, onehot), axis=1)
    z = torch.from_numpy(z).float().to('cuda')


    gen_imgs = net_g(z)
    iss = inception_score(gen_imgs,resize = True)
    print(iss)
    isss.append(iss)

  "See the documentation of nn.Upsample for details.".format(mode))


(1.0707358432321514, 0.0)
(2.6277085355430523, 0.0)
(3.062116927683054, 0.0)
(2.981237645431223, 0.0)
(3.5163301157166953, 0.0)
(3.286465983775069, 0.0)
(3.4853269111524883, 0.0)
(3.4182588323310035, 0.0)
(3.526760380834381, 0.0)
(3.600546285805776, 0.0)
(3.422761106399059, 0.0)


In [20]:
newisss = []
for epoch in range(0,550,50):
    epoch = str(epoch)
    net_g = newNetG(110).to('cuda')
    PATH = 'new_model/models/netG_epoch_'+epoch+'.pth'

    net_g.load_state_dict(torch.load(PATH))
    net_g.eval()



    batch_size = 500
    noise = np.random.normal(0, 1, (batch_size, 100))
    np_gen_label = np.random.randint(0, 10, batch_size)
    onehot = np.zeros((batch_size, 10))
    onehot[np.arange(batch_size), np_gen_label] = 1
    z = np.concatenate((noise, onehot), axis=1)
    z = torch.from_numpy(z).float().to('cuda')


    gen_imgs = net_g(z)
    iss = inception_score(gen_imgs,resize = True)
    print(iss)
    newisss.append(iss)



(1.0361979844563123, 0.0)
(3.1250639132742104, 0.0)
(2.6706251168500152, 0.0)
(3.2054184668702494, 0.0)
(3.0327701959704916, 0.0)
(2.835451290012061, 0.0)
(3.0152388156052834, 0.0)
(3.0543817028837155, 0.0)
(3.2982044649736846, 0.0)
(3.2207540985145253, 0.0)
(3.1392144695467854, 0.0)


In [15]:
class _netG_CIFAR10(nn.Module):
    def __init__(self, ngpu, nz):
        super(_netG_CIFAR10, self).__init__()
        self.ngpu = ngpu
        self.nz = nz

        # first linear layer
        self.fc1 = nn.Linear(110, 384)
        # Transposed Convolution 2
        self.tconv2 = nn.Sequential(
            nn.ConvTranspose2d(384, 192, 4, 1, 0, bias=False),
            nn.BatchNorm2d(192),
            nn.ReLU(True),
        )
        # Transposed Convolution 3
        self.tconv3 = nn.Sequential(
            nn.ConvTranspose2d(192, 96, 4, 2, 1, bias=False),
            nn.BatchNorm2d(96),
            nn.ReLU(True),
        )
        # Transposed Convolution 4
        self.tconv4 = nn.Sequential(
            nn.ConvTranspose2d(96, 48, 4, 2, 1, bias=False),
            nn.BatchNorm2d(48),
            nn.ReLU(True),
        )
        # Transposed Convolution 4
        self.tconv5 = nn.Sequential(
            nn.ConvTranspose2d(48, 3, 4, 2, 1, bias=False),
            nn.Tanh(),
        )
    def forward(self, input):
        if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
            input = input.view(-1, self.nz)
            fc1 = nn.parallel.data_parallel(self.fc1, input, range(self.ngpu))
            fc1 = fc1.view(-1, 384, 1, 1)
            tconv2 = nn.parallel.data_parallel(self.tconv2, fc1, range(self.ngpu))
            tconv3 = nn.parallel.data_parallel(self.tconv3, tconv2, range(self.ngpu))
            tconv4 = nn.parallel.data_parallel(self.tconv4, tconv3, range(self.ngpu))
            tconv5 = nn.parallel.data_parallel(self.tconv5, tconv4, range(self.ngpu))
            output = tconv5
        else:
            input = input.view(-1, self.nz)
            fc1 = self.fc1(input)
            fc1 = fc1.view(-1, 384, 1, 1)
            tconv2 = self.tconv2(fc1)
            tconv3 = self.tconv3(tconv2)
            tconv4 = self.tconv4(tconv3)
            tconv5 = self.tconv5(tconv4)
            output = tconv5
        return output
   
batch_size = 500

netG = _netG_CIFAR10(1, 110).to('cuda')
PATH = 'premade/netG_epoch_100.pth'
nz= 110
num_classes = 10

netG.load_state_dict(torch.load(PATH))
netG.eval()

input = torch.FloatTensor(batch_size, 3,32, 32)
noise = torch.FloatTensor(batch_size, nz, 1, 1)
eval_noise = torch.FloatTensor(batch_size, nz, 1, 1).normal_(0, 1)
dis_label = torch.FloatTensor(batch_size)
aux_label = torch.LongTensor(batch_size)
real_label = 1
fake_label = 0

eval_noise_ = np.random.normal(0, 1, (batch_size, nz))
eval_label = np.random.randint(0, num_classes, batch_size)
eval_onehot = np.zeros((batch_size, num_classes))
eval_onehot[np.arange(batch_size), eval_label] = 1
eval_noise_[np.arange(batch_size), :num_classes] = eval_onehot[np.arange(batch_size)]
eval_noise_ = (torch.from_numpy(eval_noise_))
eval_noise.data.copy_(eval_noise_.view(batch_size, nz, 1, 1))

noise.data.resize_(batch_size, nz, 1, 1).normal_(0, 1)
label = np.random.randint(0, num_classes, batch_size)
noise_ = np.random.normal(0, 1, (batch_size, nz))
class_onehot = np.zeros((batch_size, num_classes))
class_onehot[np.arange(batch_size), label] = 1
noise_[np.arange(batch_size), :num_classes] = class_onehot[np.arange(batch_size)]
noise_ = (torch.from_numpy(noise_))
noise.data.copy_(noise_.view(batch_size, nz, 1, 1))
aux_label.data.resize_(batch_size).copy_(torch.from_numpy(label))
noise = noise.to("cuda")
fake = netG(noise)

iss = inception_score(fake,resize = True)
print(iss)



(3.4297642099890937, 0.0)


In [6]:
avgmsssims = []
for epoch in range(0,550,50):
    epoch = str(epoch)
    net_g = oldNetG(110).to('cuda')
    PATH = 'original_model/models/netG_epoch_'+epoch+'.pth'

    net_g.load_state_dict(torch.load(PATH))
    net_g.eval()
    
    model_msssims = []
    for j in range(10):
        batch_size = 200
        noise = np.random.normal(0, 1, (batch_size, 100))
        np_gen_label = np.ones(batch_size)*j
        onehot = np.zeros((batch_size, 10))
        onehot[np.arange(batch_size), np_gen_label.astype('int')] = 1
        z = np.concatenate((noise, onehot), axis=1)
        z = torch.from_numpy(z).float().to('cuda')


        gen_imgs = net_g(z)
        gen_imgs = torch.clamp(gen_imgs, -1,1)

        gen_imgs = torch.unsqueeze(gen_imgs,1)

        ims = [i.cpu().detach().numpy() for i in gen_imgs]

        msssims = []
        for i in range(100):
            msssim = MultiScaleSSIM(ims[int(i*2)], ims[int(i*2)+1], 2)
            if msssim <1 and msssim>0:
                msssims.append(msssim)
        avg = np.average(msssims)
        model_msssims.append(avg)
    avgmsssims.append(model_msssims)
    print(np.average(model_msssims), np.std(model_msssims))



0.9759814052685813 0.0002860229374768393


  return (np.prod(mcs[0:levels-1] ** weights[0:levels-1]) *
  (mssim[levels-1] ** weights[levels-1]))


0.6388277528071316 0.06139795373997667
0.595976417319075 0.05681847116750847
0.5961661314803359 0.059277975068045115
0.6051750304621626 0.05680380883212092
0.5851483882665158 0.061651454372759006
0.5867779518003859 0.05004555419271459
0.5827149470800661 0.05952185603112699
0.5934325529576886 0.05880507057072949
0.6001077277801714 0.06569094071497811
0.5880500863399958 0.06034135064267417


In [21]:
avgmsssims = []
for epoch in range(0,550,50):
    epoch = str(epoch)
    net_g = newNetG(110).to('cuda')
    PATH = 'new_model/models/netG_epoch_'+epoch+'.pth'

    net_g.load_state_dict(torch.load(PATH))
    net_g.eval()

    model_msssims = []
    for j in range(10):
        batch_size = 200
        noise = np.random.normal(0, 1, (batch_size, 100))
        np_gen_label = np.ones(batch_size)*j
        onehot = np.zeros((batch_size, 10))
        onehot[np.arange(batch_size), np_gen_label.astype('int')] = 1
        z = np.concatenate((noise, onehot), axis=1)
        z = torch.from_numpy(z).float().to('cuda')


        gen_imgs = net_g(z)
        gen_imgs = torch.clamp(gen_imgs, -1,1)

        gen_imgs = torch.unsqueeze(gen_imgs,1)

        ims = [i.cpu().detach().numpy() for i in gen_imgs]

        msssims = []
        for i in range(100):
            msssim = MultiScaleSSIM(ims[int(i*2)], ims[int(i*2)+1], 2)
            if msssim <1 and msssim>0:
                msssims.append(msssim)
        avg = np.average(msssims)
        model_msssims.append(avg)
    avgmsssims.append(model_msssims)
    print(np.average(model_msssims), np.std(model_msssims))

0.9857179210336777 0.0002565075509477855
0.7882750255767357 0.06186495478357706
0.7994986122045458 0.05468561743246797
0.7938468413019109 0.07985497094817709
0.7863713294244428 0.07403534739045659
0.7784688819813901 0.058945102195056526
0.8058418902480018 0.06865579647816059
0.8030375860621118 0.07923368814750369
0.8223005283471216 0.08579992006643335
0.8028884247690687 0.08011746228864343
0.7856904546235237 0.05294809416163964
