In [1]:
import sys
sys.path.append('../utils')

import libraries
import utils_general
from utils_general import *

%matplotlib inline
import imageio

from __future__ import print_function
import argparse
import torch
import torch.nn.parallel

import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
import torchvision.utils as vutils
from torch.autograd import Variable

parser = argparse.ArgumentParser()
parser.add_argument('--dataset', required=True,
                    help='cifar10 | lsun | imagenet | folder | lfw ')
parser.add_argument('--dataroot', type=str, help='path to dataset')
parser.add_argument('--workers', type=int,
                    help='number of data loading workers', default=8)
parser.add_argument('--batch_size', type=int,
                    default=64, help='batch size')
parser.add_argument('--image_size', type=int, default=32,
                    help='the resolution of the input image to network')
parser.add_argument('--nz', type=int, default=100,
                    help='size of the latent z vector')
parser.add_argument('--ngf', type=int, default=64)
parser.add_argument('--ndf', type=int, default=64)
parser.add_argument('--nc', type=int)

parser.add_argument('--nepoch', type=int, default=25,
                    help='number of epochs to train for')
parser.add_argument('--lr', type=float, default=0.0002,
                    help='learning rate, default=0.0002')
parser.add_argument('--beta1', type=float, default=0.5,
                    help='beta1 for adam. default=0.5')
parser.add_argument('--cpu', action='store_true',
                    help='use CPU instead of GPU')
parser.add_argument('--ngpu', type=int, default=1,
                    help='number of GPUs to use')

parser.add_argument('--netG', default='',
                    help="path to netG config")
parser.add_argument('--netE', default='',
                    help="path to netE config")
parser.add_argument('--netG_chp', default='',
                    help="path to netG (to continue training)")
parser.add_argument('--netE_chp', default='',
                    help="path to netE (to continue training)")

parser.add_argument('--save_dir', default='.',
                    help='folder to output images and model checkpoints')
parser.add_argument('--criterion', default='param',
                    help='param|nonparam, How to estimate KL')
parser.add_argument('--KL', default='qp', help='pq|qp')
parser.add_argument('--noise', default='sphere', help='normal|sphere')
parser.add_argument('--match_z', default='cos', help='none|L1|L2|cos')
parser.add_argument('--match_x', default='L1', help='none|L1|L2|cos')

parser.add_argument('--drop_lr', default=5, type=int, help='')
parser.add_argument('--save_every', default=50, type=int, help='')

parser.add_argument('--manual_seed', type=int, default=123, help='manual seed')
parser.add_argument('--start_epoch', type=int, default=0, help='epoch number to start with')

parser.add_argument(
    '--e_updates', default="1;KL_fake:1,KL_real:1,match_z:0,match_x:0",
    help='Update plan for encoder <number of updates>;[<term:weight>]'
)

parser.add_argument(
    '--g_updates', default="2;KL_fake:1,match_z:1,match_x:0",
    help='Update plan for generator <number of updates>;[<term:weight>]'
)

# opt = parser.parse_args()

# python age.py 
# --dataset celeba 
# --dataroot <data_root> 
# --image_size 64 
# --save_dir <save_dir> 
# --lr 0.0002 
# --nz 64 
# --batch_size 64 
# --netG dcgan64px 
# --netE dcgan64px 
# --nepoch 5 
# --drop_lr 5 
# --e_updates '1;KL_fake:1,KL_real:1,match_z:0,match_x:10' 
# --g_updates '3;KL_fake:1,match_z:1000,match_x:0'

# python age.py 
# --dataset celeba 
# --dataroot <data_root> 
# --image_size 64 
# --save_dir <save_dir> 
# --start_epoch 5 
# --lr 0.0002 
# --nz 64 
# --batch_size 256 
# --netG dcgan64px 
# --netE dcgan64px 
# --nepoch 6 
# --drop_lr 5   
# --e_updates '1;KL_fake:1,KL_real:1,match_z:0,match_x:15' 
# --g_updates '3;KL_fake:1,match_z:1000,match_x:0' 
# --netE_chp  <save_dir>/netE_epoch_5.pth 
# --netG_chp <save_dir>/netG_epoch_5.pth

_StoreAction(option_strings=['--g_updates'], dest='g_updates', nargs=None, const=None, default='2;KL_fake:1,match_z:1,match_x:0', type=None, choices=None, help='Update plan for generator <number of updates>;[<term:weight>]', metavar=None)

In [2]:
# print(get_available_gpus())
print(psutil.virtual_memory())

svmem(total=64388890624, available=63350505472, percent=1.6, used=556662784, free=63168655360, active=497311744, inactive=469098496, buffers=49991680, cached=613580800, shared=9654272)


#### Setting up the options

In [3]:
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torch.utils.data
import importlib
# from .dataset import FolderWithImages
import random
import torch.backends.cudnn as cudnn
from PIL import Image

from datetime import datetime

In [35]:
# setup function @utils.py

cuda = True # not opt.cpu
torch.set_num_threads(4)
dataset = 'helen'
run = ''
save_dir = 'output/helen/'
try:
    os.makedirs(save_dir)
except OSError:
    print('Directory was not created.')

manual_seed = random.randint(1, 10000)
random.seed(manual_seed)
torch.manual_seed(manual_seed)
torch.cuda.manual_seed_all(manual_seed)

cudnn.benchmark = True

if torch.cuda.is_available() and not cuda:
    print("WARNING: You have a CUDA device,"
            "so you should probably run with --cuda")

e_updates = '1;KL_fake:1,KL_real:1,match_z:0,match_x:10' 
g_updates = '3;KL_fake:1,match_z:1000,match_x:1'
    
updates = {'e': {}, 'g': {}}
updates['e']['num_updates'] = int(e_updates.split(';')[0])
updates['e'].update({x.split(':')[0]: float(x.split(':')[1]) 
                     for x in e_updates.split(';')[1].split(',')})

updates['g']['num_updates'] = int(g_updates.split(';')[0])
updates['g'].update({x.split(':')[0]: float(x.split(':')[1]) 
                     for x in g_updates.split(';')[1].split(',')})


image_size = 64 # 512
batch_size = 128 # 64
workers = 16
shuffle = True
drop_last = True
train = True

nc = 3
nz = 64 #100 #'size of the latent z vector')
ngf = 64
ndf = 64
ngpu = 0 # 'number of GPUs to use')
pin_memory = True # True | If ``True``, the data loader will copy tensors into CUDA pinned memory before returning them.

noise = 'sphere' #help='normal|sphere')

# netG = 'dcgan64px' 
# netE = 'dcgan64px' 

netG_chp = 'output/helen/models_netG_last_epoch_64_128_15_dlr:100.pth' # "path to netG (to continue training)"
netE_chp = 'output/helen/models_netE_last_epoch_64_128_15_dlr:100.pth' # "path to netE (to continue training)"

lr = 0.0002
drop_lr = 25
beta1 = 0.5 # help='beta1 for adam. default=0.5'
criterion = 'param'# help='param|nonparam, How to estimate KL'
KL ='qp' # help='pq|qp'
match_z = 'cos' # help='none|L1|L2|cos'
match_x = 'L1' # help='none|L1|L2|cos'

start_epoch = 0
nepoch = 100

save_every_e = 9
save_every_b = None
dataroot = '../data/data'

now = str(datetime.now()).split('.')[0].split(' ')[0] + '_' + str(datetime.now()).split('.')[0].split(' ')[1] + '_Exg=1'

run = '%d_%d_%d_dlr:%d_%s' %(image_size, batch_size, nepoch, drop_lr, now); print(run)
# run = '%d_%d_%d_dlr:%d' %(image_size, batch_size, nepoch, drop_lr); print(run)

# --lr 0.0002 
# --nz 64 
# --batch_size 64 
# --netG dcgan64px 
# --netE dcgan64px 
# --nepoch 5 
# --drop_lr 5 
# --e_updates '1;KL_fake:1,KL_real:1,match_z:0,match_x:10' 
# --g_updates '3;KL_fake:1,match_z:1000,match_x:0'

Directory was not created.
64_128_100_dlr:25_2017-12-19_18:19:33_Exg=1


[Go to trainning](#train)
<a id='init'.</a>

In [36]:
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torch.utils.data
import importlib
import random
import os
import torch.backends.cudnn as cudnn
from PIL import Image
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

##
import torch.utils.data as data
from os import listdir
from os.path import join
from PIL import Image

def is_image_file(filename):
    return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"])

def load_img(filepath):
#     img = cv2.imread(filepath) #, cv2.IMREAD_UNCHANGED)
    img = Image.open(filepath).convert('RGB')
    return img

class FolderWithImages(data.Dataset):
    def __init__(self, root, input_transform=None, target_transform=None):
        super(FolderWithImages, self).__init__()
        self.image_filenames = [join(root, x)
                                for x in listdir(root) if is_image_file(x.lower())]

        self.input_transform = input_transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        input = load_img(self.image_filenames[index])
        target = input.copy()
        if self.input_transform:
            input = self.input_transform(input)
        if self.target_transform:
            target = self.target_transform(target)

        return input, target

    def __len__(self):
        return len(self.image_filenames)

    
def setup_dataset(dataset, dataroot, train=True, shuffle=True, drop_last=True):
    '''
    Setups dataset.
    '''
    # Usual transform
    t = transforms.Compose([
        transforms.Scale([image_size, image_size]),
        transforms.ToTensor(),
        transforms.Normalize(
            (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    if dataset in ['imagenet', 'folder', 'lfw']:
        imdir = 'train' if train else 'val'
        dataroot = os.path.join(dataroot, imdir)

        dataset = dset.ImageFolder(root=dataroot, transform=t)
    elif dataset == 'lsun':
        dataset = dset.LSUN(db_path=dataroot,
                            classes=['bedroom_train'],
                            train=train,
                            transform=t)
    elif dataset == 'cifar10':
        dataset = dset.CIFAR10(root='data/raw/cifar10',
                               download=True,
                               train=train,
                               transform=t
                               )
    elif dataset == 'mnist':
        dataset = dset.MNIST(root='data/raw/mnist',
                             download=True,
                             train=train,
                             transform=t
                             )
    elif dataset == 'svhn':
        dataset = dset.SVHN(root='data/raw/svhn',
                            download=True,
                            train=train,
                            transform=t)
    elif dataset == 'celeba':
        imdir = 'train' if train else 'val'
        dataroot = os.path.join(dataroot, imdir)

        dataset = FolderWithImages(root=dataroot,
                                   input_transform=transforms.Compose([
                                       ALICropAndScale(),
                                       transforms.ToTensor(),
                                       transforms.Normalize(
                                           (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                   ]),
                                   target_transform=transforms.ToTensor()
                                   )
    elif dataset == 'helen':
        imdir = 'train' if train else 'val'
        dataroot = os.path.join(dataroot, imdir)

        dataset = FolderWithImages(root=dataroot,
                                   input_transform=transforms.Compose([
                                       Scale(),
                                       transforms.ToTensor(),
                                       transforms.Normalize(
                                           (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                   ]),
                                   target_transform=transforms.ToTensor()
                                   )
    else:
        assert False, 'Wrong dataset name.'

    assert len(dataset) > 0, 'No images found, check your paths.'

    # Shuffle and drop last when training
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=batch_size,
                                             shuffle=shuffle,
                                             num_workers=int(workers),
                                             pin_memory=pin_memory,
                                             drop_last=drop_last)
#     print(len(dataloader))
#     for i in dataloader: print(i[0].shape, i[1].shape)
        
    return InfiniteDataLoader(dataloader)
#     return dataloader

class InfiniteDataLoader(object):
    """docstring for InfiniteDataLoader"""

    def __init__(self, dataloader):
        super(InfiniteDataLoader, self).__init__()
        self.dataloader = dataloader
        self.data_iter = None

    def next(self):
        try:
            data = self.data_iter.next()
        except Exception:
            # Reached end of the dataset
            self.data_iter = iter(self.dataloader)
            data = self.data_iter.next()

        return data

    def __len__(self):
        return len(self.dataloader)

class ALICropAndScale(object):
    def __call__(self, img):
        return img.resize((64, 78), Image.ANTIALIAS).crop((0, 7, 64, 64 + 7))
    
class Scale(object):
    def __call__(self, img):
#         img = cv2.resize(img, (64, 64), cv2.INTER_AREA)
#         img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
#         img = Image.fromarray(img)
        img = img.resize((image_size, image_size), Image.ANTIALIAS)
        return img 

# cv::INTER_AREA interpolation, whereas to
# .   enlarge an image, it will generally look best with cv::INTER_CUBIC (slow) or cv::INTER_LINEAR
# .   (faster but still looks OK).
    
# Setup dataset
dataloader = dict(train=setup_dataset(dataset, dataroot, train=True),
                  val=setup_dataset(dataset, dataroot, train=False))



In [37]:
# dl = setup_dataset(dataset, dataroot, train=True)
# # len(dataloader['val'])

# def save_image(tensor, filename, nrow=8, padding=2,
#                normalize=False, range=None, scale_each=False, pad_value=0):
#     """Save a given Tensor into an image file.
#     Args:
#         tensor (Tensor or list): Image to be saved. If given a mini-batch tensor,
#             saves the tensor as a grid of images by calling ``make_grid``.
#         **kwargs: Other arguments are documented in ``make_grid``.
#     """
#     from PIL import Image
#     grid = vutils.make_grid(tensor, nrow=nrow, padding=padding, pad_value=pad_value,
#                      normalize=normalize, range=range, scale_each=scale_each)
#     ndarr = grid.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy() # 
# #     im = Image.fromarray(ndarr)
#     im = Image.fromarray(ndarr[:,:,::-1])    
#     im.save(filename)


# def save_images(epoch):

#     real_cpu.resize_(x.data.size()).copy_(x.data)

#     # Real samples
#     save_path = '%s/real_samples.png' % save_dir
#     save_image(real_cpu[:64] / 2 + 0.5, save_path)

#     netG.eval()
#     fake = netG(fixed_z)

#     # Fake samples
#     save_path = '%s/fake_samples_epoch_%03d.png' % (save_dir, epoch)
#     save_image(fake.data[:64] / 2 + 0.5, save_path)

#     # Save reconstructions
#     populate_x(x, dataloader['val'])
#     gex = netG(netE(x)) # here - at G entry

#     t = torch.FloatTensor(x.size(0) * 2, x.size(1),
#                           x.size(2), x.size(3))

#     t[0::2] = x.data[:]
#     t[1::2] = gex.data[:]

#     save_path = '%s/reconstructions_epoch_%03d.png' % (save_dir, epoch)
#     grid = save_image(t[:64] / 2 + 0.5, save_path)
    
#     netG.train()


In [38]:
def normalize(x, dim=0):
    '''
    Projects points to a sphere.
    '''
    return x.div(x.norm(2, dim=dim).expand_as(x))

def normalize_(x, dim=0):
    '''
    Projects points to a sphere inplace.
    '''
    x.div_(x.norm(2, dim=dim).expand_as(x))

In [39]:
def weights_init(m):
    '''
    Custom weights initialization called on netG and netE
    '''
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

class _netE_Base(nn.Module):
    def __init__(self, main):
        super(_netE_Base, self).__init__()
        self.noise = noise
        self.ngpu = ngpu
        self.main = main

    def forward(self, input):
        gpu_ids = None
        if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 0:
            gpu_ids = range(self.ngpu)
            output = nn.parallel.data_parallel(self.main, input, gpu_ids)
        else:
            output = self.main(input)
        output = output.view(output.size(0), -1)
        if self.noise == 'sphere':
            output = normalize(output)
        return output
    
class _netG_Base(nn.Module):
    def __init__(self, main):
        super(_netG_Base, self).__init__()
        self.ngpu = ngpu
        self.main = main

    def forward(self, input):

        # Check input is either (B,C,1,1) or (B,C)
        assert input.nelement() == input.size(0) * input.size(1), 'wtf'
        input = input.view(input.size(0), input.size(1), 1, 1)

        gpu_ids = None
        if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 0:
            gpu_ids = range(self.ngpu)
            return nn.parallel.data_parallel(self.main, input, gpu_ids)
        else:
            return self.main(input)


def _netE():
#     ndf = opt.ndf
#     nc = opt.nc
#     nz = opt.nz

    main = nn.Sequential(
        # input is (nc) x 64 x 64
        nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
        nn.LeakyReLU(0.2, inplace=True),
        # state size. (ndf) x 16 x 16
        nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
        nn.BatchNorm2d(ndf * 2),
        nn.LeakyReLU(0.2, inplace=True),
        # state size. (ndf*2) x 8 x 8
        nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
        nn.BatchNorm2d(ndf * 4),
        nn.LeakyReLU(0.2, inplace=True),
        # state size. (ndf*4) x 4 x 4
        nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
        nn.BatchNorm2d(ndf * 8),
        nn.LeakyReLU(0.2, inplace=True),
        # state size. (ndf*8) x 4 x 4
        nn.Conv2d(ndf * 8, nz, 4, 1, 0, bias=True),
    )

    return _netE_Base(main)

def _netG():
#     ngf = opt.ngf
#     nc = opt.nc
#     nz = opt.nz

    main = nn.Sequential(
        # input is Z, going into a convolution
        nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
        nn.BatchNorm2d(ngf * 8),
        nn.ReLU(True),
        # state size. (ngf*8) x 4 x 4
        nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
        nn.BatchNorm2d(ngf * 4),
        nn.ReLU(True),
        # state size. (ngf*4) x 8 x 8
        nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
        nn.BatchNorm2d(ngf * 2),
        nn.ReLU(True),
        # state size. (ngf*2) x 16 x 16
        nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
        nn.BatchNorm2d(ngf),
        nn.ReLU(True),
        # state size. (ngf) x 32 x 32
        nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
        nn.Tanh()
        # state size. (nc) x 64 x 64
    )

    return _netG_Base(main)

def load_G():
    '''
    Loads generator model.
    '''
    netG = _netG()
    netG.apply(weights_init)
    netG.train()
    if netG_chp != '':
        netG.load_state_dict(torch.load(netG_chp).state_dict())

    print('Generator\n', netG)
    return netG

def load_E():
    '''
    Loads encoder model.
    '''
    netE = _netE()
    netE.apply(weights_init)
    netE.train()
    if netE_chp != '':
        netE.load_state_dict(torch.load(netE_chp).state_dict())

    print('Encoder\n', netE)

    return netE

# Load generator
netG = load_G()

# Load encoder
netE = load_E()

# RuntimeError: Given transposed=1, weight[64, 512, 4, 4], so expected input[64, 256, 1, 1] to have 64 channels,
# but got 256 channels instead

Generator
 _netG_Base(
  (main): Sequential(
    (0): ConvTranspose2d (64, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
    (2): ReLU(inplace)
    (3): ConvTranspose2d (512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
    (5): ReLU(inplace)
    (6): ConvTranspose2d (256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
    (8): ReLU(inplace)
    (9): ConvTranspose2d (128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
    (11): ReLU(inplace)
    (12): ConvTranspose2d (64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Tanh()
  )
)
Encoder
 _netE_Base(
  (main): Sequential(
    (0): Conv2d (3, 64, kernel_size=(4, 4), stride=(2, 2), paddin

#### Pretraining preparation

In [40]:
def var(x, dim=0):
    '''
    Calculates variance.
    '''
    x_zero_meaned = x - x.mean(dim).expand_as(x)
    return x_zero_meaned.pow(2).mean(dim)


class KLN01Loss(torch.nn.Module):

    def __init__(self, direction, minimize):
        super(KLN01Loss, self).__init__()
        self.minimize = minimize
        assert direction in ['pq', 'qp'], 'direction?'

        self.direction = direction

    def forward(self, samples):

        assert samples.nelement() == samples.size(1) * samples.size(0), 'wtf?'

        samples = samples.view(samples.size(0), -1)

        self.samples_var = var(samples)
        self.samples_mean = samples.mean(0)

        samples_mean = self.samples_mean
        samples_var = self.samples_var

        if self.direction == 'pq':
            # mu_1 = 0; sigma_1 = 1

            t1 = (1 + samples_mean.pow(2)) / (2 * samples_var.pow(2))
            t2 = samples_var.log()

            KL = (t1 + t2 - 0.5).mean()
        else:
            # mu_2 = 0; sigma_2 = 1

            t1 = (samples_var.pow(2) + samples_mean.pow(2)) / 2
            t2 = -samples_var.log()

            KL = (t1 + t2 - 0.5).mean()

        if not self.minimize:
            KL *= -1

        return KL

# ----------------------------

def pairwise_euclidean(samples):

    B = samples.size(0)

    samples_norm = samples.mul(samples).sum(1)
    samples_norm = samples_norm.expand(B, B)

    dist_mat = samples.mm(samples.t()).mul(-2) + \
        samples_norm.add(samples_norm.t())
    return dist_mat

def sample_entropy(samples):

        # Assume B x C input

    dist_mat = pairwise_euclidean(samples)

    # Get max and add it to diag
    m = dist_mat.max().detach()
    dist_mat_d = dist_mat + \
        Variable(torch.eye(dist_mat.size(0)) * (m.data[0] + 1)).cuda()

    entropy = (dist_mat_d.min(1)[0] + 1e-4).log().sum()

    entropy *= (samples.size(1) + 0.) / samples.size(0)

    return entropy

class SampleKLN01Loss(torch.nn.Module):

    def __init__(self, direction, minimize):
        super(SampleKLN01Loss, self).__init__()
        self.minimize = minimize
        assert direction in ['pq', 'qp'], 'direction?'

        self.direction = direction

    def forward(self, samples):

        assert samples.ndimension == 2, 'wft'
        samples = samples.view(samples.size(0), -1)

        self.samples_var = var(samples)
        self.samples_mean = samples.mean(0)

        if self.direction == 'pq':
            assert False, 'not possible'
        else:
            entropy = sample_entropy(samples)

            cross_entropy = - samples.pow(2).mean() / 2.

            KL = - cross_entropy - entropy

        if not self.minimize:
            KL *= -1

        return KL

In [41]:
# x = torch.FloatTensor(batch_size, nc, image_size, image_size)
# z = torch.FloatTensor(batch_size, nz, 1, 1)
# fixed_z = torch.FloatTensor(batch_size, nz, 1, 1).normal_(0, 1)
# z = fixed_z

In [42]:
x = torch.FloatTensor(batch_size, nc, image_size, image_size)
z = torch.FloatTensor(batch_size, nz, 1, 1)
fixed_z = torch.FloatTensor(batch_size, nz, 1, 1).normal_(0, 1)

def match(x, y, dist):
    '''
    Computes distance between corresponding points in `x` and `y`
    using distance `dist`.
    '''
    if dist == 'L2':
        return (x - y).pow(2).mean()
    elif dist == 'L1':
        return (x - y).abs().mean()
    elif dist == 'cos':
        x_n = normalize(x)
        y_n = normalize(y)

        return 2 - (x_n).mul(y_n).mean()
    else:
        assert dist == 'none', 'wtf ?'

def normalize(x, dim=0):
    '''
    Projects points to a sphere.
    '''
    return x.div(x.norm(2, dim=dim).expand_as(x))

def normalize_(x, dim=0):
    '''
    Projects points to a sphere inplace.
    '''
    x.div_(x.norm(2, dim=dim).expand_as(x))

if noise == 'sphere':
    normalize_(fixed_z)

if cuda:
    netE.cuda()
    netG.cuda()
    x = x.cuda()
    z, fixed_z = z.cuda(), fixed_z.cuda()

x = Variable(x)
z = Variable(z)
fixed_z = Variable(fixed_z)

# Setup optimizers
optimizerD = optim.Adam(netE.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

# Setup criterions
if criterion == 'param':
    print('Using parametric criterion KL_%s' % KL)
    KL_minimizer = KLN01Loss(direction=KL, minimize=True)
    KL_maximizer = KLN01Loss(direction=KL, minimize=False)
elif criterion == 'nonparam':
    print('Using NON-parametric criterion KL_%s' % KL)
    KL_minimizer = SampleKLN01Loss(direction=KL, minimize=True)
    KL_maximizer = SampleKLN01Loss(direction=KL, minimize=False)
else:
    assert False, 'criterion?'

real_cpu = torch.FloatTensor()


def save_image(tensor, filename, nrow=8, padding=2,
               normalize=False, range=None, scale_each=False, pad_value=0):
    """Save a given Tensor into an image file.
    Args:
        tensor (Tensor or list): Image to be saved. If given a mini-batch tensor,
            saves the tensor as a grid of images by calling ``make_grid``.
        **kwargs: Other arguments are documented in ``make_grid``.
    """
    from PIL import Image
    grid = vutils.make_grid(tensor, nrow=nrow, padding=padding, pad_value=pad_value,
                     normalize=normalize, range=range, scale_each=scale_each)
    ndarr = grid.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy() # 
    im = Image.fromarray(ndarr)
#     im = Image.fromarray(ndarr[:,:,::-1])    
    im.save(filename)


def save_images(epoch):

    real_cpu.resize_(x.data.size()).copy_(x.data)

    # Real samples
    save_path = '%s/real_samples.png' % save_dir
    save_image(real_cpu[:64] / 2 + 0.5, save_path)

    netG.eval()
    fake = netG(fixed_z)

    # Fake samples
    save_path = '%s/fake_samples_epoch_%03d.png' % (save_dir, epoch)
    save_image(fake.data[:64] / 2 + 0.5, save_path)

    # Save reconstructions
    populate_x(x, dataloader['val'])
    gex = netG(netE(x)) # here - at G entry

    t = torch.FloatTensor(x.size(0) * 2, x.size(1),
                          x.size(2), x.size(3))

    t[0::2] = x.data[:]
    t[1::2] = gex.data[:]

    save_path = '%s/reconstructions_epoch_%03d.png' % (save_dir, epoch)
    grid = save_image(t[:64] / 2 + 0.5, save_path)
    
    netG.train()

def adjust_lr(epoch, drop_coef=2):
    if (epoch + 1) % drop_lr == 0:
        ###
        assert optimizerD.param_groups[0]['lr'] == optimizerG.param_groups[0]['lr']
        lr = optimizerD.param_groups[0]['lr']
        print('Adjusting learning rate from %f to %f on E and G' % (lr, lr / drop_coef))
        optimizerD.param_groups[0]['lr'] /= drop_coef
        optimizerG.param_groups[0]['lr'] /= drop_coef
        ###
#         lr /= 2
#         for param_group in optimizerD.param_groups:
#             param_group['lr'] = lr

#         for param_group in optimizerG.param_groups:
#             param_group['lr'] = lr

Using parametric criterion KL_qp


In [43]:
def populate_x(x, dataloader):
    '''
    Fills input variable `x` with data generated with dataloader
    '''
    real_cpu, _ = dataloader.next()
    x.data.resize_(real_cpu.size()).copy_(real_cpu)

def populate_z(z):
    '''
    Fills noise variable `z` with noise U(S^M)
    '''
    z.data.resize_(batch_size, nz, 1, 1)
    z.data.normal_(0, 1)
    if noise == 'sphere':
        normalize_(z.data)

In [44]:
# print(len(dataloader['train'].next()))
# print(len(dataloader['val'].next()))
# print(dataloader['train'].next()[0].shape, dataloader['train'].next()[1].shape)
print(dataloader['train'].next()[0].shape, dataloader['train'].next()[1].shape)

# populate_x(x, dataloader['val'])
# nepoch = 10
# lr = 0.0002

sim = dataloader['train'].next()[0][0]


torch.Size([128, 3, 64, 64]) torch.Size([128, 3, 512, 512])


In [None]:
stats = {}
losses = {'KL_real': [],
         'Ex_e': [],
         'KL_fake': [], 
         'Ez_e': [],
         'KL_fake_g': [],
         'Ex_g': [],
         'Ez_g_10e3': [],
        }

for epoch in range(start_epoch, nepoch):

    for i in range(len(dataloader['train'])):

        # ---------------------------
        #        Optimize over e
        # ---------------------------
        
        # e_updates = '1; KL_fake:1, KL_real:1, match_z:0, match_x:10' 
        
        KL_real_ = []
        KL_fake_ = []
        Ex_e_ = []
        Ez_e_ = []

        for e_iter in range(updates['e']['num_updates']):
            e_losses = []
            netE.zero_grad()

            # X - e(X)
            populate_x(x, dataloader['train'])
            ex = netE(x)

            # KL_real: - \Delta( e(X) , Z ) -> max_e
            KL_real = KL_minimizer(ex)
            e_losses.append(KL_real * updates['e']['KL_real'])
            # === stats ===    
            KL_real_.append(KL_real.data[0])
            
            if updates['e']['match_x'] != 0:
                # g(e(X))
                gex = netG(ex)

                # match_x: E_x||g(e(x)) - x|| -> min_e
                err = match(gex, x, match_x)
                e_losses.append(err * updates['e']['match_x'])
                # === stats === 
                Ex_e_.append(err.data[0])

            # Save some stats
            stats['real_mean'] = KL_minimizer.samples_mean.data.mean()
            stats['real_var'] = KL_minimizer.samples_var.data.mean()
            stats['KL_real'] = KL_real.data[0]

            # ================================================

            # Z - g(Z) - e(g(Z))
            populate_z(z)
            fake = netG(z).detach()
            egz = netE(fake)

            # KL_fake: \Delta( e(g(Z)) , Z ) -> max_e
            KL_fake = KL_maximizer(egz)
            e_losses.append(KL_fake * updates['e']['KL_fake'])
            # === stats === 
            KL_fake_.append(-KL_fake.data[0])
            
            if updates['e']['match_z'] != 0:
                # match_z: E_z||e(g(z)) - z|| -> min_e
                err = match(egz, z, match_z)
                e_losses.append(err * updates['e']['match_z'])
                # === stats === 
                Ez_e_.append(err.data[0])
            
            # Update e
            sum(e_losses).backward()
            optimizerD.step()

            # === stats === 
            stats['fake_mean'] = KL_maximizer.samples_mean.data.mean()
            stats['fake_var'] = KL_maximizer.samples_var.data.mean()
            stats['KL_fake'] = -KL_fake.data[0]
            
            if KL_real_: losses['KL_real'].append(np.mean(KL_real_))
            if KL_fake_: losses['KL_fake'].append(np.mean(KL_fake_))
            if Ex_e_: losses['Ex_e'].append(np.mean(Ex_e_))
            if Ez_e_: losses['Ez_e'].append(np.mean(Ez_e_))
            
        # ---------------------------
        #        Minimize over g
        # ---------------------------
        
        # g_updates = '3; KL_fake:1, match_z:1000, match_x:0'

        KL_fake_g_ = []
        Ex_g_ = []
        Ez_g_ = []
        
        for g_iter in range(updates['g']['num_updates']):
            g_losses = []
            netG.zero_grad()

            # Z - g(Z) - e(g(Z))
            populate_z(z)
            fake = netG(z)
            egz = netE(fake)

            # KL_fake: \Delta( e(g(Z)) , Z ) -> min_g
            KL_fake_g = KL_minimizer(egz)
            g_losses.append(KL_fake_g * updates['g']['KL_fake'])
            # === stats === 
            KL_fake_g_.append(KL_fake_g.data[0])
            
            if updates['g']['match_z'] != 0:
                # match_z: E_z||e(g(z)) - z|| -> min_g
                err = match(egz, z, match_z)
                err = err * updates['g']['match_z']
                g_losses.append(err)
                # === stats === 
#                 Ez_g_.append(err.data[0])
                Ez_g_.append(err.data[0] / 1000)
                
            # ==================================

            if updates['g']['match_x'] != 0:
                # X - e(X) - g(e(X))
                populate_x(x, dataloader['train'])
                ex = netE(x)
                gex = netG(ex)

                # match_x: E_x||g(e(x)) - x|| -> min_g
                err = match(gex, x, match_x)
                err = err * updates['g']['match_x']
                g_losses.append(err)
                # === stats === 
                Ex_g_.append(err.data[0])

            # Step g
            sum(g_losses).backward()
            optimizerG.step()
        
        # === stats === 
        if KL_fake_g_: losses['KL_fake_g'].append(np.mean(KL_fake_g_))
        if Ez_g_: losses['Ez_g_10e3'].append(np.mean(Ez_g_))
        if Ex_g_: losses['Ex_g'].append(np.mean(Ex_g_))
        
        # === stdout === 
        print('[{epoch}/{nepoch}][{iter}/{niter}] '
              'KL_real/fake: {KL_real:.3f}/{KL_fake:.3f} '
              'mean_real/fake: {real_mean:.3f}/{fake_mean:.3f} '
              'var_real/fake: {real_var:.3f}/{fake_var:.3f} '
              ''.format(epoch=epoch,
                        nepoch=nepoch,
                        iter=i,
                        niter=len(dataloader['train']),
                        **stats))
        
        
# #         if save_every_b:
# #             if i % save_every_b == 0: save_images(epoch)
            
#         # If an epoch takes long time, dump intermediate
#         if dataset in ['lsun', 'imagenet'] and (i % 5000 == 0):
#             torch.save(netG, '%s/netG_epoch_%d_it_%d.pth' %
#                        (save_dir, epoch, i))
#             torch.save(netE, '%s/netE_epoch_%d_it_%d.pth' %
#                        (save_dir, epoch, i))

    # === adjusting learning rate ===
    adjust_lr(epoch)
    
    # === saving === 
    if epoch % save_every_e == 0: 
        save_images(epoch)
        # do checkpointing
        torch.save(netG, '%s/netG_epoch_%d.pth' % (save_dir, epoch))
        torch.save(netE, '%s/netE_epoch_%d.pth' % (save_dir, epoch))

[0/100][0/43] KL_real/fake: 4.570/4.543 mean_real/fake: -0.003/-0.004 var_real/fake: 0.006/0.007 
[0/100][1/43] KL_real/fake: 4.607/4.572 mean_real/fake: -0.006/-0.005 var_real/fake: 0.006/0.007 
[0/100][2/43] KL_real/fake: 4.614/4.586 mean_real/fake: -0.008/-0.005 var_real/fake: 0.006/0.006 
[0/100][3/43] KL_real/fake: 4.599/4.567 mean_real/fake: -0.006/-0.007 var_real/fake: 0.006/0.007 
[0/100][4/43] KL_real/fake: 4.603/4.578 mean_real/fake: -0.005/-0.007 var_real/fake: 0.006/0.006 
[0/100][5/43] KL_real/fake: 4.602/4.572 mean_real/fake: -0.003/-0.005 var_real/fake: 0.006/0.007 
[0/100][6/43] KL_real/fake: 4.612/4.582 mean_real/fake: -0.004/-0.004 var_real/fake: 0.006/0.007 
[0/100][7/43] KL_real/fake: 4.564/4.543 mean_real/fake: -0.002/-0.001 var_real/fake: 0.007/0.007 
[0/100][8/43] KL_real/fake: 4.543/4.532 mean_real/fake: 0.001/-0.002 var_real/fake: 0.007/0.007 
[0/100][9/43] KL_real/fake: 4.552/4.530 mean_real/fake: -0.002/-0.005 var_real/fake: 0.007/0.007 
[0/100][10/43] KL_rea

  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


[1/100][0/43] KL_real/fake: 4.686/4.637 mean_real/fake: -0.005/-0.003 var_real/fake: 0.006/0.006 
[1/100][1/43] KL_real/fake: 4.651/4.620 mean_real/fake: -0.003/-0.002 var_real/fake: 0.006/0.006 
[1/100][2/43] KL_real/fake: 4.652/4.610 mean_real/fake: 0.001/-0.004 var_real/fake: 0.006/0.006 
[1/100][3/43] KL_real/fake: 4.619/4.597 mean_real/fake: -0.001/-0.004 var_real/fake: 0.006/0.006 
[1/100][4/43] KL_real/fake: 4.612/4.578 mean_real/fake: -0.004/-0.002 var_real/fake: 0.006/0.006 
[1/100][5/43] KL_real/fake: 4.581/4.544 mean_real/fake: -0.004/-0.003 var_real/fake: 0.007/0.007 
[1/100][6/43] KL_real/fake: 4.558/4.541 mean_real/fake: -0.002/-0.002 var_real/fake: 0.007/0.007 
[1/100][7/43] KL_real/fake: 4.564/4.539 mean_real/fake: -0.004/-0.004 var_real/fake: 0.007/0.007 
[1/100][8/43] KL_real/fake: 4.578/4.559 mean_real/fake: -0.004/-0.002 var_real/fake: 0.007/0.007 
[1/100][9/43] KL_real/fake: 4.611/4.578 mean_real/fake: -0.003/-0.001 var_real/fake: 0.006/0.006 
[1/100][10/43] KL_rea

[2/100][41/43] KL_real/fake: 4.612/4.598 mean_real/fake: -0.003/-0.003 var_real/fake: 0.006/0.006 
[2/100][42/43] KL_real/fake: 4.606/4.594 mean_real/fake: -0.006/-0.005 var_real/fake: 0.006/0.006 
[3/100][0/43] KL_real/fake: 4.626/4.614 mean_real/fake: -0.005/-0.008 var_real/fake: 0.006/0.006 
[3/100][1/43] KL_real/fake: 4.637/4.618 mean_real/fake: -0.005/-0.004 var_real/fake: 0.006/0.006 
[3/100][2/43] KL_real/fake: 4.640/4.618 mean_real/fake: -0.006/-0.004 var_real/fake: 0.006/0.006 
[3/100][3/43] KL_real/fake: 4.652/4.618 mean_real/fake: -0.006/-0.005 var_real/fake: 0.006/0.006 
[3/100][4/43] KL_real/fake: 4.668/4.632 mean_real/fake: -0.010/-0.009 var_real/fake: 0.006/0.006 
[3/100][5/43] KL_real/fake: 4.667/4.634 mean_real/fake: -0.010/-0.008 var_real/fake: 0.006/0.006 
[3/100][6/43] KL_real/fake: 4.660/4.630 mean_real/fake: -0.009/-0.009 var_real/fake: 0.006/0.006 
[3/100][7/43] KL_real/fake: 4.649/4.626 mean_real/fake: -0.011/-0.007 var_real/fake: 0.006/0.006 
[3/100][8/43] KL_r

[4/100][38/43] KL_real/fake: 4.571/4.547 mean_real/fake: -0.005/-0.005 var_real/fake: 0.007/0.007 
[4/100][39/43] KL_real/fake: 4.568/4.552 mean_real/fake: -0.006/-0.004 var_real/fake: 0.007/0.007 
[4/100][40/43] KL_real/fake: 4.575/4.560 mean_real/fake: -0.006/-0.004 var_real/fake: 0.006/0.007 
[4/100][41/43] KL_real/fake: 4.576/4.565 mean_real/fake: -0.005/-0.003 var_real/fake: 0.006/0.007 
[4/100][42/43] KL_real/fake: 4.580/4.563 mean_real/fake: -0.003/-0.006 var_real/fake: 0.006/0.007 
[5/100][0/43] KL_real/fake: 4.588/4.572 mean_real/fake: -0.007/-0.005 var_real/fake: 0.006/0.006 
[5/100][1/43] KL_real/fake: 4.580/4.567 mean_real/fake: -0.007/-0.004 var_real/fake: 0.006/0.007 
[5/100][2/43] KL_real/fake: 4.591/4.570 mean_real/fake: -0.007/-0.009 var_real/fake: 0.006/0.007 
[5/100][3/43] KL_real/fake: 4.579/4.566 mean_real/fake: -0.006/-0.007 var_real/fake: 0.006/0.007 
[5/100][4/43] KL_real/fake: 4.570/4.549 mean_real/fake: -0.008/-0.006 var_real/fake: 0.007/0.007 
[5/100][5/43] K

[6/100][35/43] KL_real/fake: 4.669/4.644 mean_real/fake: -0.008/-0.008 var_real/fake: 0.006/0.006 
[6/100][36/43] KL_real/fake: 4.665/4.634 mean_real/fake: -0.007/-0.006 var_real/fake: 0.006/0.006 
[6/100][37/43] KL_real/fake: 4.652/4.628 mean_real/fake: -0.005/-0.008 var_real/fake: 0.006/0.006 
[6/100][38/43] KL_real/fake: 4.655/4.629 mean_real/fake: -0.007/-0.007 var_real/fake: 0.006/0.006 
[6/100][39/43] KL_real/fake: 4.667/4.640 mean_real/fake: -0.009/-0.009 var_real/fake: 0.006/0.006 
[6/100][40/43] KL_real/fake: 4.675/4.648 mean_real/fake: -0.008/-0.006 var_real/fake: 0.006/0.006 
[6/100][41/43] KL_real/fake: 4.668/4.647 mean_real/fake: -0.007/-0.005 var_real/fake: 0.006/0.006 
[6/100][42/43] KL_real/fake: 4.662/4.647 mean_real/fake: -0.007/-0.005 var_real/fake: 0.006/0.006 
[7/100][0/43] KL_real/fake: 4.656/4.637 mean_real/fake: -0.005/-0.007 var_real/fake: 0.006/0.006 
[7/100][1/43] KL_real/fake: 4.655/4.639 mean_real/fake: -0.002/-0.004 var_real/fake: 0.006/0.006 
[7/100][2/43

[8/100][32/43] KL_real/fake: 4.617/4.602 mean_real/fake: -0.006/-0.006 var_real/fake: 0.006/0.006 
[8/100][33/43] KL_real/fake: 4.622/4.611 mean_real/fake: -0.008/-0.005 var_real/fake: 0.006/0.006 
[8/100][34/43] KL_real/fake: 4.637/4.621 mean_real/fake: -0.007/-0.005 var_real/fake: 0.006/0.006 
[8/100][35/43] KL_real/fake: 4.649/4.633 mean_real/fake: -0.007/-0.004 var_real/fake: 0.006/0.006 
[8/100][36/43] KL_real/fake: 4.647/4.640 mean_real/fake: -0.003/-0.002 var_real/fake: 0.006/0.006 
[8/100][37/43] KL_real/fake: 4.654/4.630 mean_real/fake: -0.002/-0.004 var_real/fake: 0.006/0.006 
[8/100][38/43] KL_real/fake: 4.653/4.625 mean_real/fake: -0.002/-0.002 var_real/fake: 0.006/0.006 
[8/100][39/43] KL_real/fake: 4.641/4.625 mean_real/fake: -0.001/-0.004 var_real/fake: 0.006/0.006 
[8/100][40/43] KL_real/fake: 4.641/4.618 mean_real/fake: -0.003/-0.006 var_real/fake: 0.006/0.006 
[8/100][41/43] KL_real/fake: 4.629/4.602 mean_real/fake: -0.005/-0.003 var_real/fake: 0.006/0.006 
[8/100][42

[10/100][29/43] KL_real/fake: 4.671/4.644 mean_real/fake: -0.004/-0.005 var_real/fake: 0.006/0.006 
[10/100][30/43] KL_real/fake: 4.651/4.641 mean_real/fake: -0.003/-0.005 var_real/fake: 0.006/0.006 
[10/100][31/43] KL_real/fake: 4.649/4.632 mean_real/fake: -0.002/-0.005 var_real/fake: 0.006/0.006 
[10/100][32/43] KL_real/fake: 4.658/4.636 mean_real/fake: -0.005/-0.005 var_real/fake: 0.006/0.006 
[10/100][33/43] KL_real/fake: 4.642/4.624 mean_real/fake: -0.006/-0.004 var_real/fake: 0.006/0.006 
[10/100][34/43] KL_real/fake: 4.637/4.621 mean_real/fake: -0.006/-0.004 var_real/fake: 0.006/0.006 
[10/100][35/43] KL_real/fake: 4.631/4.612 mean_real/fake: -0.003/-0.004 var_real/fake: 0.006/0.006 
[10/100][36/43] KL_real/fake: 4.626/4.609 mean_real/fake: -0.003/-0.003 var_real/fake: 0.006/0.006 
[10/100][37/43] KL_real/fake: 4.629/4.600 mean_real/fake: -0.003/-0.003 var_real/fake: 0.006/0.006 
[10/100][38/43] KL_real/fake: 4.626/4.605 mean_real/fake: -0.005/-0.003 var_real/fake: 0.006/0.006 


[12/100][26/43] KL_real/fake: 4.610/4.600 mean_real/fake: -0.008/-0.007 var_real/fake: 0.006/0.006 
[12/100][27/43] KL_real/fake: 4.619/4.599 mean_real/fake: -0.009/-0.008 var_real/fake: 0.006/0.006 
[12/100][28/43] KL_real/fake: 4.614/4.590 mean_real/fake: -0.008/-0.007 var_real/fake: 0.006/0.006 
[12/100][29/43] KL_real/fake: 4.603/4.584 mean_real/fake: -0.008/-0.007 var_real/fake: 0.006/0.006 
[12/100][30/43] KL_real/fake: 4.596/4.583 mean_real/fake: -0.006/-0.005 var_real/fake: 0.006/0.006 
[12/100][31/43] KL_real/fake: 4.585/4.576 mean_real/fake: -0.005/-0.004 var_real/fake: 0.006/0.006 
[12/100][32/43] KL_real/fake: 4.587/4.571 mean_real/fake: -0.005/-0.006 var_real/fake: 0.006/0.006 
[12/100][33/43] KL_real/fake: 4.592/4.573 mean_real/fake: -0.007/-0.004 var_real/fake: 0.006/0.006 
[12/100][34/43] KL_real/fake: 4.600/4.585 mean_real/fake: -0.007/-0.007 var_real/fake: 0.006/0.006 
[12/100][35/43] KL_real/fake: 4.616/4.594 mean_real/fake: -0.007/-0.006 var_real/fake: 0.006/0.006 


[14/100][23/43] KL_real/fake: 4.615/4.593 mean_real/fake: -0.003/-0.006 var_real/fake: 0.006/0.006 
[14/100][24/43] KL_real/fake: 4.621/4.592 mean_real/fake: -0.005/-0.004 var_real/fake: 0.006/0.006 
[14/100][25/43] KL_real/fake: 4.630/4.599 mean_real/fake: -0.007/-0.005 var_real/fake: 0.006/0.006 
[14/100][26/43] KL_real/fake: 4.623/4.591 mean_real/fake: -0.008/-0.005 var_real/fake: 0.006/0.006 
[14/100][27/43] KL_real/fake: 4.631/4.605 mean_real/fake: -0.008/-0.006 var_real/fake: 0.006/0.006 
[14/100][28/43] KL_real/fake: 4.625/4.606 mean_real/fake: -0.007/-0.008 var_real/fake: 0.006/0.006 
[14/100][29/43] KL_real/fake: 4.642/4.610 mean_real/fake: -0.006/-0.005 var_real/fake: 0.006/0.006 
[14/100][30/43] KL_real/fake: 4.635/4.619 mean_real/fake: -0.004/-0.003 var_real/fake: 0.006/0.006 
[14/100][31/43] KL_real/fake: 4.640/4.625 mean_real/fake: -0.003/-0.006 var_real/fake: 0.006/0.006 
[14/100][32/43] KL_real/fake: 4.645/4.633 mean_real/fake: -0.002/-0.003 var_real/fake: 0.006/0.006 


[16/100][20/43] KL_real/fake: 4.637/4.596 mean_real/fake: -0.009/-0.006 var_real/fake: 0.006/0.006 
[16/100][21/43] KL_real/fake: 4.636/4.610 mean_real/fake: -0.008/-0.009 var_real/fake: 0.006/0.006 
[16/100][22/43] KL_real/fake: 4.639/4.616 mean_real/fake: -0.006/-0.008 var_real/fake: 0.006/0.006 
[16/100][23/43] KL_real/fake: 4.639/4.615 mean_real/fake: -0.007/-0.005 var_real/fake: 0.006/0.006 
[16/100][24/43] KL_real/fake: 4.651/4.622 mean_real/fake: -0.009/-0.005 var_real/fake: 0.006/0.006 
[16/100][25/43] KL_real/fake: 4.642/4.620 mean_real/fake: -0.007/-0.006 var_real/fake: 0.006/0.006 
[16/100][26/43] KL_real/fake: 4.648/4.617 mean_real/fake: -0.004/-0.004 var_real/fake: 0.006/0.006 
[16/100][27/43] KL_real/fake: 4.634/4.616 mean_real/fake: -0.002/-0.003 var_real/fake: 0.006/0.006 
[16/100][28/43] KL_real/fake: 4.636/4.619 mean_real/fake: 0.000/-0.002 var_real/fake: 0.006/0.006 
[16/100][29/43] KL_real/fake: 4.629/4.614 mean_real/fake: -0.001/-0.005 var_real/fake: 0.006/0.006 
[

[18/100][17/43] KL_real/fake: 4.618/4.599 mean_real/fake: -0.005/-0.005 var_real/fake: 0.006/0.006 
[18/100][18/43] KL_real/fake: 4.610/4.601 mean_real/fake: -0.005/-0.005 var_real/fake: 0.006/0.006 
[18/100][19/43] KL_real/fake: 4.616/4.593 mean_real/fake: -0.005/-0.005 var_real/fake: 0.006/0.006 
[18/100][20/43] KL_real/fake: 4.615/4.593 mean_real/fake: -0.007/-0.004 var_real/fake: 0.006/0.006 
[18/100][21/43] KL_real/fake: 4.617/4.599 mean_real/fake: -0.008/-0.006 var_real/fake: 0.006/0.006 
[18/100][22/43] KL_real/fake: 4.619/4.601 mean_real/fake: -0.008/-0.007 var_real/fake: 0.006/0.006 
[18/100][23/43] KL_real/fake: 4.609/4.597 mean_real/fake: -0.007/-0.005 var_real/fake: 0.006/0.006 
[18/100][24/43] KL_real/fake: 4.610/4.598 mean_real/fake: -0.005/-0.007 var_real/fake: 0.006/0.006 
[18/100][25/43] KL_real/fake: 4.609/4.594 mean_real/fake: -0.006/-0.006 var_real/fake: 0.006/0.006 
[18/100][26/43] KL_real/fake: 4.605/4.590 mean_real/fake: -0.006/-0.005 var_real/fake: 0.006/0.006 


[20/100][14/43] KL_real/fake: 4.617/4.607 mean_real/fake: -0.006/-0.005 var_real/fake: 0.006/0.006 
[20/100][15/43] KL_real/fake: 4.617/4.607 mean_real/fake: -0.006/-0.004 var_real/fake: 0.006/0.006 
[20/100][16/43] KL_real/fake: 4.614/4.600 mean_real/fake: -0.007/-0.006 var_real/fake: 0.006/0.006 
[20/100][17/43] KL_real/fake: 4.614/4.604 mean_real/fake: -0.004/-0.005 var_real/fake: 0.006/0.006 
[20/100][18/43] KL_real/fake: 4.609/4.599 mean_real/fake: -0.005/-0.006 var_real/fake: 0.006/0.006 
[20/100][19/43] KL_real/fake: 4.607/4.600 mean_real/fake: -0.005/-0.006 var_real/fake: 0.006/0.006 
[20/100][20/43] KL_real/fake: 4.622/4.592 mean_real/fake: -0.006/-0.004 var_real/fake: 0.006/0.006 
[20/100][21/43] KL_real/fake: 4.612/4.591 mean_real/fake: -0.006/-0.003 var_real/fake: 0.006/0.006 
[20/100][22/43] KL_real/fake: 4.600/4.586 mean_real/fake: -0.005/-0.006 var_real/fake: 0.006/0.006 
[20/100][23/43] KL_real/fake: 4.598/4.582 mean_real/fake: -0.005/-0.004 var_real/fake: 0.006/0.006 


[22/100][11/43] KL_real/fake: 4.605/4.593 mean_real/fake: -0.002/-0.005 var_real/fake: 0.006/0.006 
[22/100][12/43] KL_real/fake: 4.603/4.594 mean_real/fake: -0.004/-0.004 var_real/fake: 0.006/0.006 
[22/100][13/43] KL_real/fake: 4.602/4.591 mean_real/fake: -0.005/-0.004 var_real/fake: 0.006/0.006 
[22/100][14/43] KL_real/fake: 4.604/4.591 mean_real/fake: -0.004/-0.006 var_real/fake: 0.006/0.006 
[22/100][15/43] KL_real/fake: 4.604/4.593 mean_real/fake: -0.004/-0.005 var_real/fake: 0.006/0.006 
[22/100][16/43] KL_real/fake: 4.601/4.585 mean_real/fake: -0.002/-0.005 var_real/fake: 0.006/0.006 
[22/100][17/43] KL_real/fake: 4.589/4.577 mean_real/fake: -0.003/-0.005 var_real/fake: 0.006/0.006 
[22/100][18/43] KL_real/fake: 4.583/4.567 mean_real/fake: -0.003/-0.005 var_real/fake: 0.006/0.007 
[22/100][19/43] KL_real/fake: 4.578/4.562 mean_real/fake: -0.005/-0.005 var_real/fake: 0.006/0.007 
[22/100][20/43] KL_real/fake: 4.585/4.557 mean_real/fake: -0.005/-0.004 var_real/fake: 0.006/0.007 


[24/100][8/43] KL_real/fake: 4.586/4.565 mean_real/fake: -0.003/-0.004 var_real/fake: 0.006/0.006 
[24/100][9/43] KL_real/fake: 4.581/4.561 mean_real/fake: -0.003/-0.004 var_real/fake: 0.006/0.007 
[24/100][10/43] KL_real/fake: 4.585/4.557 mean_real/fake: -0.001/-0.005 var_real/fake: 0.006/0.007 
[24/100][11/43] KL_real/fake: 4.577/4.559 mean_real/fake: -0.005/-0.004 var_real/fake: 0.006/0.007 
[24/100][12/43] KL_real/fake: 4.584/4.559 mean_real/fake: -0.004/-0.005 var_real/fake: 0.006/0.007 
[24/100][13/43] KL_real/fake: 4.583/4.559 mean_real/fake: -0.004/-0.004 var_real/fake: 0.006/0.007 
[24/100][14/43] KL_real/fake: 4.576/4.561 mean_real/fake: -0.005/-0.006 var_real/fake: 0.006/0.007 
[24/100][15/43] KL_real/fake: 4.576/4.560 mean_real/fake: -0.005/-0.007 var_real/fake: 0.006/0.007 
[24/100][16/43] KL_real/fake: 4.575/4.557 mean_real/fake: -0.005/-0.006 var_real/fake: 0.006/0.007 
[24/100][17/43] KL_real/fake: 4.565/4.555 mean_real/fake: -0.005/-0.005 var_real/fake: 0.006/0.007 
[2

[26/100][4/43] KL_real/fake: 4.578/4.573 mean_real/fake: -0.005/-0.004 var_real/fake: 0.006/0.006 
[26/100][5/43] KL_real/fake: 4.577/4.572 mean_real/fake: -0.005/-0.004 var_real/fake: 0.006/0.006 
[26/100][6/43] KL_real/fake: 4.574/4.571 mean_real/fake: -0.004/-0.005 var_real/fake: 0.006/0.006 
[26/100][7/43] KL_real/fake: 4.581/4.571 mean_real/fake: -0.003/-0.004 var_real/fake: 0.006/0.006 
[26/100][8/43] KL_real/fake: 4.576/4.569 mean_real/fake: -0.004/-0.004 var_real/fake: 0.006/0.006 
[26/100][9/43] KL_real/fake: 4.575/4.571 mean_real/fake: -0.004/-0.004 var_real/fake: 0.006/0.006 
[26/100][10/43] KL_real/fake: 4.577/4.570 mean_real/fake: -0.004/-0.005 var_real/fake: 0.006/0.006 
[26/100][11/43] KL_real/fake: 4.578/4.570 mean_real/fake: -0.005/-0.005 var_real/fake: 0.006/0.006 
[26/100][12/43] KL_real/fake: 4.577/4.571 mean_real/fake: -0.005/-0.004 var_real/fake: 0.006/0.006 
[26/100][13/43] KL_real/fake: 4.571/4.569 mean_real/fake: -0.004/-0.006 var_real/fake: 0.006/0.006 
[26/10

[28/100][1/43] KL_real/fake: 4.569/4.563 mean_real/fake: -0.003/-0.004 var_real/fake: 0.006/0.007 
[28/100][2/43] KL_real/fake: 4.565/4.561 mean_real/fake: -0.004/-0.003 var_real/fake: 0.007/0.007 
[28/100][3/43] KL_real/fake: 4.558/4.555 mean_real/fake: -0.004/-0.004 var_real/fake: 0.007/0.007 
[28/100][4/43] KL_real/fake: 4.560/4.553 mean_real/fake: -0.003/-0.005 var_real/fake: 0.007/0.007 
[28/100][5/43] KL_real/fake: 4.553/4.548 mean_real/fake: -0.003/-0.005 var_real/fake: 0.007/0.007 
[28/100][6/43] KL_real/fake: 4.556/4.543 mean_real/fake: -0.004/-0.004 var_real/fake: 0.007/0.007 
[28/100][7/43] KL_real/fake: 4.558/4.539 mean_real/fake: -0.004/-0.003 var_real/fake: 0.007/0.007 
[28/100][8/43] KL_real/fake: 4.550/4.542 mean_real/fake: -0.004/-0.003 var_real/fake: 0.007/0.007 
[28/100][9/43] KL_real/fake: 4.548/4.545 mean_real/fake: -0.005/-0.005 var_real/fake: 0.007/0.007 
[28/100][10/43] KL_real/fake: 4.556/4.551 mean_real/fake: -0.005/-0.004 var_real/fake: 0.007/0.007 
[28/100][

[29/100][41/43] KL_real/fake: 4.559/4.552 mean_real/fake: -0.006/-0.005 var_real/fake: 0.007/0.007 
[29/100][42/43] KL_real/fake: 4.555/4.551 mean_real/fake: -0.006/-0.006 var_real/fake: 0.007/0.007 
[30/100][0/43] KL_real/fake: 4.553/4.551 mean_real/fake: -0.006/-0.006 var_real/fake: 0.007/0.007 
[30/100][1/43] KL_real/fake: 4.553/4.549 mean_real/fake: -0.005/-0.007 var_real/fake: 0.007/0.007 
[30/100][2/43] KL_real/fake: 4.550/4.545 mean_real/fake: -0.005/-0.006 var_real/fake: 0.007/0.007 
[30/100][3/43] KL_real/fake: 4.549/4.544 mean_real/fake: -0.005/-0.007 var_real/fake: 0.007/0.007 
[30/100][4/43] KL_real/fake: 4.548/4.547 mean_real/fake: -0.005/-0.006 var_real/fake: 0.007/0.007 
[30/100][5/43] KL_real/fake: 4.552/4.547 mean_real/fake: -0.006/-0.005 var_real/fake: 0.007/0.007 
[30/100][6/43] KL_real/fake: 4.559/4.552 mean_real/fake: -0.005/-0.005 var_real/fake: 0.007/0.007 
[30/100][7/43] KL_real/fake: 4.552/4.548 mean_real/fake: -0.006/-0.006 var_real/fake: 0.007/0.007 
[30/100]

[31/100][38/43] KL_real/fake: 4.548/4.544 mean_real/fake: -0.005/-0.006 var_real/fake: 0.007/0.007 
[31/100][39/43] KL_real/fake: 4.548/4.545 mean_real/fake: -0.005/-0.005 var_real/fake: 0.007/0.007 
[31/100][40/43] KL_real/fake: 4.547/4.544 mean_real/fake: -0.006/-0.004 var_real/fake: 0.007/0.007 
[31/100][41/43] KL_real/fake: 4.545/4.540 mean_real/fake: -0.007/-0.004 var_real/fake: 0.007/0.007 
[31/100][42/43] KL_real/fake: 4.544/4.541 mean_real/fake: -0.004/-0.004 var_real/fake: 0.007/0.007 
[32/100][0/43] KL_real/fake: 4.541/4.539 mean_real/fake: -0.002/-0.004 var_real/fake: 0.007/0.007 
[32/100][1/43] KL_real/fake: 4.548/4.540 mean_real/fake: -0.004/-0.003 var_real/fake: 0.007/0.007 
[32/100][2/43] KL_real/fake: 4.541/4.537 mean_real/fake: -0.003/-0.004 var_real/fake: 0.007/0.007 
[32/100][3/43] KL_real/fake: 4.538/4.533 mean_real/fake: -0.005/-0.005 var_real/fake: 0.007/0.007 
[32/100][4/43] KL_real/fake: 4.538/4.531 mean_real/fake: -0.005/-0.005 var_real/fake: 0.007/0.007 
[32/1

[33/100][35/43] KL_real/fake: 4.537/4.533 mean_real/fake: -0.006/-0.007 var_real/fake: 0.007/0.007 
[33/100][36/43] KL_real/fake: 4.538/4.532 mean_real/fake: -0.006/-0.005 var_real/fake: 0.007/0.007 
[33/100][37/43] KL_real/fake: 4.536/4.532 mean_real/fake: -0.005/-0.004 var_real/fake: 0.007/0.007 
[33/100][38/43] KL_real/fake: 4.537/4.531 mean_real/fake: -0.005/-0.004 var_real/fake: 0.007/0.007 
[33/100][39/43] KL_real/fake: 4.537/4.535 mean_real/fake: -0.004/-0.005 var_real/fake: 0.007/0.007 
[33/100][40/43] KL_real/fake: 4.539/4.535 mean_real/fake: -0.006/-0.007 var_real/fake: 0.007/0.007 
[33/100][41/43] KL_real/fake: 4.537/4.534 mean_real/fake: -0.006/-0.004 var_real/fake: 0.007/0.007 
[33/100][42/43] KL_real/fake: 4.537/4.534 mean_real/fake: -0.004/-0.006 var_real/fake: 0.007/0.007 
[34/100][0/43] KL_real/fake: 4.539/4.531 mean_real/fake: -0.002/-0.006 var_real/fake: 0.007/0.007 
[34/100][1/43] KL_real/fake: 4.536/4.528 mean_real/fake: -0.004/-0.006 var_real/fake: 0.007/0.007 
[3

[35/100][32/43] KL_real/fake: 4.560/4.553 mean_real/fake: -0.006/-0.005 var_real/fake: 0.007/0.007 
[35/100][33/43] KL_real/fake: 4.559/4.554 mean_real/fake: -0.005/-0.005 var_real/fake: 0.007/0.007 
[35/100][34/43] KL_real/fake: 4.559/4.558 mean_real/fake: -0.005/-0.006 var_real/fake: 0.007/0.007 
[35/100][35/43] KL_real/fake: 4.565/4.558 mean_real/fake: -0.005/-0.005 var_real/fake: 0.006/0.007 
[35/100][36/43] KL_real/fake: 4.559/4.557 mean_real/fake: -0.005/-0.004 var_real/fake: 0.007/0.007 
[35/100][37/43] KL_real/fake: 4.557/4.549 mean_real/fake: -0.004/-0.005 var_real/fake: 0.007/0.007 
[35/100][38/43] KL_real/fake: 4.550/4.543 mean_real/fake: -0.004/-0.006 var_real/fake: 0.007/0.007 
[35/100][39/43] KL_real/fake: 4.552/4.543 mean_real/fake: -0.006/-0.005 var_real/fake: 0.007/0.007 
[35/100][40/43] KL_real/fake: 4.544/4.536 mean_real/fake: -0.006/-0.006 var_real/fake: 0.007/0.007 
[35/100][41/43] KL_real/fake: 4.548/4.537 mean_real/fake: -0.006/-0.006 var_real/fake: 0.007/0.007 


[37/100][29/43] KL_real/fake: 4.570/4.563 mean_real/fake: -0.006/-0.005 var_real/fake: 0.006/0.007 
[37/100][30/43] KL_real/fake: 4.566/4.564 mean_real/fake: -0.005/-0.004 var_real/fake: 0.007/0.007 
[37/100][31/43] KL_real/fake: 4.561/4.557 mean_real/fake: -0.004/-0.007 var_real/fake: 0.007/0.007 
[37/100][32/43] KL_real/fake: 4.558/4.554 mean_real/fake: -0.004/-0.006 var_real/fake: 0.007/0.007 
[37/100][33/43] KL_real/fake: 4.560/4.555 mean_real/fake: -0.006/-0.006 var_real/fake: 0.007/0.007 
[37/100][34/43] KL_real/fake: 4.563/4.553 mean_real/fake: -0.006/-0.004 var_real/fake: 0.006/0.007 
[37/100][35/43] KL_real/fake: 4.555/4.550 mean_real/fake: -0.005/-0.006 var_real/fake: 0.007/0.007 
[37/100][36/43] KL_real/fake: 4.556/4.551 mean_real/fake: -0.004/-0.004 var_real/fake: 0.007/0.007 
[37/100][37/43] KL_real/fake: 4.553/4.549 mean_real/fake: -0.004/-0.004 var_real/fake: 0.007/0.007 
[37/100][38/43] KL_real/fake: 4.556/4.555 mean_real/fake: -0.004/-0.004 var_real/fake: 0.007/0.007 


[39/100][26/43] KL_real/fake: 4.555/4.549 mean_real/fake: -0.005/-0.004 var_real/fake: 0.007/0.007 
[39/100][27/43] KL_real/fake: 4.552/4.550 mean_real/fake: -0.006/-0.004 var_real/fake: 0.007/0.007 
[39/100][28/43] KL_real/fake: 4.544/4.543 mean_real/fake: -0.006/-0.004 var_real/fake: 0.007/0.007 
[39/100][29/43] KL_real/fake: 4.556/4.539 mean_real/fake: -0.004/-0.004 var_real/fake: 0.007/0.007 
[39/100][30/43] KL_real/fake: 4.541/4.535 mean_real/fake: -0.003/-0.004 var_real/fake: 0.007/0.007 
[39/100][31/43] KL_real/fake: 4.549/4.535 mean_real/fake: -0.004/-0.004 var_real/fake: 0.007/0.007 
[39/100][32/43] KL_real/fake: 4.536/4.534 mean_real/fake: -0.004/-0.004 var_real/fake: 0.007/0.007 
[39/100][33/43] KL_real/fake: 4.540/4.532 mean_real/fake: -0.004/-0.003 var_real/fake: 0.007/0.007 
[39/100][34/43] KL_real/fake: 4.538/4.531 mean_real/fake: -0.004/-0.004 var_real/fake: 0.007/0.007 
[39/100][35/43] KL_real/fake: 4.538/4.533 mean_real/fake: -0.003/-0.005 var_real/fake: 0.007/0.007 


[41/100][23/43] KL_real/fake: 4.550/4.540 mean_real/fake: -0.007/-0.004 var_real/fake: 0.007/0.007 
[41/100][24/43] KL_real/fake: 4.550/4.547 mean_real/fake: -0.006/-0.006 var_real/fake: 0.007/0.007 
[41/100][25/43] KL_real/fake: 4.556/4.551 mean_real/fake: -0.006/-0.006 var_real/fake: 0.007/0.007 
[41/100][26/43] KL_real/fake: 4.566/4.560 mean_real/fake: -0.005/-0.004 var_real/fake: 0.007/0.007 
[41/100][27/43] KL_real/fake: 4.578/4.569 mean_real/fake: -0.006/-0.005 var_real/fake: 0.006/0.007 
[41/100][28/43] KL_real/fake: 4.583/4.574 mean_real/fake: -0.004/-0.005 var_real/fake: 0.006/0.006 
[41/100][29/43] KL_real/fake: 4.577/4.567 mean_real/fake: -0.004/-0.005 var_real/fake: 0.006/0.007 
[41/100][30/43] KL_real/fake: 4.572/4.563 mean_real/fake: -0.006/-0.006 var_real/fake: 0.006/0.007 
[41/100][31/43] KL_real/fake: 4.570/4.559 mean_real/fake: -0.005/-0.008 var_real/fake: 0.006/0.007 
[41/100][32/43] KL_real/fake: 4.564/4.556 mean_real/fake: -0.004/-0.006 var_real/fake: 0.007/0.007 


[43/100][20/43] KL_real/fake: 4.546/4.536 mean_real/fake: -0.008/-0.007 var_real/fake: 0.007/0.007 
[43/100][21/43] KL_real/fake: 4.552/4.539 mean_real/fake: -0.006/-0.007 var_real/fake: 0.007/0.007 
[43/100][22/43] KL_real/fake: 4.551/4.545 mean_real/fake: -0.008/-0.004 var_real/fake: 0.007/0.007 
[43/100][23/43] KL_real/fake: 4.555/4.545 mean_real/fake: -0.007/-0.006 var_real/fake: 0.007/0.007 
[43/100][24/43] KL_real/fake: 4.557/4.549 mean_real/fake: -0.005/-0.007 var_real/fake: 0.007/0.007 
[43/100][25/43] KL_real/fake: 4.561/4.556 mean_real/fake: -0.006/-0.008 var_real/fake: 0.007/0.007 
[43/100][26/43] KL_real/fake: 4.562/4.552 mean_real/fake: -0.005/-0.006 var_real/fake: 0.007/0.007 
[43/100][27/43] KL_real/fake: 4.556/4.550 mean_real/fake: -0.005/-0.006 var_real/fake: 0.007/0.007 
[43/100][28/43] KL_real/fake: 4.553/4.550 mean_real/fake: -0.005/-0.006 var_real/fake: 0.007/0.007 
[43/100][29/43] KL_real/fake: 4.553/4.548 mean_real/fake: -0.006/-0.006 var_real/fake: 0.007/0.007 


[45/100][17/43] KL_real/fake: 4.559/4.555 mean_real/fake: -0.007/-0.005 var_real/fake: 0.007/0.007 
[45/100][18/43] KL_real/fake: 4.566/4.558 mean_real/fake: -0.006/-0.005 var_real/fake: 0.006/0.007 
[45/100][19/43] KL_real/fake: 4.568/4.552 mean_real/fake: -0.005/-0.003 var_real/fake: 0.006/0.007 
[45/100][20/43] KL_real/fake: 4.557/4.550 mean_real/fake: -0.005/-0.005 var_real/fake: 0.007/0.007 
[45/100][21/43] KL_real/fake: 4.561/4.556 mean_real/fake: -0.004/-0.006 var_real/fake: 0.007/0.007 
[45/100][22/43] KL_real/fake: 4.559/4.553 mean_real/fake: -0.004/-0.005 var_real/fake: 0.007/0.007 
[45/100][23/43] KL_real/fake: 4.555/4.550 mean_real/fake: -0.004/-0.006 var_real/fake: 0.007/0.007 
[45/100][24/43] KL_real/fake: 4.558/4.548 mean_real/fake: -0.004/-0.004 var_real/fake: 0.007/0.007 
[45/100][25/43] KL_real/fake: 4.546/4.545 mean_real/fake: -0.004/-0.007 var_real/fake: 0.007/0.007 
[45/100][26/43] KL_real/fake: 4.549/4.541 mean_real/fake: -0.005/-0.008 var_real/fake: 0.007/0.007 


[47/100][14/43] KL_real/fake: 4.556/4.536 mean_real/fake: -0.006/-0.006 var_real/fake: 0.007/0.007 
[47/100][15/43] KL_real/fake: 4.554/4.543 mean_real/fake: -0.007/-0.005 var_real/fake: 0.007/0.007 
[47/100][16/43] KL_real/fake: 4.552/4.540 mean_real/fake: -0.006/-0.006 var_real/fake: 0.007/0.007 
[47/100][17/43] KL_real/fake: 4.550/4.543 mean_real/fake: -0.006/-0.006 var_real/fake: 0.007/0.007 
[47/100][18/43] KL_real/fake: 4.549/4.538 mean_real/fake: -0.007/-0.005 var_real/fake: 0.007/0.007 
[47/100][19/43] KL_real/fake: 4.548/4.540 mean_real/fake: -0.007/-0.004 var_real/fake: 0.007/0.007 
[47/100][20/43] KL_real/fake: 4.544/4.542 mean_real/fake: -0.006/-0.005 var_real/fake: 0.007/0.007 
[47/100][21/43] KL_real/fake: 4.544/4.541 mean_real/fake: -0.004/-0.007 var_real/fake: 0.007/0.007 
[47/100][22/43] KL_real/fake: 4.548/4.543 mean_real/fake: -0.003/-0.005 var_real/fake: 0.007/0.007 
[47/100][23/43] KL_real/fake: 4.555/4.546 mean_real/fake: -0.003/-0.005 var_real/fake: 0.007/0.007 


[49/100][11/43] KL_real/fake: 4.555/4.551 mean_real/fake: -0.004/-0.006 var_real/fake: 0.007/0.007 
[49/100][12/43] KL_real/fake: 4.562/4.548 mean_real/fake: -0.004/-0.005 var_real/fake: 0.007/0.007 
[49/100][13/43] KL_real/fake: 4.558/4.553 mean_real/fake: -0.006/-0.005 var_real/fake: 0.007/0.007 
[49/100][14/43] KL_real/fake: 4.558/4.548 mean_real/fake: -0.004/-0.006 var_real/fake: 0.007/0.007 
[49/100][15/43] KL_real/fake: 4.554/4.546 mean_real/fake: -0.006/-0.005 var_real/fake: 0.007/0.007 
[49/100][16/43] KL_real/fake: 4.552/4.546 mean_real/fake: -0.006/-0.006 var_real/fake: 0.007/0.007 
[49/100][17/43] KL_real/fake: 4.551/4.542 mean_real/fake: -0.004/-0.005 var_real/fake: 0.007/0.007 
[49/100][18/43] KL_real/fake: 4.544/4.540 mean_real/fake: -0.005/-0.004 var_real/fake: 0.007/0.007 
[49/100][19/43] KL_real/fake: 4.548/4.539 mean_real/fake: -0.005/-0.004 var_real/fake: 0.007/0.007 
[49/100][20/43] KL_real/fake: 4.547/4.541 mean_real/fake: -0.004/-0.003 var_real/fake: 0.007/0.007 


[51/100][7/43] KL_real/fake: 4.554/4.553 mean_real/fake: -0.005/-0.005 var_real/fake: 0.007/0.007 
[51/100][8/43] KL_real/fake: 4.555/4.551 mean_real/fake: -0.005/-0.005 var_real/fake: 0.007/0.007 
[51/100][9/43] KL_real/fake: 4.554/4.553 mean_real/fake: -0.006/-0.006 var_real/fake: 0.007/0.007 
[51/100][10/43] KL_real/fake: 4.552/4.547 mean_real/fake: -0.004/-0.005 var_real/fake: 0.007/0.007 
[51/100][11/43] KL_real/fake: 4.556/4.546 mean_real/fake: -0.004/-0.005 var_real/fake: 0.007/0.007 
[51/100][12/43] KL_real/fake: 4.548/4.546 mean_real/fake: -0.003/-0.006 var_real/fake: 0.007/0.007 
[51/100][13/43] KL_real/fake: 4.547/4.547 mean_real/fake: -0.005/-0.006 var_real/fake: 0.007/0.007 
[51/100][14/43] KL_real/fake: 4.548/4.549 mean_real/fake: -0.004/-0.006 var_real/fake: 0.007/0.007 
[51/100][15/43] KL_real/fake: 4.549/4.547 mean_real/fake: -0.004/-0.007 var_real/fake: 0.007/0.007 
[51/100][16/43] KL_real/fake: 4.546/4.545 mean_real/fake: -0.005/-0.005 var_real/fake: 0.007/0.007 
[51

[53/100][4/43] KL_real/fake: 4.556/4.551 mean_real/fake: -0.005/-0.005 var_real/fake: 0.007/0.007 
[53/100][5/43] KL_real/fake: 4.553/4.549 mean_real/fake: -0.004/-0.005 var_real/fake: 0.007/0.007 
[53/100][6/43] KL_real/fake: 4.560/4.553 mean_real/fake: -0.005/-0.005 var_real/fake: 0.006/0.007 
[53/100][7/43] KL_real/fake: 4.559/4.552 mean_real/fake: -0.005/-0.004 var_real/fake: 0.007/0.007 
[53/100][8/43] KL_real/fake: 4.558/4.555 mean_real/fake: -0.005/-0.004 var_real/fake: 0.007/0.007 
[53/100][9/43] KL_real/fake: 4.562/4.558 mean_real/fake: -0.005/-0.005 var_real/fake: 0.006/0.007 
[53/100][10/43] KL_real/fake: 4.563/4.558 mean_real/fake: -0.005/-0.005 var_real/fake: 0.006/0.007 
[53/100][11/43] KL_real/fake: 4.563/4.557 mean_real/fake: -0.006/-0.006 var_real/fake: 0.007/0.007 
[53/100][12/43] KL_real/fake: 4.562/4.559 mean_real/fake: -0.004/-0.005 var_real/fake: 0.006/0.007 
[53/100][13/43] KL_real/fake: 4.566/4.558 mean_real/fake: -0.005/-0.003 var_real/fake: 0.006/0.007 
[53/10

[55/100][1/43] KL_real/fake: 4.551/4.547 mean_real/fake: -0.005/-0.005 var_real/fake: 0.007/0.007 
[55/100][2/43] KL_real/fake: 4.552/4.546 mean_real/fake: -0.005/-0.004 var_real/fake: 0.007/0.007 
[55/100][3/43] KL_real/fake: 4.553/4.544 mean_real/fake: -0.004/-0.004 var_real/fake: 0.007/0.007 
[55/100][4/43] KL_real/fake: 4.548/4.545 mean_real/fake: -0.005/-0.006 var_real/fake: 0.007/0.007 
[55/100][5/43] KL_real/fake: 4.548/4.545 mean_real/fake: -0.005/-0.004 var_real/fake: 0.007/0.007 
[55/100][6/43] KL_real/fake: 4.548/4.546 mean_real/fake: -0.005/-0.005 var_real/fake: 0.007/0.007 
[55/100][7/43] KL_real/fake: 4.550/4.543 mean_real/fake: -0.005/-0.005 var_real/fake: 0.007/0.007 
[55/100][8/43] KL_real/fake: 4.546/4.543 mean_real/fake: -0.005/-0.004 var_real/fake: 0.007/0.007 
[55/100][9/43] KL_real/fake: 4.547/4.543 mean_real/fake: -0.005/-0.005 var_real/fake: 0.007/0.007 
[55/100][10/43] KL_real/fake: 4.550/4.548 mean_real/fake: -0.005/-0.006 var_real/fake: 0.007/0.007 
[55/100][

[56/100][41/43] KL_real/fake: 4.550/4.548 mean_real/fake: -0.005/-0.006 var_real/fake: 0.007/0.007 
[56/100][42/43] KL_real/fake: 4.553/4.546 mean_real/fake: -0.003/-0.005 var_real/fake: 0.007/0.007 
[57/100][0/43] KL_real/fake: 4.545/4.546 mean_real/fake: -0.004/-0.005 var_real/fake: 0.007/0.007 
[57/100][1/43] KL_real/fake: 4.548/4.542 mean_real/fake: -0.006/-0.005 var_real/fake: 0.007/0.007 
[57/100][2/43] KL_real/fake: 4.543/4.542 mean_real/fake: -0.006/-0.006 var_real/fake: 0.007/0.007 
[57/100][3/43] KL_real/fake: 4.550/4.547 mean_real/fake: -0.006/-0.006 var_real/fake: 0.007/0.007 
[57/100][4/43] KL_real/fake: 4.551/4.546 mean_real/fake: -0.006/-0.006 var_real/fake: 0.007/0.007 
[57/100][5/43] KL_real/fake: 4.549/4.545 mean_real/fake: -0.005/-0.005 var_real/fake: 0.007/0.007 
[57/100][6/43] KL_real/fake: 4.546/4.544 mean_real/fake: -0.004/-0.006 var_real/fake: 0.007/0.007 
[57/100][7/43] KL_real/fake: 4.548/4.541 mean_real/fake: -0.005/-0.004 var_real/fake: 0.007/0.007 
[57/100]

[58/100][38/43] KL_real/fake: 4.549/4.547 mean_real/fake: -0.005/-0.004 var_real/fake: 0.007/0.007 
[58/100][39/43] KL_real/fake: 4.553/4.547 mean_real/fake: -0.004/-0.004 var_real/fake: 0.007/0.007 
[58/100][40/43] KL_real/fake: 4.548/4.545 mean_real/fake: -0.004/-0.004 var_real/fake: 0.007/0.007 
[58/100][41/43] KL_real/fake: 4.546/4.547 mean_real/fake: -0.006/-0.004 var_real/fake: 0.007/0.007 
[58/100][42/43] KL_real/fake: 4.555/4.548 mean_real/fake: -0.004/-0.005 var_real/fake: 0.007/0.007 
[59/100][0/43] KL_real/fake: 4.548/4.548 mean_real/fake: -0.004/-0.005 var_real/fake: 0.007/0.007 
[59/100][1/43] KL_real/fake: 4.552/4.545 mean_real/fake: -0.005/-0.004 var_real/fake: 0.007/0.007 
[59/100][2/43] KL_real/fake: 4.544/4.543 mean_real/fake: -0.005/-0.004 var_real/fake: 0.007/0.007 
[59/100][3/43] KL_real/fake: 4.543/4.542 mean_real/fake: -0.004/-0.005 var_real/fake: 0.007/0.007 
[59/100][4/43] KL_real/fake: 4.541/4.539 mean_real/fake: -0.003/-0.006 var_real/fake: 0.007/0.007 
[59/1

[60/100][35/43] KL_real/fake: 4.539/4.538 mean_real/fake: -0.004/-0.006 var_real/fake: 0.007/0.007 
[60/100][36/43] KL_real/fake: 4.544/4.538 mean_real/fake: -0.005/-0.005 var_real/fake: 0.007/0.007 
[60/100][37/43] KL_real/fake: 4.543/4.542 mean_real/fake: -0.004/-0.007 var_real/fake: 0.007/0.007 
[60/100][38/43] KL_real/fake: 4.544/4.543 mean_real/fake: -0.006/-0.007 var_real/fake: 0.007/0.007 
[60/100][39/43] KL_real/fake: 4.543/4.541 mean_real/fake: -0.005/-0.006 var_real/fake: 0.007/0.007 
[60/100][40/43] KL_real/fake: 4.541/4.544 mean_real/fake: -0.005/-0.005 var_real/fake: 0.007/0.007 
[60/100][41/43] KL_real/fake: 4.543/4.538 mean_real/fake: -0.006/-0.006 var_real/fake: 0.007/0.007 
[60/100][42/43] KL_real/fake: 4.544/4.537 mean_real/fake: -0.007/-0.005 var_real/fake: 0.007/0.007 
[61/100][0/43] KL_real/fake: 4.539/4.533 mean_real/fake: -0.005/-0.006 var_real/fake: 0.007/0.007 
[61/100][1/43] KL_real/fake: 4.536/4.530 mean_real/fake: -0.005/-0.006 var_real/fake: 0.007/0.007 
[6

[Go to init](#init)
<a id='train'.</a>

In [None]:
# https://github.com/pytorch/pytorch/issues/1355

In [None]:
# o_f = 'output/helen/train_hist_ims_64_64_25.gif'
# from IPython.display import HTML
# # HTML('<img src="%s">' %o_f)

In [None]:
def log_train_hist(hist, run, show=True, save=True, o_p='output/helen'):
    
    nepoch = int(run.split('_')[2])
    batch_size = int(run.split('_')[1])
    
    plt.figure(figsize = (12, 6))
    x = range(len(hist['KL_real']))
    for i, j in hist.items():
        if j: plt.plot(x, j, label=i)
            
    plt.xlabel('Epoch')
    plt.ylabel('Losses')
    plt.legend(loc=4)
    plt.grid(True)
    plt.tight_layout()

    if save: plt.savefig(os.path.join(o_p, 'train_hist_%s.png' %run))
    if show: plt.show()
    else: plt.close()
        
def log_last_epoch(netG, netE, run, o_p='output/helen'):
    le = round((nepoch / save_every_e)) * save_every_e
    if le < 10: le_str = 'epoch_00%d' %le
    elif le < 100: le_str = 'epoch_0%d' %le
    else: le_str = 'epoch_%d' %le
    
    fr_name = os.path.join(o_p, 'reconstructions_' + le_str + '.png'); print(fr_name)
    fn_name = os.path.join(o_p, 'fake_samples_' + le_str + '.png'); print(fn_name)
    
    imr = cv2.imread(fr_name, cv2.IMREAD_UNCHANGED); cv2.imwrite(os.path.join(o_p, 'ims_rec_last_epoch_%s.png' %run), imr)
    imn = cv2.imread(fn_name, cv2.IMREAD_UNCHANGED); cv2.imwrite(os.path.join(o_p, 'ims_noi_last_epoch_%s.png' %run), imn)
        
    torch.save(netG, '%s/models_netG_last_epoch_%s.pth' % (o_p, run))
    torch.save(netE, '%s/models_netE_last_epoch_%s.pth' % (o_p, run))
    
    plot_ims(np.array([imr, imn]))

In [None]:
log_train_hist(losses, run)
log_last_epoch(netG, netE, run)

In [29]:
o_p='output/helen'
torch.save(netG, '%s/models_netG_last_epoch_%s.pth' % (o_p, run))
torch.save(netE, '%s/models_netE_last_epoch_%s.pth' % (o_p, run))

  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


In [None]:
# def train_hist_ims_gif(run, save=True, show=True, o_p='output/helen'):
    
#     ims_rec = []
#     for i in range(nepoch):
#         if i < 10:
#             f_name = os.path.join(o_p, 'reconstructions_epoch_00' + str(i) + '.png')
#             ims_rec.append(imageio.imread(f_name))
#         elif i < 100:
#             f_name = os.path.join(o_p, 'reconstructions_epoch_0' + str(i) + '.png')
#             ims_rec.append(imageio.imread(f_name))                     
#         else:
#             f_name = os.path.join(o_p, 'reconstructions_epoch_' + str(i) + '.png')
#             ims_rec.append(imageio.imread(f_name))  
#         print(len(ims_rec))
#     o_f = os.path.join(o_p, 'train_hist_ims_%s.gif' %run); print(o_f)
    
#     if save: 
#         imageio.mimsave(o_f, ims_rec, fps=5)
    
#     if show:
#         from IPython.display import HTML
#         HTML('<img src="%s">' %o_f)