In [3]:
import sys
sys.path.append('../utils')

import libraries
# from libraries import *
import utils_general
from utils_general import *

%matplotlib inline

#########

from __future__ import print_function
import argparse
import torch
import torch.nn.parallel

import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
import torchvision.utils as vutils
from torch.autograd import Variable

parser = argparse.ArgumentParser()
parser.add_argument('--dataset', required=True,
                    help='cifar10 | lsun | imagenet | folder | lfw ')
parser.add_argument('--dataroot', type=str, help='path to dataset')
parser.add_argument('--workers', type=int,
                    help='number of data loading workers', default=8)
parser.add_argument('--batch_size', type=int,
                    default=64, help='batch size')
parser.add_argument('--image_size', type=int, default=32,
                    help='the resolution of the input image to network')
parser.add_argument('--nz', type=int, default=100,
                    help='size of the latent z vector')
parser.add_argument('--ngf', type=int, default=64)
parser.add_argument('--ndf', type=int, default=64)
parser.add_argument('--nc', type=int)

parser.add_argument('--nepoch', type=int, default=25,
                    help='number of epochs to train for')
parser.add_argument('--lr', type=float, default=0.0002,
                    help='learning rate, default=0.0002')
parser.add_argument('--beta1', type=float, default=0.5,
                    help='beta1 for adam. default=0.5')
parser.add_argument('--cpu', action='store_true',
                    help='use CPU instead of GPU')
parser.add_argument('--ngpu', type=int, default=1,
                    help='number of GPUs to use')

parser.add_argument('--netG', default='',
                    help="path to netG config")
parser.add_argument('--netE', default='',
                    help="path to netE config")
parser.add_argument('--netG_chp', default='',
                    help="path to netG (to continue training)")
parser.add_argument('--netE_chp', default='',
                    help="path to netE (to continue training)")

parser.add_argument('--save_dir', default='.',
                    help='folder to output images and model checkpoints')
parser.add_argument('--criterion', default='param',
                    help='param|nonparam, How to estimate KL')
parser.add_argument('--KL', default='qp', help='pq|qp')
parser.add_argument('--noise', default='sphere', help='normal|sphere')
parser.add_argument('--match_z', default='cos', help='none|L1|L2|cos')
parser.add_argument('--match_x', default='L1', help='none|L1|L2|cos')

parser.add_argument('--drop_lr', default=5, type=int, help='')
parser.add_argument('--save_every', default=50, type=int, help='')

parser.add_argument('--manual_seed', type=int, default=123, help='manual seed')
parser.add_argument('--start_epoch', type=int, default=0, help='epoch number to start with')

parser.add_argument(
    '--e_updates', default="1;KL_fake:1,KL_real:1,match_z:0,match_x:0",
    help='Update plan for encoder <number of updates>;[<term:weight>]'
)

parser.add_argument(
    '--g_updates', default="2;KL_fake:1,match_z:1,match_x:0",
    help='Update plan for generator <number of updates>;[<term:weight>]'
)

# opt = parser.parse_args()

# python age.py 
# --dataset celeba 
# --dataroot <data_root> 
# --image_size 64 
# --save_dir <save_dir> 
# --lr 0.0002 
# --nz 64 
# --batch_size 64 
# --netG dcgan64px 
# --netE dcgan64px 
# --nepoch 5 
# --drop_lr 5 
# --e_updates '1;KL_fake:1,KL_real:1,match_z:0,match_x:10' 
# --g_updates '3;KL_fake:1,match_z:1000,match_x:0'

# python age.py 
# --dataset celeba 
# --dataroot <data_root> 
# --image_size 64 
# --save_dir <save_dir> 
# --start_epoch 5 
# --lr 0.0002 
# --nz 64 
# --batch_size 256 
# --netG dcgan64px 
# --netE dcgan64px 
# --nepoch 6 
# --drop_lr 5   
# --e_updates '1;KL_fake:1,KL_real:1,match_z:0,match_x:15' 
# --g_updates '3;KL_fake:1,match_z:1000,match_x:0' 
# --netE_chp  <save_dir>/netE_epoch_5.pth 
# --netG_chp <save_dir>/netG_epoch_5.pth

_StoreAction(option_strings=['--g_updates'], dest='g_updates', nargs=None, const=None, default='2;KL_fake:1,match_z:1,match_x:0', type=None, choices=None, help='Update plan for generator <number of updates>;[<term:weight>]', metavar=None)

In [5]:
# print(get_available_gpus())
print(psutil.virtual_memory())

svmem(total=64388890624, available=63188000768, percent=1.9, used=652193792, free=52520173568, active=5063626752, inactive=6215786496, buffers=81018880, cached=11135504384, shared=9723904)


#### Setting up the options

In [6]:
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torch.utils.data
import importlib
# from .dataset import FolderWithImages
import random
import torch.backends.cudnn as cudnn
from PIL import Image

In [9]:
# setup function @utils.py

cuda = True # not opt.cpu
torch.set_num_threads(4)
dataset = 'helen'
save_dir = 'output/helen'
try:
    os.makedirs(save_dir)
except OSError:
    print('Directory was not created.')

manual_seed = random.randint(1, 10000)
random.seed(manual_seed)
torch.manual_seed(manual_seed)
# torch.cuda.manual_seed_all(manual_seed)

cudnn.benchmark = True

if torch.cuda.is_available() and not cuda:
    print("WARNING: You have a CUDA device,"
            "so you should probably run with --cuda")

e_updates = '1;KL_fake:1,KL_real:1,match_z:0,match_x:10' 
g_updates = '3;KL_fake:1,match_z:1000,match_x:0'
    
updates = {'e': {}, 'g': {}}
updates['e']['num_updates'] = int(e_updates.split(';')[0])
updates['e'].update({x.split(':')[0]: float(x.split(':')[1]) 
                     for x in e_updates.split(';')[1].split(',')})

updates['g']['num_updates'] = int(g_updates.split(';')[0])
updates['g'].update({x.split(':')[0]: float(x.split(':')[1]) 
                     for x in g_updates.split(';')[1].split(',')})


###################

image_size = 64 # 512
batch_size = 32 # 64
workers = 8
shuffle = True
drop_last = True
train = True

nc = 3
nz = 64 #100 #'size of the latent z vector')
ngf = 64
ndf = 64
ngpu = 0 # 1 # 'number of GPUs to use')
pin_memory = False # True | If ``True``, the data loader will copy tensors into CUDA pinned memory before returning them.

noise = 'sphere' #help='normal|sphere')

netG = 'dcgan64px' 
netE = 'dcgan64px' 

netG_chp = '' # "path to netG (to continue training)"
netE_chp = '' # "path to netE (to continue training)"

lr = 0.0002 
drop_lr = 20
beta1 = 0.5 # help='beta1 for adam. default=0.5'
criterion = 'param'# help='param|nonparam, How to estimate KL'
KL ='qp' # help='pq|qp'
match_z = 'cos' # help='none|L1|L2|cos'
match_x = 'L1' # help='none|L1|L2|cos'

start_epoch = 0
nepoch = 100

save_every = 10
dataroot = '../data/helen_iso_r'

# --lr 0.0002 
# --nz 64 
# --batch_size 64 
# --netG dcgan64px 
# --netE dcgan64px 
# --nepoch 5 
# --drop_lr 5 
# --e_updates '1;KL_fake:1,KL_real:1,match_z:0,match_x:10' 
# --g_updates '3;KL_fake:1,match_z:1000,match_x:0'

Directory was not created.


In [10]:
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torch.utils.data
import importlib
import random
import os
import torch.backends.cudnn as cudnn
from PIL import Image

##
import torch.utils.data as data
from os import listdir
from os.path import join
from PIL import Image

def is_image_file(filename):
    return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"])

def load_img(filepath):
    img = cv2.imread(filepath) #, cv2.IMREAD_UNCHANGED)
#     img = Image.open(filepath).convert('RGB')
    return img

class FolderWithImages(data.Dataset):
    def __init__(self, root, input_transform=None, target_transform=None):
        super(FolderWithImages, self).__init__()
        self.image_filenames = [join(root, x)
                                for x in listdir(root) if is_image_file(x.lower())]

        self.input_transform = input_transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        input = load_img(self.image_filenames[index])
        target = input.copy()
        if self.input_transform:
            input = self.input_transform(input)
        if self.target_transform:
            target = self.target_transform(target)

        return input, target

    def __len__(self):
        return len(self.image_filenames)
    
##

def setup_dataset(dataset, dataroot, train=True, shuffle=True, drop_last=True):
    '''
    Setups dataset.
    '''
    # Usual transform
    t = transforms.Compose([
        transforms.Scale([image_size, image_size]),
        transforms.ToTensor(),
        transforms.Normalize(
            (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    if dataset in ['imagenet', 'folder', 'lfw']:
        imdir = 'train' if train else 'val'
        dataroot = os.path.join(dataroot, imdir)

        dataset = dset.ImageFolder(root=dataroot, transform=t)
    elif dataset == 'lsun':
        dataset = dset.LSUN(db_path=dataroot,
                            classes=['bedroom_train'],
                            train=train,
                            transform=t)
    elif dataset == 'cifar10':
        dataset = dset.CIFAR10(root='data/raw/cifar10',
                               download=True,
                               train=train,
                               transform=t
                               )
    elif dataset == 'mnist':
        dataset = dset.MNIST(root='data/raw/mnist',
                             download=True,
                             train=train,
                             transform=t
                             )
    elif dataset == 'svhn':
        dataset = dset.SVHN(root='data/raw/svhn',
                            download=True,
                            train=train,
                            transform=t)
    elif dataset == 'celeba':
        imdir = 'train' if train else 'val'
        dataroot = os.path.join(dataroot, imdir)

        dataset = FolderWithImages(root=dataroot,
                                   input_transform=transforms.Compose([
                                       ALICropAndScale(),
                                       transforms.ToTensor(),
                                       transforms.Normalize(
                                           (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                   ]),
                                   target_transform=transforms.ToTensor()
                                   )
    elif dataset == 'helen':
        imdir = 'train' if train else 'val'
        dataroot = os.path.join(dataroot, imdir)

        dataset = FolderWithImages(root=dataroot,
                                   input_transform=transforms.Compose([
                                       Scale(),
                                       transforms.ToTensor(),
                                       transforms.Normalize(
                                           (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                   ]),
                                   target_transform=transforms.ToTensor()
                                   )
    else:
        assert False, 'Wrong dataset name.'

    assert len(dataset) > 0, 'No images found, check your paths.'

    # Shuffle and drop last when training
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=batch_size,
                                             shuffle=shuffle,
                                             num_workers=int(workers),
                                             pin_memory=pin_memory,
                                             drop_last=drop_last)
#     print(len(dataloader))
#     for i in dataloader: print(i[0].shape, i[1].shape)
        
    return InfiniteDataLoader(dataloader)


class InfiniteDataLoader(object):
    """docstring for InfiniteDataLoader"""

    def __init__(self, dataloader):
        super(InfiniteDataLoader, self).__init__()
        self.dataloader = dataloader
        self.data_iter = None

    def next(self):
        try:
            data = self.data_iter.next()
        except Exception:
            # Reached end of the dataset
            self.data_iter = iter(self.dataloader)
            data = self.data_iter.next()

        return data

    def __len__(self):
        return len(self.dataloader)

class ALICropAndScale(object):
    def __call__(self, img):
        return img.resize((64, 78), Image.ANTIALIAS).crop((0, 7, 64, 64 + 7))
    
class Scale(object):
    def __call__(self, img):
        img = cv2.resize(img, (64, 64), cv2.INTER_AREA)
        return img 

# cv::INTER_AREA interpolation, whereas to
# .   enlarge an image, it will generally look best with cv::INTER_CUBIC (slow) or cv::INTER_LINEAR
# .   (faster but still looks OK).
    
# Setup dataset
dataloader = dict(train=setup_dataset(dataset, dataroot, train=True),
                  val=setup_dataset(dataset, dataroot, train=False))



In [11]:
dl = setup_dataset(dataset, dataroot, train=True)
len(dataloader['val'])



2

In [12]:
def normalize(x, dim=0):
    '''
    Projects points to a sphere.
    '''
    return x.div(x.norm(2, dim=dim).expand_as(x))

def normalize_(x, dim=0):
    '''
    Projects points to a sphere inplace.
    '''
    x.div_(x.norm(2, dim=dim).expand_as(x))

In [13]:
def weights_init(m):
    '''
    Custom weights initialization called on netG and netE
    '''
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

class _netE_Base(nn.Module):
    def __init__(self, main):
        super(_netE_Base, self).__init__()
        self.noise = noise
        self.ngpu = ngpu
        self.main = main

    def forward(self, input):
        gpu_ids = None
        if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 0:
            gpu_ids = range(self.ngpu)
            output = nn.parallel.data_parallel(self.main, input, gpu_ids)
        else:
            output = self.main(input)
        output = output.view(output.size(0), -1)
        if self.noise == 'sphere':
            output = normalize(output)
        return output
    
class _netG_Base(nn.Module):
    def __init__(self, main):
        super(_netG_Base, self).__init__()
        self.ngpu = ngpu
        self.main = main

    def forward(self, input):

        # Check input is either (B,C,1,1) or (B,C)
        assert input.nelement() == input.size(0) * input.size(1), 'wtf'
        input = input.view(input.size(0), input.size(1), 1, 1)

        gpu_ids = None
        if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 0:
            gpu_ids = range(self.ngpu)
            return nn.parallel.data_parallel(self.main, input, gpu_ids)
        else:
            return self.main(input)


def _netE():
#     ndf = opt.ndf
#     nc = opt.nc
#     nz = opt.nz

    main = nn.Sequential(
        # input is (nc) x 64 x 64
        nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
        nn.LeakyReLU(0.2, inplace=True),
        # state size. (ndf) x 16 x 16
        nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
        nn.BatchNorm2d(ndf * 2),
        nn.LeakyReLU(0.2, inplace=True),
        # state size. (ndf*2) x 8 x 8
        nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
        nn.BatchNorm2d(ndf * 4),
        nn.LeakyReLU(0.2, inplace=True),
        # state size. (ndf*4) x 4 x 4
        nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
        nn.BatchNorm2d(ndf * 8),
        nn.LeakyReLU(0.2, inplace=True),
        # state size. (ndf*8) x 4 x 4
        nn.Conv2d(ndf * 8, nz, 4, 1, 0, bias=True),
    )

    return _netE_Base(main)

def _netG():
#     ngf = opt.ngf
#     nc = opt.nc
#     nz = opt.nz

    main = nn.Sequential(
        # input is Z, going into a convolution
        nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
        nn.BatchNorm2d(ngf * 8),
        nn.ReLU(True),
        # state size. (ngf*8) x 4 x 4
        nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
        nn.BatchNorm2d(ngf * 4),
        nn.ReLU(True),
        # state size. (ngf*4) x 8 x 8
        nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
        nn.BatchNorm2d(ngf * 2),
        nn.ReLU(True),
        # state size. (ngf*2) x 16 x 16
        nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
        nn.BatchNorm2d(ngf),
        nn.ReLU(True),
        # state size. (ngf) x 32 x 32
        nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
        nn.Tanh()
        # state size. (nc) x 64 x 64
    )

    return _netG_Base(main)

def load_G():
    '''
    Loads generator model.
    '''
    netG = _netG()
    netG.apply(weights_init)
    netG.train()
    if netG_chp != '':
        netG.load_state_dict(torch.load(netG_chp).state_dict())

    print('Generator\n', netG)
    return netG

def load_E():
    '''
    Loads encoder model.
    '''
    netE = _netE()
    netE.apply(weights_init)
    netE.train()
    if netE_chp != '':
        netE.load_state_dict(torch.load(netE_chp).state_dict())

    print('Encoder\n', netE)

    return netE

# Load generator
netG = load_G()

# Load encoder
netE = load_E()

# RuntimeError: Given transposed=1, weight[64, 512, 4, 4], so expected input[64, 256, 1, 1] to have 64 channels,
# but got 256 channels instead

Generator
 _netG_Base(
  (main): Sequential(
    (0): ConvTranspose2d (64, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
    (2): ReLU(inplace)
    (3): ConvTranspose2d (512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
    (5): ReLU(inplace)
    (6): ConvTranspose2d (256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
    (8): ReLU(inplace)
    (9): ConvTranspose2d (128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
    (11): ReLU(inplace)
    (12): ConvTranspose2d (64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Tanh()
  )
)
Encoder
 _netE_Base(
  (main): Sequential(
    (0): Conv2d (3, 64, kernel_size=(4, 4), stride=(2, 2), paddin

#### Pretraining preparation

In [14]:
def var(x, dim=0):
    '''
    Calculates variance.
    '''
    x_zero_meaned = x - x.mean(dim).expand_as(x)
    return x_zero_meaned.pow(2).mean(dim)


class KLN01Loss(torch.nn.Module):

    def __init__(self, direction, minimize):
        super(KLN01Loss, self).__init__()
        self.minimize = minimize
        assert direction in ['pq', 'qp'], 'direction?'

        self.direction = direction

    def forward(self, samples):

        assert samples.nelement() == samples.size(1) * samples.size(0), 'wtf?'

        samples = samples.view(samples.size(0), -1)

        self.samples_var = var(samples)
        self.samples_mean = samples.mean(0)

        samples_mean = self.samples_mean
        samples_var = self.samples_var

        if self.direction == 'pq':
            # mu_1 = 0; sigma_1 = 1

            t1 = (1 + samples_mean.pow(2)) / (2 * samples_var.pow(2))
            t2 = samples_var.log()

            KL = (t1 + t2 - 0.5).mean()
        else:
            # mu_2 = 0; sigma_2 = 1

            t1 = (samples_var.pow(2) + samples_mean.pow(2)) / 2
            t2 = -samples_var.log()

            KL = (t1 + t2 - 0.5).mean()

        if not self.minimize:
            KL *= -1

        return KL

# ----------------------------

def pairwise_euclidean(samples):

    B = samples.size(0)

    samples_norm = samples.mul(samples).sum(1)
    samples_norm = samples_norm.expand(B, B)

    dist_mat = samples.mm(samples.t()).mul(-2) + \
        samples_norm.add(samples_norm.t())
    return dist_mat

def sample_entropy(samples):

        # Assume B x C input

    dist_mat = pairwise_euclidean(samples)

    # Get max and add it to diag
    m = dist_mat.max().detach()
    dist_mat_d = dist_mat + \
        Variable(torch.eye(dist_mat.size(0)) * (m.data[0] + 1)).cuda()

    entropy = (dist_mat_d.min(1)[0] + 1e-4).log().sum()

    entropy *= (samples.size(1) + 0.) / samples.size(0)

    return entropy

class SampleKLN01Loss(torch.nn.Module):

    def __init__(self, direction, minimize):
        super(SampleKLN01Loss, self).__init__()
        self.minimize = minimize
        assert direction in ['pq', 'qp'], 'direction?'

        self.direction = direction

    def forward(self, samples):

        assert samples.ndimension == 2, 'wft'
        samples = samples.view(samples.size(0), -1)

        self.samples_var = var(samples)
        self.samples_mean = samples.mean(0)

        if self.direction == 'pq':
            assert False, 'not possible'
        else:
            entropy = sample_entropy(samples)

            cross_entropy = - samples.pow(2).mean() / 2.

            KL = - cross_entropy - entropy

        if not self.minimize:
            KL *= -1

        return KL

In [15]:
x = torch.FloatTensor(batch_size, nc, image_size, image_size)
z = torch.FloatTensor(batch_size, nz, 1, 1)
fixed_z = torch.FloatTensor(batch_size, nz, 1, 1).normal_(0, 1)
z = fixed_z

In [16]:
x = torch.FloatTensor(batch_size, nc, image_size, image_size)
z = torch.FloatTensor(batch_size, nz, 1, 1)
fixed_z = torch.FloatTensor(batch_size, nz, 1, 1).normal_(0, 1)

def match(x, y, dist):
    '''
    Computes distance between corresponding points points in `x` and `y`
    using distance `dist`.
    '''
    if dist == 'L2':
        return (x - y).pow(2).mean()
    elif dist == 'L1':
        return (x - y).abs().mean()
    elif dist == 'cos':
        x_n = normalize(x)
        y_n = normalize(y)

        return 2 - (x_n).mul(y_n).mean()
    else:
        assert dist == 'none', 'wtf ?'

def normalize(x, dim=0):
    '''
    Projects points to a sphere.
    '''
    return x.div(x.norm(2, dim=dim).expand_as(x))

def normalize_(x, dim=0):
    '''
    Projects points to a sphere inplace.
    '''
    x.div_(x.norm(2, dim=dim).expand_as(x))

if noise == 'sphere':
    normalize_(fixed_z)

# if cuda:
#     netE.cuda()
#     netG.cuda()
#     x = x.cuda()
#     z, fixed_z = z.cuda(), fixed_z.cuda()

x = Variable(x)
z = Variable(z)
fixed_z = Variable(fixed_z)

# Setup optimizers
optimizerD = optim.Adam(netE.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

# Setup criterions
if criterion == 'param':
    print('Using parametric criterion KL_%s' % KL)
    KL_minimizer = KLN01Loss(direction=KL, minimize=True)
    KL_maximizer = KLN01Loss(direction=KL, minimize=False)
elif criterion == 'nonparam':
    print('Using NON-parametric criterion KL_%s' % KL)
    KL_minimizer = SampleKLN01Loss(direction=KL, minimize=True)
    KL_maximizer = SampleKLN01Loss(direction=KL, minimize=False)
else:
    assert False, 'criterion?'

real_cpu = torch.FloatTensor()

def save_images(epoch):

    real_cpu.resize_(x.data.size()).copy_(x.data)

    # Real samples
    save_path = '%s/real_samples.png' % save_dir
    vutils.save_image(real_cpu[:64] / 2 + 0.5, save_path)

    netG.eval()
    fake = netG(fixed_z)

    # Fake samples
    save_path = '%s/fake_samples_epoch_%03d.png' % (save_dir, epoch)
    vutils.save_image(fake.data[:64] / 2 + 0.5, save_path)

    # Save reconstructions
    populate_x(x, dataloader['val'])
    gex = netG(netE(x)) # here - at G entry

    t = torch.FloatTensor(x.size(0) * 2, x.size(1),
                          x.size(2), x.size(3))

    t[0::2] = x.data[:]
    t[1::2] = gex.data[:]

    save_path = '%s/reconstructions_epoch_%03d.png' % (save_dir, epoch)
    grid = vutils.save_image(t[:64] / 2 + 0.5, save_path)

#     torch.Size([64, 3, 64, 64])
    
    netG.train()

def adjust_lr(epoch):
    if epoch % drop_lr == (drop_lr - 1):
        ###
        assert optimizerD.param_groups[0]['lr'] == optimizerG.param_groups[0]['lr']
        print('Adjusting learning rate from %f to %f on E and G' % (lr, lr / 2))
        optimizerD.param_groups[0]['lr'] /= 2
        optimizerG.param_groups[0]['lr'] /= 2
        ###
#         lr /= 2
#         for param_group in optimizerD.param_groups:
#             param_group['lr'] = lr

#         for param_group in optimizerG.param_groups:
#             param_group['lr'] = lr

Using parametric criterion KL_qp


In [17]:
def populate_x(x, dataloader):
    '''
    Fills input variable `x` with data generated with dataloader
    '''
    real_cpu, _ = dataloader.next()
    x.data.resize_(real_cpu.size()).copy_(real_cpu)

def populate_z(z):
    '''
    Fills noise variable `z` with noise U(S^M)
    '''
    z.data.resize_(batch_size, nz, 1, 1)
    z.data.normal_(0, 1)
    if noise == 'sphere':
        normalize_(z.data)

In [None]:
print(len(dataloader['train'].next()))
print(len(dataloader['val'].next()))
print(dataloader['train'].next()[0].shape, dataloader['train'].next()[1].shape)
print(dataloader['train'].next()[0].shape, dataloader['train'].next()[1].shape)

2
2
torch.Size([32, 3, 64, 64]) torch.Size([32, 3, 512, 512])
torch.Size([32, 3, 64, 64]) torch.Size([32, 3, 512, 512])


In [None]:
stats = {}
for epoch in range(start_epoch, nepoch):

    # Adjust learning rate
    adjust_lr(epoch)

    for i in range(len(dataloader['train'])):

        # ---------------------------
        #        Optimize over e
        # ---------------------------

        for e_iter in range(updates['e']['num_updates']):
            e_losses = []
            netE.zero_grad()

            # X
            populate_x(x, dataloader['train'])
            # e(X)
            ex = netE(x)

            # KL_real: - \Delta( e(X) , Z ) -> max_e
            KL_real = KL_minimizer(ex)
            e_losses.append(KL_real * updates['e']['KL_real'])

            if updates['e']['match_x'] != 0:
                # g(e(X))
                gex = netG(ex) # HERE:::

                # match_x: E_x||g(e(x)) - x|| -> min_e
                err = match(gex, x, match_x)
                e_losses.append(err * updates['e']['match_x'])

            # Save some stats
            stats['real_mean'] = KL_minimizer.samples_mean.data.mean()
            stats['real_var'] = KL_minimizer.samples_var.data.mean()
            stats['KL_real'] = KL_real.data[0]

            # ================================================

            # Z
            populate_z(z)
            # g(Z)
            fake = netG(z).detach()
            # e(g(Z))
            egz = netE(fake)

            # KL_fake: \Delta( e(g(Z)) , Z ) -> max_e
            KL_fake = KL_maximizer(egz)
            e_losses.append(KL_fake * updates['e']['KL_fake'])

            if updates['e']['match_z'] != 0:
                # match_z: E_z||e(g(z)) - z|| -> min_e
                err = match(egz, z, match_z)
                e_losses.append(err * updates['e']['match_z'])

            # Save some stats
            stats['fake_mean'] = KL_maximizer.samples_mean.data.mean()
            stats['fake_var'] = KL_maximizer.samples_var.data.mean()
            stats['KL_fake'] = -KL_fake.data[0]

            # Update e
            sum(e_losses).backward()
            optimizerD.step()

        # ---------------------------
        #        Minimize over g
        # ---------------------------

        for g_iter in range(updates['g']['num_updates']):
            g_losses = []
            netG.zero_grad()

            # Z
            populate_z(z)
            # g(Z)
            fake = netG(z)
            # e(g(Z))
            egz = netE(fake)

            # KL_fake: \Delta( e(g(Z)) , Z ) -> min_g
            KL_fake_g = KL_minimizer(egz)
            g_losses.append(KL_fake_g * updates['g']['KL_fake'])

            if updates['g']['match_z'] != 0:
                # match_z: E_z||e(g(z)) - z|| -> min_g
                err = match(egz, z, match_z)
                err = err * updates['g']['match_z']
                g_losses.append(err)

            # ==================================

            if updates['g']['match_x'] != 0:
                # X
                populate_x(x, dataloader['train'])
                # e(X)
                ex = netE(x)

                # g(e(X))
                gex = netG(ex)

                # match_x: E_x||g(e(x)) - x|| -> min_g
                err = match(gex, x, match_x)
                err = err * updates['g']['match_x']
                g_losses.append(err)

            # Step g
            sum(g_losses).backward()
            optimizerG.step()

        print('[{epoch}/{nepoch}][{iter}/{niter}] '
              'KL_real/fake: {KL_real:.3f}/{KL_fake:.3f} '
              'mean_real/fake: {real_mean:.3f}/{fake_mean:.3f} '
              'var_real/fake: {real_var:.3f}/{fake_var:.3f} '
              ''.format(epoch=epoch,
                        nepoch=nepoch,
                        iter=i,
                        niter=len(dataloader['train']),
                        **stats))

        if i % save_every == 0:
            save_images(epoch)

        # If an epoch takes long time, dump intermediate
        if dataset in ['lsun', 'imagenet'] and (i % 5000 == 0):
            torch.save(netG, '%s/netG_epoch_%d_it_%d.pth' %
                       (save_dir, epoch, i))
            torch.save(netE, '%s/netE_epoch_%d_it_%d.pth' %
                       (save_dir, epoch, i))

    # do checkpointing
    torch.save(netG, '%s/netG_epoch_%d.pth' % (save_dir, epoch))
    torch.save(netE, '%s/netE_epoch_%d.pth' % (save_dir, epoch))

[0/100][0/4] KL_real/fake: 3.389/3.330 mean_real/fake: -0.011/-0.009 var_real/fake: 0.022/0.023 
[0/100][1/4] KL_real/fake: 3.450/3.162 mean_real/fake: 0.004/-0.005 var_real/fake: 0.021/0.026 
[0/100][2/4] KL_real/fake: 3.211/3.185 mean_real/fake: -0.003/-0.019 var_real/fake: 0.025/0.026 
[0/100][3/4] KL_real/fake: 3.100/3.186 mean_real/fake: -0.007/-0.017 var_real/fake: 0.028/0.026 


  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


[1/100][0/4] KL_real/fake: 3.058/3.304 mean_real/fake: -0.012/-0.031 var_real/fake: 0.029/0.024 
[1/100][1/4] KL_real/fake: 3.047/3.495 mean_real/fake: -0.013/-0.022 var_real/fake: 0.029/0.021 
[1/100][2/4] KL_real/fake: 3.062/3.613 mean_real/fake: -0.008/-0.029 var_real/fake: 0.029/0.020 
[1/100][3/4] KL_real/fake: 3.051/3.851 mean_real/fake: -0.009/-0.020 var_real/fake: 0.029/0.017 
[2/100][0/4] KL_real/fake: 3.032/4.021 mean_real/fake: -0.006/-0.035 var_real/fake: 0.029/0.015 
[2/100][1/4] KL_real/fake: 3.053/4.217 mean_real/fake: -0.009/-0.045 var_real/fake: 0.029/0.013 
[2/100][2/4] KL_real/fake: 3.096/4.401 mean_real/fake: -0.020/-0.040 var_real/fake: 0.028/0.012 
[2/100][3/4] KL_real/fake: 3.026/4.452 mean_real/fake: -0.016/-0.031 var_real/fake: 0.030/0.011 
[3/100][0/4] KL_real/fake: 3.053/4.467 mean_real/fake: -0.008/-0.039 var_real/fake: 0.029/0.010 
[3/100][1/4] KL_real/fake: 3.055/4.654 mean_real/fake: -0.008/-0.027 var_real/fake: 0.029/0.009 
[3/100][2/4] KL_real/fake: 3.0

[22/100][0/4] KL_real/fake: 4.234/5.044 mean_real/fake: -0.023/-0.026 var_real/fake: 0.012/0.007 
[22/100][1/4] KL_real/fake: 4.181/5.313 mean_real/fake: -0.023/-0.027 var_real/fake: 0.012/0.005 
[22/100][2/4] KL_real/fake: 4.073/5.054 mean_real/fake: -0.021/-0.023 var_real/fake: 0.013/0.006 
[22/100][3/4] KL_real/fake: 4.112/5.234 mean_real/fake: -0.021/-0.026 var_real/fake: 0.012/0.006 
[23/100][0/4] KL_real/fake: 4.255/5.259 mean_real/fake: -0.023/-0.027 var_real/fake: 0.011/0.005 
[23/100][1/4] KL_real/fake: 4.177/5.204 mean_real/fake: -0.022/-0.026 var_real/fake: 0.012/0.005 
[23/100][2/4] KL_real/fake: 4.185/5.199 mean_real/fake: -0.022/-0.025 var_real/fake: 0.012/0.005 
[23/100][3/4] KL_real/fake: 4.037/5.138 mean_real/fake: -0.022/-0.026 var_real/fake: 0.013/0.006 
[24/100][0/4] KL_real/fake: 4.176/5.145 mean_real/fake: -0.024/-0.025 var_real/fake: 0.013/0.006 
[24/100][1/4] KL_real/fake: 4.094/5.305 mean_real/fake: -0.023/-0.026 var_real/fake: 0.013/0.005 
[24/100][2/4] KL_rea

[42/100][3/4] KL_real/fake: 4.053/5.807 mean_real/fake: -0.021/-0.026 var_real/fake: 0.013/0.003 
[43/100][0/4] KL_real/fake: 3.780/5.712 mean_real/fake: -0.018/-0.027 var_real/fake: 0.016/0.003 
[43/100][1/4] KL_real/fake: 3.989/5.761 mean_real/fake: -0.021/-0.027 var_real/fake: 0.014/0.003 
[43/100][2/4] KL_real/fake: 3.896/5.682 mean_real/fake: -0.023/-0.027 var_real/fake: 0.014/0.003 
[43/100][3/4] KL_real/fake: 3.917/5.711 mean_real/fake: -0.021/-0.027 var_real/fake: 0.015/0.003 
[44/100][0/4] KL_real/fake: 4.152/5.930 mean_real/fake: -0.022/-0.027 var_real/fake: 0.012/0.002 
[44/100][1/4] KL_real/fake: 3.817/5.650 mean_real/fake: -0.020/-0.026 var_real/fake: 0.015/0.003 
[44/100][2/4] KL_real/fake: 3.895/5.647 mean_real/fake: -0.021/-0.027 var_real/fake: 0.015/0.003 
[44/100][3/4] KL_real/fake: 3.985/5.800 mean_real/fake: -0.022/-0.027 var_real/fake: 0.014/0.003 
[45/100][0/4] KL_real/fake: 4.177/5.849 mean_real/fake: -0.021/-0.027 var_real/fake: 0.011/0.003 
[45/100][1/4] KL_rea

[63/100][2/4] KL_real/fake: 3.982/6.073 mean_real/fake: -0.022/-0.027 var_real/fake: 0.014/0.002 
[63/100][3/4] KL_real/fake: 3.903/6.017 mean_real/fake: -0.019/-0.027 var_real/fake: 0.015/0.002 
[64/100][0/4] KL_real/fake: 3.719/5.980 mean_real/fake: -0.018/-0.026 var_real/fake: 0.017/0.002 
[64/100][1/4] KL_real/fake: 3.957/6.050 mean_real/fake: -0.018/-0.027 var_real/fake: 0.014/0.002 
[64/100][2/4] KL_real/fake: 3.837/5.954 mean_real/fake: -0.017/-0.027 var_real/fake: 0.015/0.002 
[64/100][3/4] KL_real/fake: 3.913/5.910 mean_real/fake: -0.019/-0.027 var_real/fake: 0.015/0.002 
[65/100][0/4] KL_real/fake: 3.923/5.928 mean_real/fake: -0.021/-0.027 var_real/fake: 0.014/0.002 
[65/100][1/4] KL_real/fake: 3.690/5.893 mean_real/fake: -0.017/-0.027 var_real/fake: 0.017/0.002 
[65/100][2/4] KL_real/fake: 3.981/6.069 mean_real/fake: -0.019/-0.027 var_real/fake: 0.014/0.002 
[65/100][3/4] KL_real/fake: 3.761/5.971 mean_real/fake: -0.017/-0.027 var_real/fake: 0.016/0.002 
[66/100][0/4] KL_rea

In [None]:
print(real_cpu.shape); print(real_cpu[:64].shape); print(real_cpu[0].shape)
# vutils.save_image(real_cpu[:64] / 2 + 0.5, save_path)
plt.imshow(real_cpu[0].numpy().reshape((64,64,3))) #[[2, 1, 0],:,:])

torch.Size([32, 3, 64, 64])
torch.Size([32, 3, 64, 64])
torch.Size([3, 64, 64])


<matplotlib.image.AxesImage at 0x7fe7bb1eb390>

Error in callback <function install_repl_displayhook.<locals>.post_execute at 0x7fe8011fb598> (for post_execute):


ValueError: Floating point image RGB values must be in the 0..1 range.

ValueError: Floating point image RGB values must be in the 0..1 range.

<matplotlib.figure.Figure at 0x7fe7bd2f3978>