In [1]:
from __future__ import print_function
import argparse
import torch
import torch.nn.parallel
import torch.nn as nn
import torch.optim as optim
import torchvision.utils as vutils
from torch.autograd import Variable

parser = argparse.ArgumentParser()
parser.add_argument('--dataset', required=True,
                    help='cifar10 | lsun | imagenet | folder | lfw ')
parser.add_argument('--dataroot', type=str, help='path to dataset')
parser.add_argument('--workers', type=int,
                    help='number of data loading workers', default=8)
parser.add_argument('--batch_size', type=int,
                    default=64, help='batch size')
parser.add_argument('--image_size', type=int, default=32,
                    help='the resolution of the input image to network')
parser.add_argument('--nz', type=int, default=100,
                    help='size of the latent z vector')
parser.add_argument('--ngf', type=int, default=64)
parser.add_argument('--ndf', type=int, default=64)
parser.add_argument('--nc', type=int)

parser.add_argument('--nepoch', type=int, default=25,
                    help='number of epochs to train for')
parser.add_argument('--lr', type=float, default=0.0002,
                    help='learning rate, default=0.0002')
parser.add_argument('--beta1', type=float, default=0.5,
                    help='beta1 for adam. default=0.5')
parser.add_argument('--cpu', action='store_true',
                    help='use CPU instead of GPU')
parser.add_argument('--ngpu', type=int, default=1,
                    help='number of GPUs to use')

parser.add_argument('--netG', default='',
                    help="path to netG config")
parser.add_argument('--netE', default='',
                    help="path to netE config")
parser.add_argument('--netG_chp', default='',
                    help="path to netG (to continue training)")
parser.add_argument('--netE_chp', default='',
                    help="path to netE (to continue training)")

parser.add_argument('--save_dir', default='.',
                    help='folder to output images and model checkpoints')
parser.add_argument('--criterion', default='param',
                    help='param|nonparam, How to estimate KL')
parser.add_argument('--KL', default='qp', help='pq|qp')
parser.add_argument('--noise', default='sphere', help='normal|sphere')
parser.add_argument('--match_z', default='cos', help='none|L1|L2|cos')
parser.add_argument('--match_x', default='L1', help='none|L1|L2|cos')

parser.add_argument('--drop_lr', default=5, type=int, help='')
parser.add_argument('--save_every', default=50, type=int, help='')

parser.add_argument('--manual_seed', type=int, default=123, help='manual seed')
parser.add_argument('--start_epoch', type=int, default=0, help='epoch number to start with')

parser.add_argument(
    '--e_updates', default="1;KL_fake:1,KL_real:1,match_z:0,match_x:0",
    help='Update plan for encoder <number of updates>;[<term:weight>]'
)

parser.add_argument(
    '--g_updates', default="2;KL_fake:1,match_z:1,match_x:0",
    help='Update plan for generator <number of updates>;[<term:weight>]'
)

# opt = parser.parse_args()

# python age.py 
# --dataset celeba 
# --dataroot <data_root> 
# --image_size 64 
# --save_dir <save_dir> 
# --lr 0.0002 
# --nz 64 
# --batch_size 64 
# --netG dcgan64px 
# --netE dcgan64px 
# --nepoch 5 
# --drop_lr 5 
# --e_updates '1;KL_fake:1,KL_real:1,match_z:0,match_x:10' 
# --g_updates '3;KL_fake:1,match_z:1000,match_x:0'

# python age.py 
# --dataset celeba 
# --dataroot <data_root> 
# --image_size 64 
# --save_dir <save_dir> 
# --start_epoch 5 
# --lr 0.0002 
# --nz 64 
# --batch_size 256 
# --netG dcgan64px 
# --netE dcgan64px 
# --nepoch 6 
# --drop_lr 5   
# --e_updates '1;KL_fake:1,KL_real:1,match_z:0,match_x:15' 
# --g_updates '3;KL_fake:1,match_z:1000,match_x:0' 
# --netE_chp  <save_dir>/netE_epoch_5.pth 
# --netG_chp <save_dir>/netG_epoch_5.pth

_StoreAction(option_strings=['--g_updates'], dest='g_updates', nargs=None, const=None, default='2;KL_fake:1,match_z:1,match_x:0', type=None, choices=None, help='Update plan for generator <number of updates>;[<term:weight>]', metavar=None)

#### Setting up the options

In [2]:
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torch.utils.data
import importlib
# from .dataset import FolderWithImages
import random
import os
import torch.backends.cudnn as cudnn
from PIL import Image

In [3]:
# setup function @utils.py

cuda = False # not opt.cpu
torch.set_num_threads(4)
dataset = 'mnist'
save_dir = 'output/mnist'
try:
    os.makedirs(save_dir)
except OSError:
    print('Directory was not created.')

manual_seed = random.randint(1, 10000)
random.seed(manual_seed)
torch.manual_seed(manual_seed)
# torch.cuda.manual_seed_all(manual_seed)

cudnn.benchmark = True

if torch.cuda.is_available() and not cuda:
    print("WARNING: You have a CUDA device,"
            "so you should probably run with --cuda")

e_updates = "1;KL_fake:1,KL_real:1,match_z:0,match_x:0"
updates = {'e': {}, 'g': {}}
updates['e']['num_updates'] = int(e_updates.split(';')[0])
updates['e'].update({x.split(':')[0]: float(x.split(':')[1]) 
                     for x in e_updates.split(';')[1].split(',')})

g_updates = "2;KL_fake:1,match_z:1,match_x:0"
updates['g']['num_updates'] = int(g_updates.split(';')[0])
updates['g'].update({x.split(':')[0]: float(x.split(':')[1]) 
                     for x in g_updates.split(';')[1].split(',')})

print(updates)

image_size = 32 # 512
batch_size = 64
workers = 8
shuffle = True
drop_last = True
train = True

nc = 1 
nz = 64 #100 #'size of the latent z vector') size == batchsize??
ngf = 32
ndf = 32
ngpu = 0 # 1 # 'number of GPUs to use')
pin_memory= False # True | If ``True``, the data loader will copy tensors into CUDA pinned memory before returning them.

noise = 'sphere' #help='normal|sphere')

netG = 'dcgan32px' 
netE = 'dcgan32px' 

netG_chp = '' # "path to netG (to continue training)"
netE_chp = '' # "path to netE (to continue training)"

lr = 0.0002 
drop_lr = 1 #5
beta1 = 0.5 # help='beta1 for adam. default=0.5'
criterion = 'param'# help='param|nonparam, How to estimate KL'
KL ='qp' # help='pq|qp'
match_z = 'cos' # help='none|L1|L2|cos'
match_x = 'L1' # help='none|L1|L2|cos'

start_epoch = 0
nepoch = 1 # 5

save_every = 5

Directory was not created.
{'e': {'num_updates': 1, 'KL_fake': 1.0, 'KL_real': 1.0, 'match_z': 0.0, 'match_x': 0.0}, 'g': {'num_updates': 2, 'KL_fake': 1.0, 'match_z': 1.0, 'match_x': 0.0}}


In [4]:
# setup_dataset @utils.py
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torch.utils.data
import importlib
# from .dataset import FolderWithImages
import random
import os
import torch.backends.cudnn as cudnn
from PIL import Image

##
import torch.utils.data as data
from os import listdir
from os.path import join
from PIL import Image

def is_image_file(filename):
    return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"])

def load_img(filepath):
    img = Image.open(filepath).convert('RGB')
    return img

class FolderWithImages(data.Dataset):
    def __init__(self, root, input_transform=None, target_transform=None):
        super(FolderWithImages, self).__init__()
        self.image_filenames = [join(root, x)
                                for x in listdir(root) if is_image_file(x.lower())]

        self.input_transform = input_transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        input = load_img(self.image_filenames[index])
        target = input.copy()
        if self.input_transform:
            input = self.input_transform(input)
        if self.target_transform:
            target = self.target_transform(target)

        return input, target

    def __len__(self):
        return len(self.image_filenames)
    
##

def setup_dataset(dataset, train=True, shuffle=True, drop_last=True):
    '''
    Setups dataset.
    '''
    # Usual transform
    t = transforms.Compose([
        transforms.Scale([image_size, image_size]),
        transforms.ToTensor(),
        transforms.Normalize(
            (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    if dataset in ['imagenet', 'folder', 'lfw']:
        imdir = 'train' if train else 'val'
        dataroot = os.path.join(dataroot, imdir)

        dataset = dset.ImageFolder(root=dataroot, transform=t)
    elif dataset == 'lsun':
        dataset = dset.LSUN(db_path=dataroot,
                            classes=['bedroom_train'],
                            train=train,
                            transform=t)
    elif dataset == 'cifar10':
        dataset = dset.CIFAR10(root='data/raw/cifar10',
                               download=True,
                               train=train,
                               transform=t
                               )
    elif dataset == 'mnist':
        dataset = dset.MNIST(root='data/raw/mnist',
                             download=True,
                             train=train,
                             transform=t
                             )
    elif dataset == 'svhn':
        dataset = dset.SVHN(root='data/raw/svhn',
                            download=True,
                            train=train,
                            transform=t)
    elif dataset == 'celeba':
        imdir = 'train' if train else 'val'
        dataroot = os.path.join(dataroot, imdir)

        dataset = FolderWithImages(root=dataroot,
                                   input_transform=transforms.Compose([
                                       ALICropAndScale(),
                                       transforms.ToTensor(),
                                       transforms.Normalize(
                                           (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                   ]),
                                   target_transform=transforms.ToTensor()
                                   )

    else:
        assert False, 'Wrong dataset name.'

    assert len(dataset) > 0, 'No images found, check your paths.'

    # Shuffle and drop last when training
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=batch_size,
                                             shuffle=shuffle,
                                             num_workers=int(workers),
                                             pin_memory=pin_memory,
                                             drop_last=drop_last)

    return InfiniteDataLoader(dataloader)


class InfiniteDataLoader(object):
    """docstring for InfiniteDataLoader"""

    def __init__(self, dataloader):
        super(InfiniteDataLoader, self).__init__()
        self.dataloader = dataloader
        self.data_iter = None

    def next(self):
        try:
            data = self.data_iter.next()
        except Exception:
            # Reached end of the dataset
            self.data_iter = iter(self.dataloader)
            data = self.data_iter.next()

        return data

    def __len__(self):
        return len(self.dataloader)

# Setup dataset
dataloader = dict(train=setup_dataset(dataset, train=True),
                  val=setup_dataset(dataset, train=False))



In [5]:
len(dataloader['train'])

937

In [6]:
def weights_init(m):
    '''
    Custom weights initialization called on netG and netE
    '''
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

class _netE_Base(nn.Module):
    def __init__(self, main):
        super(_netE_Base, self).__init__()
        self.noise = noise
        self.ngpu = ngpu
        self.main = main

    def forward(self, input):
        gpu_ids = None
        if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 0:
            gpu_ids = range(self.ngpu)
            output = nn.parallel.data_parallel(self.main, input, gpu_ids)
        else:
            output = self.main(input)

        output = output.view(output.size(0), -1)
        if self.noise == 'sphere':
            output = normalize(output)

        return output

class _netG_Base(nn.Module):
    def __init__(self, main):
        super(_netG_Base, self).__init__()
        self.ngpu = ngpu
        self.main = main

    def forward(self, input):

        # Check input is either (B,C,1,1) or (B,C)
        assert input.nelement() == input.size(0) * input.size(1), 'wtf'
        input = input.view(input.size(0), input.size(1), 1, 1)

        gpu_ids = None
        if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 0:
            gpu_ids = range(self.ngpu)
            return nn.parallel.data_parallel(self.main, input, gpu_ids)
        else:
            return self.main(input)

def _netE():

    main = nn.Sequential(
        # input is (nc) x 32 x 32
        nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
        nn.LeakyReLU(0.2, inplace=True),
        
        nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
        nn.BatchNorm2d(ndf * 2),
        nn.LeakyReLU(0.2, inplace=True),
        
        nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
        nn.BatchNorm2d(ndf * 4),
        nn.LeakyReLU(0.2, inplace=True),
        
        nn.Conv2d(ndf * 4, nz, 4, 2, 1, bias=True),
        nn.AvgPool2d(2),
    )

    return _netE_Base(main)

def _netG():

    main = nn.Sequential(
        # input is Z, going into a convolution
        nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
        nn.BatchNorm2d(ngf * 8),
        nn.ReLU(True),
        # state size. (ngf*8) x 4 x 4
        nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
        nn.BatchNorm2d(ngf * 4),
        nn.ReLU(True),
        # state size. (ngf*4) x 8 x 8
        nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
        nn.BatchNorm2d(ngf * 2),
        nn.ReLU(True),
        # state size. (ngf*2) x 16 x 16
        nn.ConvTranspose2d(ngf * 2, ngf * 2, 4, 2, 1, bias=False),
        nn.BatchNorm2d(ngf * 2),
        nn.ReLU(True),

        nn.Conv2d(ngf * 2, nc, 1, bias=True),
        nn.Tanh()
    )

    return _netG_Base(main)

def load_G():
    '''
    Loads generator model.
    '''
    netG = _netG()
    netG.apply(weights_init)
    netG.train()
    if netG_chp != '':
        netG.load_state_dict(torch.load(netG_chp).state_dict())

    print('Generator\n', netG)
    return netG

def load_E():
    '''
    Loads encoder model.
    '''
    netE = _netE()
    netE.apply(weights_init)
    netE.train()
    if netE_chp != '':
        netE.load_state_dict(torch.load(netE_chp).state_dict())

    print('Encoder\n', netE)

    return netE

# Load generator
netG = load_G()

# Load encoder
netE = load_E()

# RuntimeError: 
# Given input size: (512 x 2 x 2). 
# Calculated output size: (64 x -1 x -1). 
# Output size is too small at /Users/soumith/minicondabuild3/conda-bld/pytorch_1512381214802/work/torch/lib/THNN/generic/SpatialConvolutionMM.c:45


Generator
 _netG_Base(
  (main): Sequential(
    (0): ConvTranspose2d (64, 256, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
    (2): ReLU(inplace)
    (3): ConvTranspose2d (256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
    (5): ReLU(inplace)
    (6): ConvTranspose2d (128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
    (8): ReLU(inplace)
    (9): ConvTranspose2d (64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
    (11): ReLU(inplace)
    (12): Conv2d (64, 1, kernel_size=(1, 1), stride=(1, 1))
    (13): Tanh()
  )
)
Encoder
 _netE_Base(
  (main): Sequential(
    (0): Conv2d (1, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU

#### Pretraining preparation

In [7]:
import torch
import torch.nn
from torch.autograd import Variable

def var(x, dim=0):
    '''
    Calculates variance.
    '''
    x_zero_meaned = x - x.mean(dim).expand_as(x)
    return x_zero_meaned.pow(2).mean(dim)


class KLN01Loss(torch.nn.Module):

    def __init__(self, direction, minimize):
        super(KLN01Loss, self).__init__()
        self.minimize = minimize
        assert direction in ['pq', 'qp'], 'direction?'

        self.direction = direction

    def forward(self, samples):

        assert samples.nelement() == samples.size(1) * samples.size(0), 'wtf?'

        samples = samples.view(samples.size(0), -1)

        self.samples_var = var(samples)
        self.samples_mean = samples.mean(0)

        samples_mean = self.samples_mean
        samples_var = self.samples_var

        if self.direction == 'pq':
            # mu_1 = 0; sigma_1 = 1

            t1 = (1 + samples_mean.pow(2)) / (2 * samples_var.pow(2))
            t2 = samples_var.log()

            KL = (t1 + t2 - 0.5).mean()
        else:
            # mu_2 = 0; sigma_2 = 1

            t1 = (samples_var.pow(2) + samples_mean.pow(2)) / 2
            t2 = -samples_var.log()

            KL = (t1 + t2 - 0.5).mean()

        if not self.minimize:
            KL *= -1

        return KL

# ----------------------------

def pairwise_euclidean(samples):

    B = samples.size(0)

    samples_norm = samples.mul(samples).sum(1)
    samples_norm = samples_norm.expand(B, B)

    dist_mat = samples.mm(samples.t()).mul(-2) + \
        samples_norm.add(samples_norm.t())
    return dist_mat

def sample_entropy(samples):

        # Assume B x C input

    dist_mat = pairwise_euclidean(samples)

    # Get max and add it to diag
    m = dist_mat.max().detach()
    dist_mat_d = dist_mat + \
        Variable(torch.eye(dist_mat.size(0)) * (m.data[0] + 1)).cuda()

    entropy = (dist_mat_d.min(1)[0] + 1e-4).log().sum()

    entropy *= (samples.size(1) + 0.) / samples.size(0)

    return entropy

class SampleKLN01Loss(torch.nn.Module):

    def __init__(self, direction, minimize):
        super(SampleKLN01Loss, self).__init__()
        self.minimize = minimize
        assert direction in ['pq', 'qp'], 'direction?'

        self.direction = direction

    def forward(self, samples):

        assert samples.ndimension == 2, 'wft'
        samples = samples.view(samples.size(0), -1)

        self.samples_var = var(samples)
        self.samples_mean = samples.mean(0)

        if self.direction == 'pq':
            assert False, 'not possible'
        else:
            entropy = sample_entropy(samples)

            cross_entropy = - samples.pow(2).mean() / 2.

            KL = - cross_entropy - entropy

        if not self.minimize:
            KL *= -1

        return KL

In [8]:
x = torch.FloatTensor(batch_size, nc, image_size, image_size)
z = torch.FloatTensor(batch_size, nz, 1, 1)
fixed_z = torch.FloatTensor(batch_size, nz, 1, 1).normal_(0, 1)

print(fixed_z.shape)

def normalize_(x, dim=1):
    '''
    Projects points to a sphere inplace.
    '''
    x.div_(x.norm(2, dim=dim).expand_as(x))

if noise == 'sphere':
    normalize_(fixed_z)

print(fixed_z.shape)
    
# if cuda:
#     netE.cuda()
#     netG.cuda()
#     x = x.cuda()
#     z, fixed_z = z.cuda(), fixed_z.cuda()

x = Variable(x)
z = Variable(z)
fixed_z = Variable(fixed_z)

# Setup optimizers
optimizerD = optim.Adam(netE.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

# Setup criterions
if criterion == 'param':
    print('Using parametric criterion KL_%s' % KL)
    KL_minimizer = KLN01Loss(direction=KL, minimize=True)
    KL_maximizer = KLN01Loss(direction=KL, minimize=False)
elif criterion == 'nonparam':
    print('Using NON-parametric criterion KL_%s' % KL)
    KL_minimizer = SampleKLN01Loss(direction=KL, minimize=True)
    KL_maximizer = SampleKLN01Loss(direction=KL, minimize=False)
else:
    assert False, 'criterion?'

real_cpu = torch.FloatTensor()

def save_images(epoch):

    real_cpu.resize_(x.data.size()).copy_(x.data)

    # Real samples
    save_path = '%s/real_samples.png' % save_dir
    vutils.save_image(real_cpu[:64] / 2 + 0.5, save_path)

    netG.eval()
    fake = netG(fixed_z)

    # Fake samples
    save_path = '%s/fake_samples_epoch_%03d.png' % (save_dir, epoch)
    vutils.save_image(fake.data[:64] / 2 + 0.5, save_path)

    # Save reconstructions
    populate_x(x, dataloader['val'])
    gex = netG(netE(x))

    t = torch.FloatTensor(x.size(0) * 2, x.size(1),
                          x.size(2), x.size(3))

    t[0::2] = x.data[:]
    t[1::2] = gex.data[:]

    save_path = '%s/reconstructions_epoch_%03d.png' % (save_dir, epoch)
    grid = vutils.save_image(t[:64] / 2 + 0.5, save_path)

    netG.train()

## ORIGINAL
    
# def adjust_lr(epoch):
#     if epoch % drop_lr == (drop_lr - 1):
#         lr /= 2
#         for param_group in optimizerD.param_groups:
#             param_group['lr'] = lr

#         for param_group in optimizerG.param_groups:
#             param_group['lr'] = lr
            
            
def adjust_lr(epoch):
    print(epoch, drop_lr)
    if epoch % drop_lr == (drop_lr - 1):
        ###
        assert optimizerD.param_groups[0]['lr'] == optimizerG.param_groups[0]['lr']
        print('Adjusting learning rate from %f to %f on E and G' % (lr, lr / 2))
        optimizerD.param_groups[0]['lr'] /= 2
        optimizerG.param_groups[0]['lr'] /= 2
        ###
#         lr /= 2
#         for param_group in optimizerD.param_groups:
#             param_group['lr'] = lr

#         for param_group in optimizerG.param_groups:
#             param_group['lr'] = lr

torch.Size([64, 64, 1, 1])
torch.Size([64, 64, 1, 1])
Using parametric criterion KL_qp


In [9]:
def populate_x(x, dataloader):
    '''
    Fills input variable `x` with data generated with dataloader
    '''
    real_cpu, _ = dataloader.next()
    x.data.resize_(real_cpu.size()).copy_(real_cpu)


def populate_z(z):
    '''
    Fills noise variable `z` with noise U(S^M)
    '''
    z.data.resize_(batch_size, nz, 1, 1)
    z.data.normal_(0, 1)
    if noise == 'sphere':
        normalize_(z.data)

In [10]:
def match(x, y, dist):
    '''
    Computes distance between corresponding points points in `x` and `y`
    using distance `dist`.
    '''
    if dist == 'L2':
        return (x - y).pow(2).mean()
    elif dist == 'L1':
        return (x - y).abs().mean()
    elif dist == 'cos':
        x_n = normalize(x)
        y_n = normalize(y)

        return 2 - (x_n).mul(y_n).mean()
    else:
        assert dist == 'none', 'wtf ?'


def normalize(x, dim=1):
    '''
    Projects points to a sphere.
    '''
    return x.div(x.norm(2, dim=dim).expand_as(x))

In [11]:
stats = {}
for epoch in range(start_epoch, nepoch):

    # Adjust learning rate
    adjust_lr(epoch)

    for i in range(len(dataloader['train'])):

        # ---------------------------
        #        Optimize over e
        # ---------------------------

        for e_iter in range(updates['e']['num_updates']):
            e_losses = []
            netE.zero_grad()

            # X
            populate_x(x, dataloader['train'])
            # e(X)
#             print(x.shape) # torch.Size([64, 1, 32, 32])
            ex = netE(x)

            # KL_real: - \Delta( e(X) , Z ) -> max_e
            KL_real = KL_minimizer(ex)
            e_losses.append(KL_real * updates['e']['KL_real'])

            if updates['e']['match_x'] != 0:
                # g(e(X))
                gex = netG(ex)

                # match_x: E_x||g(e(x)) - x|| -> min_e
                err = match(gex, x, opt.match_x)
                e_losses.append(err * updates['e']['match_x'])

            # Save some stats
            stats['real_mean'] = KL_minimizer.samples_mean.data.mean()
            stats['real_var'] = KL_minimizer.samples_var.data.mean()
            stats['KL_real'] = KL_real.data[0]

            # ================================================

            # Z
            populate_z(z)
            # g(Z)
            fake = netG(z).detach()
            # e(g(Z))
            egz = netE(fake)

            # KL_fake: \Delta( e(g(Z)) , Z ) -> max_e
            KL_fake = KL_maximizer(egz)
            e_losses.append(KL_fake * updates['e']['KL_fake'])

            if updates['e']['match_z'] != 0:
                # match_z: E_z||e(g(z)) - z|| -> min_e
                err = match(egz, z, match_z)
                e_losses.append(err * updates['e']['match_z'])

            # Save some stats
            stats['fake_mean'] = KL_maximizer.samples_mean.data.mean()
            stats['fake_var'] = KL_maximizer.samples_var.data.mean()
            stats['KL_fake'] = -KL_fake.data[0]

            # Update e
            sum(e_losses).backward()
            optimizerD.step()

        # ---------------------------
        #        Minimize over g
        # ---------------------------

        for g_iter in range(updates['g']['num_updates']):
            g_losses = []
            netG.zero_grad()

            # Z
            populate_z(z)
            # g(Z)
            fake = netG(z)
            # e(g(Z))
            egz = netE(fake)

            # KL_fake: \Delta( e(g(Z)) , Z ) -> min_g
            KL_fake_g = KL_minimizer(egz)
            g_losses.append(KL_fake_g * updates['g']['KL_fake'])

            if updates['g']['match_z'] != 0:
                # match_z: E_z||e(g(z)) - z|| -> min_g
                err = match(egz, z, match_z)
                err = err * updates['g']['match_z']
                g_losses.append(err)

            # ==================================

            if updates['g']['match_x'] != 0:
                # X
                populate_x(x, dataloader['train'])
                # e(X)
                ex = netE(x)

                # g(e(X))
                gex = netG(ex)

                # match_x: E_x||g(e(x)) - x|| -> min_g
                err = match(gex, x, match_x)
                err = err * updates['g']['match_x']
                g_losses.append(err)

            # Step g
            sum(g_losses).backward()
            optimizerG.step()

        print('[{epoch}/{nepoch}][{iter}/{niter}] '
              'KL_real/fake: {KL_real:.3f}/{KL_fake:.3f} '
              'mean_real/fake: {real_mean:.3f}/{fake_mean:.3f} '
              'var_real/fake: {real_var:.3f}/{fake_var:.3f} '
              ''.format(epoch=epoch,
                        nepoch=nepoch, #(this should be either nepoch-1 or start the logging display at 1 instead of 0)
                        iter=i,
                        niter=len(dataloader['train']),
                        **stats))

        if i % save_every == 0:
            save_images(epoch)

        # If an epoch takes long time, dump intermediate
        if dataset in ['lsun', 'imagenet'] and (i % 5000 == 0):
            torch.save(netG, '%s/netG_epoch_%d_it_%d.pth' %
                       (save_dir, epoch, i))
            torch.save(netE, '%s/netE_epoch_%d_it_%d.pth' %
                       (save_dir, epoch, i))

    # do checkpointing
    torch.save(netG, '%s/netG_epoch_%d.pth' % (save_dir, epoch))
    torch.save(netE, '%s/netE_epoch_%d.pth' % (save_dir, epoch))

0 1
Adjusting learning rate from 0.000200 to 0.000100 on E and G
[0/1][0/937] KL_real/fake: 4.343/4.313 mean_real/fake: 0.005/0.017 var_real/fake: 0.009/0.009 
[0/1][1/937] KL_real/fake: 4.188/4.181 mean_real/fake: 0.007/0.021 var_real/fake: 0.010/0.010 
[0/1][2/937] KL_real/fake: 4.045/4.131 mean_real/fake: 0.013/0.019 var_real/fake: 0.012/0.010 
[0/1][3/937] KL_real/fake: 3.908/4.093 mean_real/fake: 0.012/0.018 var_real/fake: 0.014/0.011 
[0/1][4/937] KL_real/fake: 3.886/4.120 mean_real/fake: 0.021/0.023 var_real/fake: 0.015/0.010 
[0/1][5/937] KL_real/fake: 3.804/4.076 mean_real/fake: 0.017/0.024 var_real/fake: 0.016/0.011 
[0/1][6/937] KL_real/fake: 3.734/3.998 mean_real/fake: 0.019/0.023 var_real/fake: 0.017/0.012 
[0/1][7/937] KL_real/fake: 3.754/4.020 mean_real/fake: 0.022/0.022 var_real/fake: 0.018/0.012 
[0/1][8/937] KL_real/fake: 3.668/4.050 mean_real/fake: 0.021/0.025 var_real/fake: 0.019/0.011 
[0/1][9/937] KL_real/fake: 3.672/4.088 mean_real/fake: 0.016/0.028 var_real/fake

[0/1][85/937] KL_real/fake: 3.387/5.402 mean_real/fake: 0.040/0.041 var_real/fake: 0.045/0.003 
[0/1][86/937] KL_real/fake: 3.248/4.811 mean_real/fake: 0.052/0.040 var_real/fake: 0.047/0.007 
[0/1][87/937] KL_real/fake: 3.967/5.032 mean_real/fake: 0.070/0.040 var_real/fake: 0.026/0.006 
[0/1][88/937] KL_real/fake: 3.809/5.038 mean_real/fake: 0.027/0.047 var_real/fake: 0.034/0.005 
[0/1][89/937] KL_real/fake: 3.439/4.724 mean_real/fake: 0.044/0.048 var_real/fake: 0.048/0.009 
[0/1][90/937] KL_real/fake: 3.655/4.812 mean_real/fake: 0.046/0.047 var_real/fake: 0.035/0.008 
[0/1][91/937] KL_real/fake: 3.385/4.312 mean_real/fake: 0.059/0.046 var_real/fake: 0.054/0.016 
[0/1][92/937] KL_real/fake: 4.172/5.050 mean_real/fake: 0.074/0.051 var_real/fake: 0.022/0.006 
[0/1][93/937] KL_real/fake: 3.920/5.499 mean_real/fake: 0.068/0.029 var_real/fake: 0.029/0.003 
[0/1][94/937] KL_real/fake: 3.753/5.312 mean_real/fake: 0.059/0.045 var_real/fake: 0.033/0.004 
[0/1][95/937] KL_real/fake: 3.112/5.251 

[0/1][170/937] KL_real/fake: 3.487/5.360 mean_real/fake: 0.083/0.044 var_real/fake: 0.057/0.005 
[0/1][171/937] KL_real/fake: 3.708/5.304 mean_real/fake: 0.059/0.035 var_real/fake: 0.047/0.004 
[0/1][172/937] KL_real/fake: 3.882/5.559 mean_real/fake: 0.065/0.041 var_real/fake: 0.043/0.004 
[0/1][173/937] KL_real/fake: 3.998/5.547 mean_real/fake: 0.064/0.044 var_real/fake: 0.027/0.003 
[0/1][174/937] KL_real/fake: 3.444/5.118 mean_real/fake: 0.030/0.036 var_real/fake: 0.050/0.006 
[0/1][175/937] KL_real/fake: 3.747/5.286 mean_real/fake: 0.049/0.045 var_real/fake: 0.040/0.005 
[0/1][176/937] KL_real/fake: 4.069/5.450 mean_real/fake: 0.059/0.040 var_real/fake: 0.025/0.003 
[0/1][177/937] KL_real/fake: 3.275/5.154 mean_real/fake: 0.067/0.043 var_real/fake: 0.059/0.006 
[0/1][178/937] KL_real/fake: 4.128/5.581 mean_real/fake: 0.034/0.045 var_real/fake: 0.029/0.003 
[0/1][179/937] KL_real/fake: 3.851/5.563 mean_real/fake: 0.044/0.038 var_real/fake: 0.032/0.003 
[0/1][180/937] KL_real/fake: 3

[0/1][255/937] KL_real/fake: 3.191/4.883 mean_real/fake: 0.058/0.044 var_real/fake: 0.076/0.008 
[0/1][256/937] KL_real/fake: 4.437/6.147 mean_real/fake: 0.045/0.050 var_real/fake: 0.013/0.002 
[0/1][257/937] KL_real/fake: 4.081/5.992 mean_real/fake: 0.060/0.039 var_real/fake: 0.039/0.002 
[0/1][258/937] KL_real/fake: 3.133/4.896 mean_real/fake: 0.050/0.068 var_real/fake: 0.067/0.009 
[0/1][259/937] KL_real/fake: 4.822/6.331 mean_real/fake: 0.027/0.039 var_real/fake: 0.012/0.001 
[0/1][260/937] KL_real/fake: 3.976/5.808 mean_real/fake: 0.053/0.035 var_real/fake: 0.037/0.002 
[0/1][261/937] KL_real/fake: 4.028/5.836 mean_real/fake: 0.042/0.046 var_real/fake: 0.030/0.002 
[0/1][262/937] KL_real/fake: 3.514/5.270 mean_real/fake: 0.048/0.043 var_real/fake: 0.056/0.005 
[0/1][263/937] KL_real/fake: 4.056/5.789 mean_real/fake: 0.084/0.041 var_real/fake: 0.036/0.002 
[0/1][264/937] KL_real/fake: 3.863/5.499 mean_real/fake: 0.090/0.047 var_real/fake: 0.053/0.004 
[0/1][265/937] KL_real/fake: 3

[0/1][340/937] KL_real/fake: 4.880/6.452 mean_real/fake: 0.046/0.041 var_real/fake: 0.007/0.001 
[0/1][341/937] KL_real/fake: 4.082/5.949 mean_real/fake: 0.095/0.047 var_real/fake: 0.031/0.002 
[0/1][342/937] KL_real/fake: 3.262/4.734 mean_real/fake: 0.038/0.044 var_real/fake: 0.053/0.009 
[0/1][343/937] KL_real/fake: 4.410/5.982 mean_real/fake: 0.053/0.045 var_real/fake: 0.020/0.002 
[0/1][344/937] KL_real/fake: 3.851/5.819 mean_real/fake: 0.045/0.048 var_real/fake: 0.031/0.003 
[0/1][345/937] KL_real/fake: 3.137/5.237 mean_real/fake: 0.091/0.046 var_real/fake: 0.080/0.005 
[0/1][346/937] KL_real/fake: 4.405/6.008 mean_real/fake: 0.003/0.040 var_real/fake: 0.027/0.002 
[0/1][347/937] KL_real/fake: 4.144/5.873 mean_real/fake: 0.048/0.048 var_real/fake: 0.024/0.003 
[0/1][348/937] KL_real/fake: 3.086/5.166 mean_real/fake: 0.062/0.058 var_real/fake: 0.082/0.006 
[0/1][349/937] KL_real/fake: 4.516/5.946 mean_real/fake: 0.056/0.044 var_real/fake: 0.015/0.002 
[0/1][350/937] KL_real/fake: 3

[0/1][425/937] KL_real/fake: 5.874/6.979 mean_real/fake: 0.038/0.040 var_real/fake: 0.002/0.001 
[0/1][426/937] KL_real/fake: 5.384/6.377 mean_real/fake: 0.041/0.041 var_real/fake: 0.004/0.001 
[0/1][427/937] KL_real/fake: 4.242/5.308 mean_real/fake: 0.089/0.052 var_real/fake: 0.034/0.006 
[0/1][428/937] KL_real/fake: 3.894/4.914 mean_real/fake: 0.065/0.049 var_real/fake: 0.035/0.017 
[0/1][429/937] KL_real/fake: 3.637/4.517 mean_real/fake: 0.056/0.072 var_real/fake: 0.073/0.024 
[0/1][430/937] KL_real/fake: 4.418/6.220 mean_real/fake: 0.054/0.043 var_real/fake: 0.016/0.002 
[0/1][431/937] KL_real/fake: 4.507/5.873 mean_real/fake: 0.054/0.048 var_real/fake: 0.014/0.002 
[0/1][432/937] KL_real/fake: 3.685/5.153 mean_real/fake: 0.057/0.051 var_real/fake: 0.037/0.006 
[0/1][433/937] KL_real/fake: 3.044/3.976 mean_real/fake: 0.038/0.070 var_real/fake: 0.068/0.034 
[0/1][434/937] KL_real/fake: 4.306/5.881 mean_real/fake: 0.044/0.045 var_real/fake: 0.018/0.003 
[0/1][435/937] KL_real/fake: 4

[0/1][510/937] KL_real/fake: 3.714/5.519 mean_real/fake: 0.050/0.055 var_real/fake: 0.035/0.004 
[0/1][511/937] KL_real/fake: 3.536/5.104 mean_real/fake: 0.057/0.039 var_real/fake: 0.046/0.006 
[0/1][512/937] KL_real/fake: 3.560/5.134 mean_real/fake: 0.052/0.034 var_real/fake: 0.055/0.005 
[0/1][513/937] KL_real/fake: 3.515/5.209 mean_real/fake: 0.080/0.040 var_real/fake: 0.064/0.005 
[0/1][514/937] KL_real/fake: 4.212/5.894 mean_real/fake: 0.070/0.050 var_real/fake: 0.022/0.002 
[0/1][515/937] KL_real/fake: 3.334/5.001 mean_real/fake: 0.031/0.034 var_real/fake: 0.044/0.007 
[0/1][516/937] KL_real/fake: 4.308/5.715 mean_real/fake: 0.049/0.049 var_real/fake: 0.016/0.003 
[0/1][517/937] KL_real/fake: 4.069/5.733 mean_real/fake: 0.084/0.043 var_real/fake: 0.033/0.003 
[0/1][518/937] KL_real/fake: 4.000/5.221 mean_real/fake: 0.048/0.046 var_real/fake: 0.034/0.005 
[0/1][519/937] KL_real/fake: 3.523/5.438 mean_real/fake: 0.072/0.032 var_real/fake: 0.053/0.004 
[0/1][520/937] KL_real/fake: 5

[0/1][595/937] KL_real/fake: 3.781/5.096 mean_real/fake: 0.072/0.046 var_real/fake: 0.064/0.007 
[0/1][596/937] KL_real/fake: 3.138/4.090 mean_real/fake: 0.052/0.086 var_real/fake: 0.081/0.037 
[0/1][597/937] KL_real/fake: 4.902/6.181 mean_real/fake: 0.036/0.036 var_real/fake: 0.007/0.002 
[0/1][598/937] KL_real/fake: 5.594/6.540 mean_real/fake: 0.043/0.046 var_real/fake: 0.003/0.001 
[0/1][599/937] KL_real/fake: 5.354/6.496 mean_real/fake: 0.050/0.044 var_real/fake: 0.004/0.001 
[0/1][600/937] KL_real/fake: 4.574/5.823 mean_real/fake: 0.044/0.044 var_real/fake: 0.011/0.003 
[0/1][601/937] KL_real/fake: 3.620/4.866 mean_real/fake: 0.051/0.056 var_real/fake: 0.052/0.014 
[0/1][602/937] KL_real/fake: 3.475/5.245 mean_real/fake: 0.055/0.053 var_real/fake: 0.041/0.005 
[0/1][603/937] KL_real/fake: 3.324/4.925 mean_real/fake: 0.061/0.042 var_real/fake: 0.061/0.008 
[0/1][604/937] KL_real/fake: 4.299/5.411 mean_real/fake: 0.049/0.048 var_real/fake: 0.022/0.005 
[0/1][605/937] KL_real/fake: 3

[0/1][680/937] KL_real/fake: 3.526/4.675 mean_real/fake: 0.057/0.062 var_real/fake: 0.039/0.010 
[0/1][681/937] KL_real/fake: 4.049/5.513 mean_real/fake: 0.037/0.040 var_real/fake: 0.025/0.003 
[0/1][682/937] KL_real/fake: 3.971/5.095 mean_real/fake: 0.058/0.038 var_real/fake: 0.028/0.006 
[0/1][683/937] KL_real/fake: 3.529/4.612 mean_real/fake: 0.097/0.063 var_real/fake: 0.058/0.015 
[0/1][684/937] KL_real/fake: 3.988/5.636 mean_real/fake: 0.049/0.048 var_real/fake: 0.029/0.003 
[0/1][685/937] KL_real/fake: 3.629/5.401 mean_real/fake: 0.083/0.046 var_real/fake: 0.058/0.005 
[0/1][686/937] KL_real/fake: 3.089/4.297 mean_real/fake: 0.055/0.060 var_real/fake: 0.064/0.025 
[0/1][687/937] KL_real/fake: 5.351/6.023 mean_real/fake: 0.029/0.038 var_real/fake: 0.004/0.002 
[0/1][688/937] KL_real/fake: 4.912/6.030 mean_real/fake: 0.045/0.046 var_real/fake: 0.008/0.002 
[0/1][689/937] KL_real/fake: 3.826/4.905 mean_real/fake: 0.071/0.065 var_real/fake: 0.036/0.010 
[0/1][690/937] KL_real/fake: 3

[0/1][765/937] KL_real/fake: 3.483/4.211 mean_real/fake: 0.055/0.073 var_real/fake: 0.052/0.026 
[0/1][766/937] KL_real/fake: 3.837/5.068 mean_real/fake: 0.129/0.053 var_real/fake: 0.075/0.008 
[0/1][767/937] KL_real/fake: 4.165/5.178 mean_real/fake: 0.061/0.037 var_real/fake: 0.018/0.005 
[0/1][768/937] KL_real/fake: 3.042/5.029 mean_real/fake: 0.045/0.042 var_real/fake: 0.117/0.007 
[0/1][769/937] KL_real/fake: 5.022/6.197 mean_real/fake: 0.023/0.036 var_real/fake: 0.007/0.002 
[0/1][770/937] KL_real/fake: 5.212/6.274 mean_real/fake: 0.047/0.046 var_real/fake: 0.005/0.002 
[0/1][771/937] KL_real/fake: 4.151/5.414 mean_real/fake: 0.065/0.047 var_real/fake: 0.032/0.004 
[0/1][772/937] KL_real/fake: 3.434/4.611 mean_real/fake: 0.106/0.053 var_real/fake: 0.080/0.017 
[0/1][773/937] KL_real/fake: 4.214/5.567 mean_real/fake: 0.084/0.044 var_real/fake: 0.035/0.004 
[0/1][774/937] KL_real/fake: 3.897/5.509 mean_real/fake: 0.081/0.059 var_real/fake: 0.039/0.005 
[0/1][775/937] KL_real/fake: 3

[0/1][850/937] KL_real/fake: 4.328/5.441 mean_real/fake: 0.064/0.060 var_real/fake: 0.020/0.005 
[0/1][851/937] KL_real/fake: 4.345/5.396 mean_real/fake: 0.053/0.063 var_real/fake: 0.017/0.006 
[0/1][852/937] KL_real/fake: 3.397/4.450 mean_real/fake: 0.034/0.061 var_real/fake: 0.050/0.018 
[0/1][853/937] KL_real/fake: 3.699/4.927 mean_real/fake: 0.064/0.044 var_real/fake: 0.044/0.008 
[0/1][854/937] KL_real/fake: 3.804/5.386 mean_real/fake: 0.062/0.052 var_real/fake: 0.036/0.005 
[0/1][855/937] KL_real/fake: 4.327/5.665 mean_real/fake: 0.033/0.048 var_real/fake: 0.016/0.003 
[0/1][856/937] KL_real/fake: 3.595/5.008 mean_real/fake: 0.051/0.050 var_real/fake: 0.048/0.008 
[0/1][857/937] KL_real/fake: 3.922/5.419 mean_real/fake: 0.056/0.055 var_real/fake: 0.031/0.004 
[0/1][858/937] KL_real/fake: 4.304/5.559 mean_real/fake: 0.035/0.041 var_real/fake: 0.015/0.003 
[0/1][859/937] KL_real/fake: 3.287/4.388 mean_real/fake: 0.047/0.035 var_real/fake: 0.056/0.017 
[0/1][860/937] KL_real/fake: 4

[0/1][935/937] KL_real/fake: 3.973/4.806 mean_real/fake: 0.058/0.046 var_real/fake: 0.035/0.010 
[0/1][936/937] KL_real/fake: 3.709/4.717 mean_real/fake: 0.109/0.049 var_real/fake: 0.067/0.015 


  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
