In [3]:
# mount google drive folder
import os
import sys

from google.colab import drive

drive.mount('/content/gdrive/', force_remount=True)

project_dir = '/content/gdrive/My Drive/idgan'

sys.path.append(project_dir)

Mounted at /content/gdrive/


In [4]:
import random
from math import sqrt

import numpy as np
import torch
from torch import nn
import torch.optim as optim
from torch import distributions
import torch.nn.functional as F
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.tensorboard import SummaryWriter
import torchvision.datasets as datasets

from tqdm.notebook import tqdm

device = 'cuda' if torch.cuda.is_available() else 'cpu'

## Preprocess

We use the CelebA dataset, resize them to 64, 128 and 256. We will use 64x64 images to train Beta VAE.

In [5]:
from pathlib import Path

celeba_64_dir = Path(os.path.join(project_dir, 'CelebA_64'))
celeba_128_dir = Path(os.path.join(project_dir, 'CelebA_128'))
celeba_256_dir = Path(os.path.join(project_dir, 'CelebA_256'))

In [6]:
from PIL import Image

def preprocess_celeba(path):
    crop = transforms.CenterCrop((160, 160))
    resample = Image.LANCZOS
    img = Image.open(path)
    img = crop(img)

    img_256_path = celeba_256_dir / path.name 
    img.resize((256, 256), resample=resample).save(img_256_path)

    img_128_path = celeba_128_dir / path.name 
    img.resize((128, 128), resample=resample).save(img_128_path)

    img_64_path = celeba_64_dir / path.name 
    img.resize((64, 64), resample=resample).save(img_64_path)

    return None

In [None]:
# from multiprocessing import Pool

# paths = list(Path(os.path.join(project_dir, 'img_align_celeba')).glob('*.jpg'))

# with Pool(16) as pool:
#     pool.map(preprocess_celeba, paths)


In [7]:
!ls /content/gdrive/My\ Drive/idgan/img_align_celeba | wc -l
!ls /content/gdrive/My\ Drive/idgan/CelebA_64/images | wc -l
!ls /content/gdrive/My\ Drive/idgan/CelebA_128/images | wc -l
!ls /content/gdrive/My\ Drive/idgan/CelebA_256/images | wc -l

ls: cannot open directory '/content/gdrive/My Drive/idgan/img_align_celeba': Input/output error
0
15031
15031
15031


## Dataloader

PyTorch dataloader

In [1]:
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import ImageFolder

class CustomImageFolder(ImageFolder):
    def __init__(self, root, transform=None):
        super(CustomImageFolder, self).__init__(root, transform)

    def __getitem__(self, index):
        path = self.imgs[index][0]
        img = self.loader(path)
        if self.transform is not None:
            img = self.transform(img)

        return img

def return_data(image_size = 64):

    batch_size = 64
    num_workers = 1

    root = os.path.join(project_dir, 'CelebA_{}'.format(image_size))
    transform = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        ])
    train_kwargs = {'root':root, 'transform':transform}

    train_data = CustomImageFolder(**train_kwargs)
    train_loader = DataLoader(train_data,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=num_workers,
                              pin_memory=True,
                              drop_last=True)

    return train_loader

In [10]:
celeba64dataloader = return_data()

for idx, img in enumerate(celeba64dataloader):
    print(img.shape)
    break

len(celeba64dataloader) # total batch

torch.Size([64, 3, 64, 64])


234

## BetaVAE_H

Define Beta VAE model

In [11]:
def normal_init(m):
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)):
        m.weight.data.normal_(mean=0, std=0.02)
        if m.bias.data is not None:
            m.bias.data.zero_()


def reparametrize(mu, logvar):
    std = logvar.div(2).exp()
    eps = std.data.new(std.size()).normal_()
    return mu + std*eps


class View(nn.Module):
    def __init__(self, size):
        super(View, self).__init__()
        self.size = size
    def forward(self, tensor):
        return tensor.view(self.size)


class Encoder(nn.Module):
    def __init__(self, c_dim=10, nc=3, infodistil_mode=False):
        super(Encoder, self).__init__()
        self.c_dim = c_dim
        self.nc = nc
        self.infodistil_mode = infodistil_mode
        self.layer = nn.Sequential(
            nn.Conv2d(nc, 32, 4, 2, 1),          # B,  32, 32, 32
            nn.ReLU(True),
            nn.Conv2d(32, 32, 4, 2, 1),          # B,  32, 16, 16
            nn.ReLU(True),
            nn.Conv2d(32, 64, 4, 2, 1),          # B,  64,  8,  8
            nn.ReLU(True),
            nn.Conv2d(64, 64, 4, 2, 1),          # B,  64,  4,  4
            nn.ReLU(True),
            nn.Conv2d(64, 256, 4, 1),            # B, 256,  1,  1
            nn.ReLU(True),
            View((-1, 256*1*1)),                 # B, 256
            nn.Linear(256, c_dim*2),             # B, c_dim*2
        )

    def forward(self, x):
        if self.infodistil_mode:
            x = x.add(1).div(2)
            if (x.size(2) > 64) or (x.size(3) > 64):
                x = F.adaptive_avg_pool2d(x, (64, 64))

        h = self.layer(x)
        return h


class Decoder(nn.Module):
    def __init__(self, c_dim=10, nc=3):
        super(Decoder, self).__init__()
        self.c_dim = c_dim
        self.nc = nc
        self.layer = nn.Sequential(
            nn.Linear(c_dim, 256),               # B, 256
            View((-1, 256, 1, 1)),               # B, 256,  1,  1
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 64, 4),      # B,  64,  4,  4
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 64, 4, 2, 1), # B,  64,  8,  8
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 32, 4, 2, 1), # B,  32, 16, 16
            nn.ReLU(True),
            nn.ConvTranspose2d(32, 32, 4, 2, 1), # B,  32, 32, 32
            nn.ReLU(True),
            nn.ConvTranspose2d(32, nc, 4, 2, 1),  # B, nc, 64, 64
        )


    def forward(self, c):
        x = self.layer(c)
        return x

In [12]:
class BetaVAE_H(nn.Module):
    """Model proposed in original beta-VAE paper(Higgins et al, ICLR, 2017)."""

    def __init__(self, c_dim=10, nc=3, infodistil_mode=False):
        super(BetaVAE_H, self).__init__()
        self.c_dim = c_dim
        self.nc = nc
        self.encoder = Encoder(c_dim, nc, infodistil_mode)
        self.decoder = Decoder(c_dim, nc)
        self.apply(normal_init)

    def forward(self, x, c, encode_only, decode_only):
        if encode_only:
            c, mu, logvar = self._encode(x)
            return c, mu, logvar
        elif decode_only:
            x_recon = self._decode(c)
            return x_recon
        else:
            c, mu, logvar = self._encode(x)
            x_recon = self._decode(c)
            return x_recon, c, mu, logvar

    def __call__(self, x=None, c=None, encode_only=False, decode_only=False):
        return self.forward(x, c, encode_only, decode_only)

    def _encode(self, x):
        distributions = self.encoder(x)
        mu = distributions[:, :self.c_dim]
        logvar = distributions[:, self.c_dim:]
        c = reparametrize(mu, logvar)
        return c, mu, logvar

    def _decode(self, c):
        return self.decoder(c)

## Solver class

VAE training class, the loss id KLD loss and MSE loss

In [13]:
def reconstruction_loss(x, x_recon, distribution):
    batch_size = x.size(0)
    assert batch_size != 0

    if distribution == 'bernoulli':
        recon_loss = F.binary_cross_entropy_with_logits(x_recon, x, reduction='sum').div(batch_size)
    elif distribution == 'gaussian':
        x_recon = torch.sigmoid(x_recon)
        recon_loss = F.mse_loss(x_recon, x, reduction='sum').div(batch_size)
    else:
        raise NotImplementedError

    return recon_loss


def kl_divergence(mu, logvar):
    batch_size = mu.size(0)
    assert batch_size != 0

    if mu.data.ndimension() == 4:
        mu = mu.view(mu.size(0), mu.size(1))
    if logvar.data.ndimension() == 4:
        logvar = logvar.view(logvar.size(0), logvar.size(1))

    kld = (-0.5*(1 + logvar - mu.pow(2) - logvar.exp())).sum(1).mean(0, True)
    return kld


class Solver(object):
    def __init__(self, max_iter, data_loader, load_ckpt=None):
        
        self.global_iter = 0

        self.max_iter = max_iter # 1e6 #3e5
        self.ckpt_save_iter = 10000
        self.log_line_iter = 100
        self.log_img_iter = 100
        # image output dir
        # self.output_dir = os.path.join(project_dir, 'dvae_celeba_output')
        # checkpoint save dir
        self.ckpt_dir = os.path.join(project_dir, 'ckpt')
            
        self.log_dir = os.path.join(project_dir, 'tensorboard')
        self.writer = SummaryWriter(self.log_dir)

        self.data_loader = data_loader

        self.nc = 3

        self.c_dim = 20
        self.beta = 6.4
        
        self.dec_dist = 'gaussian'

        self.net = BetaVAE_H(self.c_dim, self.nc).to(device)
        self.net.apply(normal_init)

        self.optim = optim.Adam(self.net.parameters(), lr=1e-4,
                                betas=(0.9, 0.999))
        
        if load_ckpt is not None:
            self.load_checkpoint(str(load_ckpt))

    def train(self):
        pbar = tqdm(total=self.max_iter)
        pbar.update(self.global_iter)
        out = False if self.global_iter < self.max_iter else True
        while not out:
            for x in self.data_loader:
                self.net.train()
                self.global_iter += 1
                pbar.update(1)

                x = x.to(device)
                x_recon, c, mu, logvar = self.net(x)

                recon_loss = reconstruction_loss(x, x_recon, self.dec_dist)
                kld = kl_divergence(mu, logvar)
                beta_vae_loss = recon_loss + self.beta*kld

                self.optim.zero_grad()
                beta_vae_loss.backward()
                self.optim.step()

                pbar.set_description('[{}] recon_loss:{:.3f} kld:{:.3f}'.format(
                    self.global_iter, recon_loss.item(), kld.item()))

                if self.global_iter % self.log_line_iter == 0:
                    self.writer.add_scalar('recon_loss', recon_loss, self.global_iter)
                    self.writer.add_scalar('kld', kld, self.global_iter)

                if self.global_iter % self.log_img_iter == 0:
                    # visualize reconstruction 
                    x = make_grid(x, nrow=int(sqrt(x.size(0))), padding=2, pad_value=1)
                    x_recon = make_grid(x_recon.sigmoid(), nrow=int(sqrt(x_recon.size(0))), padding=2, pad_value=1)
                    x_vis = make_grid(torch.stack([x, x_recon]), nrow=2, padding=2, pad_value=0)
                    self.writer.add_image('reconstruction', x_vis, self.global_iter)

                    # visualize traverse
                    self.traverse(c_post=mu[:1], c_prior=torch.randn_like(mu[:1]))


                if self.global_iter % self.ckpt_save_iter == 0:
                    self.save_checkpoint()
                    pbar.write('Saved checkpoint (iter:{})'.format(self.global_iter))

                if self.global_iter >= self.max_iter:
                    self.save_checkpoint()
                    pbar.write('Saved checkpoint (iter:{})'.format(self.global_iter))
                    out = True
                    break

        pbar.write("[Training Finished]")
        pbar.close()

    def traverse(self, c_post, c_prior, limit=3, npoints=7, pos=-1):
        assert isinstance(pos, (int, list, tuple))

        self.net.eval()
        c_dict = {'c_posterior':c_post, 'c_prior':c_prior}
        interpolation = torch.linspace(-limit, limit, npoints)

        for c_key in c_dict:
            c_ori = c_dict[c_key]
            samples = []
            for row in range(self.c_dim):
                if pos != -1 and row not in pos:
                    continue

                c = c_ori.clone()
                for val in interpolation:
                    c[:, row] = val
                    sample = self.net(c=c, decode_only=True).sigmoid().data
                    samples.append(sample)

            samples = torch.cat(samples, dim=0).cpu()
            samples = make_grid(samples, nrow=npoints, padding=2, pad_value=1)
            tag = 'latent_traversal_{}'.format(c_key)
            self.writer.add_image(tag, samples, self.global_iter)

        self.net.train()

    def save_checkpoint(self):
        model_states = {'net':self.net.state_dict(),
                        'c_dim':self.c_dim,
                        'nc':self.nc}
        optim_states = {'optim':self.optim.state_dict(),}
        states = {'iter':self.global_iter,
                  'model_states':model_states,
                  'optim_states':optim_states}

        file_path = os.path.join(self.ckpt_dir, str(self.global_iter))
        with open(file_path, mode='wb+') as f:
            torch.save(states, f)

        # file_path = os.path.join(self.ckpt_dir, 'last')
        # with open(file_path, mode='wb+') as f:
        #     torch.save(states, f)

    def load_checkpoint(self, filename):

        file_path = os.path.join(self.ckpt_dir, filename)

        if os.path.isfile(file_path):
            
            checkpoint = torch.load(file_path)
            self.global_iter = checkpoint['iter']
            self.net.load_state_dict(checkpoint['model_states']['net'])
            self.optim.load_state_dict(checkpoint['optim_states']['optim'])

            tqdm.write("=> loaded checkpoint '{} (iter {})'".format(file_path, self.global_iter))
        else:
            tqdm.write("=> no checkpoint found at '{}'".format(file_path))

## VAE training

In [None]:
"""
那么cuDNN使用的非确定性算法就会自动寻找最适合当前配置的高效算法，来达到优化运行效率的问题

一般来讲，应该遵循以下准则：

如果网络的输入数据维度或类型上变化不大，设置  torch.backends.cudnn.benchmark = true  可以增加运行效率；
如果网络的输入数据在每次 iteration 都变化的话，会导致 cnDNN 每次都会去寻找一遍最优配置，这样反而会降低运行效率。
"""

torch.backends.cudnn.enabled =True  # 说明设置为使用使用非确定性算法
torch.backends.cudnn.benchmark = True

seed = 46

torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)

In [None]:
net = Solver(max_iter=500000, data_loader=celeba64dataloader, load_ckpt=250000)

net.train()

=> loaded checkpoint '/content/gdrive/My Drive/idgan/ckpt/250000 (iter 250000)'


HBox(children=(FloatProgress(value=0.0, max=500000.0), HTML(value='')))

Saved checkpoint (iter:260000)
Saved checkpoint (iter:270000)
Saved checkpoint (iter:280000)
Saved checkpoint (iter:290000)
Saved checkpoint (iter:300000)
Saved checkpoint (iter:310000)
Saved checkpoint (iter:320000)
Saved checkpoint (iter:330000)
Saved checkpoint (iter:340000)
Saved checkpoint (iter:350000)
Saved checkpoint (iter:360000)
Saved checkpoint (iter:370000)
Saved checkpoint (iter:380000)
Saved checkpoint (iter:390000)
Saved checkpoint (iter:400000)
Saved checkpoint (iter:410000)
Saved checkpoint (iter:420000)
Saved checkpoint (iter:430000)
Saved checkpoint (iter:440000)
Saved checkpoint (iter:450000)
Saved checkpoint (iter:460000)
Saved checkpoint (iter:470000)
Saved checkpoint (iter:480000)
Saved checkpoint (iter:490000)
Saved checkpoint (iter:500000)
Saved checkpoint (iter:500000)
[Training Finished]



## Resnet3

Define genrator and discriminator

In [8]:
class Generator(nn.Module):
    def __init__(self, z_dim, size, nfilter=64, nfilter_max=512, **kwargs):
        super().__init__()
        self.z_dim = z_dim

        s0 = self.s0 = 4
        nf = self.nf = nfilter
        nf_max = self.nf_max = nfilter_max

        # Submodules
        nlayers = int(np.log2(size / s0))
        self.nf0 = min(nf_max, nf * 2**nlayers)

        self.fc = nn.Linear(z_dim, self.nf0*s0*s0)

        blocks = []
        for i in range(nlayers):
            nf0 = min(nf * 2**(nlayers-i), nf_max)
            nf1 = min(nf * 2**(nlayers-i-1), nf_max)
            blocks += [
                ResnetBlock(nf0, nf1),
                nn.Upsample(scale_factor=2)
            ]

        blocks += [
            ResnetBlock(nf, nf),
        ]

        self.resnet = nn.Sequential(*blocks)
        self.conv_img = nn.Conv2d(nf, 3, 3, padding=1)

    def forward(self, z):
        batch_size = z.size(0)
        out = self.fc(z)
        out = out.view(batch_size, self.nf0, self.s0, self.s0)
        out = self.resnet(out)
        out = self.conv_img(actvn(out))
        out = torch.tanh(out)
        return out


class Discriminator(nn.Module):
    def __init__(self, z_dim, size, nfilter=64, nfilter_max=1024):
        super().__init__()
        s0 = self.s0 = 4
        nf = self.nf = nfilter
        nf_max = self.nf_max = nfilter_max

        # Submodules
        nlayers = int(np.log2(size / s0))
        self.nf0 = min(nf_max, nf * 2**nlayers)

        blocks = [
            ResnetBlock(nf, nf)
        ]

        for i in range(nlayers):
            nf0 = min(nf * 2**i, nf_max)
            nf1 = min(nf * 2**(i+1), nf_max)
            blocks += [
                nn.AvgPool2d(3, stride=2, padding=1),
                ResnetBlock(nf0, nf1),
            ]

        self.conv_img = nn.Conv2d(3, 1*nf, 3, padding=1)
        self.resnet = nn.Sequential(*blocks)
        self.fc = nn.Linear(self.nf0*s0*s0, 1)

    def forward(self, x):
        batch_size = x.size(0)
        out = self.conv_img(x)
        out = self.resnet(out)
        out = out.view(batch_size, self.nf0*self.s0*self.s0)
        out = self.fc(actvn(out))
        return out


class ResnetBlock(nn.Module):
    def __init__(self, fin, fout, fhidden=None, is_bias=True):
        super().__init__()
        # Attributes
        self.is_bias = is_bias
        self.learned_shortcut = (fin != fout)
        self.fin = fin
        self.fout = fout
        if fhidden is None:
            self.fhidden = min(fin, fout)
        else:
            self.fhidden = fhidden

        # Submodules
        self.conv_0 = nn.Conv2d(self.fin, self.fhidden, 3, stride=1, padding=1)
        self.conv_1 = nn.Conv2d(self.fhidden, self.fout, 3, stride=1, padding=1, bias=is_bias)
        if self.learned_shortcut:
            self.conv_s = nn.Conv2d(self.fin, self.fout, 1, stride=1, padding=0, bias=False)

    def forward(self, x):
        x_s = self._shortcut(x)
        dx = self.conv_0(actvn(x))
        dx = self.conv_1(actvn(dx))
        out = x_s + 0.1*dx

        return out

    def _shortcut(self, x):
        if self.learned_shortcut:
            x_s = self.conv_s(x)
        else:
            x_s = x
        return x_s


def actvn(x):
    out = F.leaky_relu(x, 2e-1)
    return out

## GAN training

Start training. To train our GAN model, we need a disentangled representation learned by VAE, we fix the trained encoder and train a Generator for high-fidelity synthesis while distilling the learned disentanglement by optimizing the following objective

$\min_G \max_D \mathcal L_{GAN}(D, G) - \lambda \mathcal R_{ID}(G)$

$\mathcal L_{GAN} (D, G) = \mathbb E_{x \sim p(x)} [\log D(x)] + \mathbb E_{s \sim p(s), c \sim q_{\phi}(c) } [\log (1- D(G(s,c)))]$

$\mathcal R_{ID}(G) = \mathbb E_{c \sim q_{\phi}(c), x \sim G(s,c)} [\log q_{\phi} (c|x)] + H_{q_{\phi}}(c) $

Where $q_{\phi}(c) = \frac{1}{N} \sum_{i} q_{\phi}(c | x_i)$ is the aggregated posterior of the encoder network.

In [14]:
batch_size = 64
d_steps = 1
restart_every = -1
inception_every = -1
save_every = 1000
backup_every = 100000

# out_dir = os.path.join(project_dir, 'ckpt')
checkpoint_dir = os.path.join(project_dir, 'ckpt')

c_dim = 20
z_dist_dim = 256
nc = 3
img_size = 64

nfilter_generator = 64
nfilter_max_generator = 512

nfilter_discriminator = 64
nfilter_max_discriminator = 512

dvae = BetaVAE_H(
    c_dim=c_dim,
    nc=nc,
    infodistil_mode=True
)
generator = Generator(
    z_dim=z_dist_dim + c_dim,
    size=img_size,
    nfilter=nfilter_generator, 
    nfilter_max=nfilter_max_generator
)
discriminator = Discriminator(
    z_dim=z_dist_dim + c_dim,
    size=img_size,
    nfilter=nfilter_discriminator, 
    nfilter_max=nfilter_max_discriminator
)

In [15]:
dvae_ckpt = torch.load(os.path.join(project_dir, 'ckpt/500000'))['model_states']['net']
dvae.load_state_dict(dvae_ckpt)

<All keys matched successfully>

In [16]:
dvae = dvae.to(device)
generator = generator.to(device)
discriminator = discriminator.to(device)

In [17]:
g_optimizer = optim.RMSprop(generator.parameters(), lr=0.0001, alpha=0.99, eps=1e-8)
d_optimizer = optim.RMSprop(discriminator.parameters(), lr=0.0001, alpha=0.99, eps=1e-8)

In [18]:
def get_zdist(dim, device=None):
    # Get distribution
    mu = torch.zeros(dim, device=device)
    scale = torch.ones(dim, device=device)
    zdist = distributions.Normal(mu, scale)

    # Add dim attribute
    zdist.dim = dim

    return zdist

cdist = get_zdist(c_dim, device=device)
zdist = get_zdist(z_dist_dim, device=device)

In [19]:
def build_lr_scheduler(optimizer, last_epoch=-1):
    lr_scheduler = optim.lr_scheduler.StepLR(
        optimizer,
        step_size=150000,
        gamma=1,
        last_epoch=last_epoch
    )
    return lr_scheduler

# Learning rate anneling
g_scheduler = build_lr_scheduler(g_optimizer, last_epoch=-1)
d_scheduler = build_lr_scheduler(d_optimizer, last_epoch=-1)

In [20]:
from torch import autograd

class Trainer(object):
    def __init__(self, dvae, generator, discriminator, g_optimizer, d_optimizer,
                 reg_param, w_info):
        self.dvae = dvae
        self.generator = generator
        self.discriminator = discriminator
        self.g_optimizer = g_optimizer
        self.d_optimizer = d_optimizer
        self.reg_param = reg_param
        self.w_info = w_info

    def generator_trainstep(self, z, cs):
        toogle_grad(self.generator, True)
        toogle_grad(self.dvae, True)
        toogle_grad(self.discriminator, False)
        self.generator.train()
        self.discriminator.train()
        self.dvae.train()
        self.dvae.zero_grad()
        self.g_optimizer.zero_grad()

        loss = 0.
        c, c_mu, c_logvar = cs
        z_ = torch.cat([z, c], 1)
        x_fake = self.generator(z_)
        d_fake = self.discriminator(x_fake)

        gloss = self.compute_loss(d_fake, 1)
        loss += gloss

        chs = self.dvae(x_fake, encode_only=True)
        encloss = self.compute_infomax(cs, chs)
        loss += self.w_info*encloss

        loss.backward()
        self.g_optimizer.step()

        return gloss.item(), encloss.item()

    def discriminator_trainstep(self, x_real, z):
        toogle_grad(self.generator, False)
        toogle_grad(self.dvae, False)
        toogle_grad(self.discriminator, True)
        self.generator.train()
        self.discriminator.train()
        self.dvae.train()
        self.d_optimizer.zero_grad()

        # On real data
        x_real.requires_grad_()

        d_real = self.discriminator(x_real)
        dloss_real = self.compute_loss(d_real, 1)
        dloss_real.backward(retain_graph=True)
        reg = self.reg_param * compute_grad2(d_real, x_real).mean()
        reg.backward()

        # On fake data
        with torch.no_grad():
            c, c_mu, c_logvar = cs = self.dvae(x_real, encode_only=True)
            z_ = torch.cat([z, c], 1)
            x_fake = self.generator(z_)

        x_fake.requires_grad_()
        d_fake = self.discriminator(x_fake)
        dloss_fake = self.compute_loss(d_fake, 0)
        dloss_fake.backward()

        self.d_optimizer.step()
        toogle_grad(self.discriminator, False)

        # Output
        dloss = (dloss_real + dloss_fake)

        return dloss.item(), reg.item(), cs

    def compute_loss(self, d_out, target):
        targets = d_out.new_full(size=d_out.size(), fill_value=target)
        loss = F.binary_cross_entropy_with_logits(d_out, targets)
        return loss

    def compute_infomax(self, cs, chs):
        c, c_mu, c_logvar = cs
        ch, ch_mu, ch_logvar = chs
        loss = (math.log(2*math.pi) + ch_logvar + (c-ch_mu).pow(2).div(ch_logvar.exp()+1e-8)).div(2).sum(1).mean()
        return loss


# Utility functions
def toogle_grad(model, requires_grad):
    for p in model.parameters():
        p.requires_grad_(requires_grad)


def compute_grad2(d_out, x_in):
    batch_size = x_in.size(0)
    grad_dout = autograd.grad(
        outputs=d_out.sum(), inputs=x_in,
        create_graph=True, retain_graph=True, only_inputs=True
    )[0]
    grad_dout2 = grad_dout.pow(2)
    assert(grad_dout2.size() == x_in.size())
    reg = grad_dout2.view(batch_size, -1).sum(1)
    return reg


def update_average(model_tgt, model_src, beta):
    toogle_grad(model_src, False)
    toogle_grad(model_tgt, False)

    param_dict_src = dict(model_src.named_parameters())

    for p_name, p_tgt in model_tgt.named_parameters():
        p_src = param_dict_src[p_name]
        assert(p_src is not p_tgt)
        p_tgt.copy_(beta*p_tgt + (1. - beta)*p_src)

In [21]:
class CheckpointIO(object):
    def __init__(self, checkpoint_dir='./chkpts', **kwargs):
        self.module_dict = kwargs
        self.checkpoint_dir = checkpoint_dir

        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)

    def register_modules(self, **kwargs):
        self.module_dict.update(kwargs)

    def save(self, it, filename):
        filename = os.path.join(self.checkpoint_dir, filename)

        outdict = {'it': it}
        for k, v in self.module_dict.items():
            outdict[k] = v.state_dict()
        torch.save(outdict, filename)

    def load(self, filename):
        filename = os.path.join(self.checkpoint_dir, filename)

        if os.path.exists(filename):
            tqdm.write('=> Loading checkpoint...')
            out_dict = torch.load(filename)
            it = out_dict['it']
            for k, v in self.module_dict.items():
                if k in out_dict:
                    v.load_state_dict(out_dict[k])
                else:
                    tqdm.write('Warning: Could not find %s in checkpoint!' % k)
        else:
            it = -1

        return it

checkpoint_io = CheckpointIO(
    checkpoint_dir=os.path.join(project_dir, 'ckpt')
)

# Register modules to checkpoint
checkpoint_io.register_modules(
    generator=generator,
    discriminator=discriminator,
    g_optimizer=g_optimizer,
    d_optimizer=d_optimizer,
)

In [22]:
import torchvision.datasets as datasets

transform = transforms.Compose([
    transforms.Resize(img_size),
    transforms.CenterCrop(img_size),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    transforms.Lambda(lambda x: x + 1./128 * torch.rand(x.size())),
])

train_dataset = datasets.ImageFolder(celeba_256_dir, transform)

train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        num_workers=16,
        shuffle=True, pin_memory=True, sampler=None, drop_last=True
)

  cpuset_checked))


In [23]:
trainer = Trainer(
    dvae, generator, discriminator, g_optimizer, d_optimizer,
    reg_param=10,
    w_info = 0.001
)

In [24]:
import math
import time

max_iter = 60000 # 300000
pbar = tqdm(total=max_iter)
it = -1
epoch_idx = -1
tstart = t0 = time.time()

# it = epoch_idx = checkpoint_io.load(os.path.join(checkpoint_dir, 'model_00030000.pt'))

out = False
while not out:
    epoch_idx += 1
    tqdm.write('Start epoch %d...' % epoch_idx)

    for x_real, _ in train_loader:
        it += 1
        pbar.update(1)
        g_scheduler.step()
        d_scheduler.step()

        d_lr = d_optimizer.param_groups[0]['lr']
        g_lr = g_optimizer.param_groups[0]['lr']

        x_real = x_real.to(device)

        # Discriminator updates
        z = zdist.sample((batch_size,))
        dloss, reg, cs = trainer.discriminator_trainstep(x_real, z)

        # Generators updates
        if ((it + 1) % d_steps) == 0:
            z = zdist.sample((batch_size,))
            gloss, encloss = trainer.generator_trainstep(z, cs)

        # (iii) Backup if necessary
        if ((it + 1) % backup_every) == 0:
            tqdm.write('Saving backup...')
            checkpoint_io.save(it, 'model_%08d.pt' % it)
            checkpoint_io.save(it, 'model.pt')

        # (iv) Save checkpoint if necessary
        if time.time() - t0 > save_every:
            tqdm.write('Saving checkpoint...')
            checkpoint_io.save(it, 'model.pt')
            t0 = time.time()

        if it >= max_iter:
            tqdm.write('Saving backup...')
            checkpoint_io.save(it, 'model_%08d.pt' % it)
            # logger.save_stats('stats_%08d.p' % it)
            checkpoint_io.save(it, 'model.pt')
            # logger.save_stats('stats.p')
            out = True
            break

Start epoch 149...
Start epoch 150...
Start epoch 151...
Start epoch 152...
Start epoch 153...
Start epoch 154...
Saving checkpoint...
Start epoch 155...
Start epoch 156...
Start epoch 157...
Start epoch 158...
Start epoch 159...
Start epoch 160...
Start epoch 161...
Start epoch 162...
Start epoch 163...
Start epoch 164...
Start epoch 165...
Start epoch 166...
Start epoch 167...
Start epoch 168...
Start epoch 169...
Start epoch 170...
Start epoch 171...
Start epoch 172...
Start epoch 173...
Start epoch 174...
Saving checkpoint...
Start epoch 175...
Start epoch 176...
Start epoch 177...
Start epoch 178...
Start epoch 179...
Start epoch 180...
Start epoch 181...
Start epoch 182...
Start epoch 183...
Start epoch 184...
Start epoch 185...
Start epoch 186...
Start epoch 187...
Start epoch 188...
Start epoch 189...
Start epoch 190...
Start epoch 191...
Start epoch 192...
Start epoch 193...
Start epoch 194...
Start epoch 195...
Saving checkpoint...
Start epoch 196...
Start epoch 197...
Start 