# Generative adversarial network (using the GAN class)

In [1]:
from __future__ import print_function
import torch
import torch.utils.data
from torch import nn, optim
from torchvision import datasets, transforms
from tensorboardX import SummaryWriter

from tqdm import tqdm

batch_size = 128
epochs = 50
seed = 1
torch.manual_seed(seed)

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

In [2]:
root = '../data'
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Lambda(lambd=lambda x: x.view(-1))])
kwargs = {'batch_size': batch_size, 'num_workers': 1, 'pin_memory': True}

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root=root, train=True, transform=transform, download=True),
    shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root=root, train=False, transform=transform),
    shuffle=False, **kwargs)

In [3]:
from pixyz.distributions import Deterministic
from pixyz.distributions import Normal
from pixyz.models import GAN
from pixyz.utils import print_latex

In [4]:
x_dim = 784
z_dim = 100

# generator model p(x|z)    
class Generator(Deterministic):
    def __init__(self):
        super(Generator, self).__init__(cond_var=["z"], var=["x"], name="p")

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(z_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, x_dim),
            nn.Sigmoid()
        )

    def forward(self, z):
        x = self.model(z)
        return {"x": x}

# prior model p(z)
prior = Normal(loc=torch.tensor(0.), scale=torch.tensor(1.),
               var=["z"], features_shape=[z_dim], name="p_{prior}").to(device)

# generative model
p_g = Generator()
p = (p_g*prior).marginalize_var("z").to(device)

In [5]:
print(p)
print_latex(p)

Distribution:
  p(x) = \int p(x|z)p_{prior}(z)dz
Network architecture:
  p_{prior}(z):
  Normal(
    name=p_{prior}, distribution_name=Normal,
    var=['z'], cond_var=[], input_var=[], features_shape=torch.Size([100])
    (loc): torch.Size([1, 100])
    (scale): torch.Size([1, 100])
  )
  p(x|z):
  Generator(
    name=p, distribution_name=Deterministic,
    var=['x'], cond_var=['z'], input_var=['z'], features_shape=torch.Size([])
    (model): Sequential(
      (0): Linear(in_features=100, out_features=128, bias=True)
      (1): LeakyReLU(negative_slope=0.2, inplace=True)
      (2): Linear(in_features=128, out_features=256, bias=True)
      (3): BatchNorm1d(256, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
      (4): LeakyReLU(negative_slope=0.2, inplace=True)
      (5): Linear(in_features=256, out_features=512, bias=True)
      (6): BatchNorm1d(512, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
      (7): LeakyReLU(negative_slope=0.2, inplace=True)
    

<IPython.core.display.Math object>

In [6]:
# discriminator model p(t|x)
class Discriminator(Deterministic):
    def __init__(self):
        super(Discriminator, self).__init__(cond_var=["x"], var=["t"], name="d")

        self.model = nn.Sequential(
            nn.Linear(x_dim, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        t = self.model(x)
        return {"t": t}
    
d = Discriminator().to(device)

In [7]:
print(d)
print_latex(d)

Distribution:
  d(t|x)
Network architecture:
  Discriminator(
    name=d, distribution_name=Deterministic,
    var=['t'], cond_var=['x'], input_var=['x'], features_shape=torch.Size([])
    (model): Sequential(
      (0): Linear(in_features=784, out_features=512, bias=True)
      (1): LeakyReLU(negative_slope=0.2, inplace=True)
      (2): Linear(in_features=512, out_features=256, bias=True)
      (3): LeakyReLU(negative_slope=0.2, inplace=True)
      (4): Linear(in_features=256, out_features=1, bias=True)
      (5): Sigmoid()
    )
  )


<IPython.core.display.Math object>

In [8]:
model = GAN(p, d,
            optimizer=optim.Adam, optimizer_params={"lr":0.0002},
            d_optimizer=optim.Adam, d_optimizer_params={"lr":0.0002})
print(model)
print_latex(model)

Distributions (for training): 
  p(x) 
Loss function: 
  mean(D_{JS}^{Adv} \left[p_{data}(x)||p(x) \right]) 
Optimizer: 
  Adam (
  Parameter Group 0
      amsgrad: False
      betas: (0.9, 0.999)
      eps: 1e-08
      lr: 0.0002
      weight_decay: 0
  )


<IPython.core.display.Math object>

In [9]:
def train(epoch):
    train_loss = 0
    train_d_loss = 0
    for x, _ in tqdm(train_loader):
        x = x.to(device)
        loss, d_loss = model.train({"x": x})
        train_loss += loss
        train_d_loss += d_loss
 
    train_loss = train_loss * train_loader.batch_size / len(train_loader.dataset)
    train_d_loss = train_d_loss * train_loader.batch_size / len(train_loader.dataset)
    print('Epoch: {} Train loss: {:.4f}, {:.4f}'.format(epoch, train_loss.item(), train_d_loss.item()))
    return train_loss

In [10]:
def test(epoch):
    test_loss = 0
    test_d_loss = 0
    for x, _ in test_loader:
        x = x.to(device)
        loss, d_loss = model.test({"x": x})
        test_loss += loss
        test_d_loss += d_loss

    test_loss = test_loss * test_loader.batch_size / len(test_loader.dataset)
    test_d_loss = test_d_loss * test_loader.batch_size / len(test_loader.dataset)
    
    print('Test loss: {:.4f}, {:.4f}'.format(test_loss, test_d_loss.item()))
    return test_loss

In [11]:
def plot_image_from_latent(z_sample):
    with torch.no_grad():
        sample = p_g.sample({"z": z_sample})["x"].view(-1, 1, 28, 28).cpu()
        return sample

In [12]:
import datetime

dt_now = datetime.datetime.now()
exp_time = dt_now.strftime('%Y%m%d_%H:%M:%S')

In [13]:
import pixyz
v = pixyz.__version__
writer = SummaryWriter("runs/" + v + ".gan"  + exp_time)

z_sample = torch.randn(64, z_dim).to(device)
_x, _y = iter(test_loader).next()
_x = _x.to(device)
_y = _y.to(device)

import time
start = time.time()

for epoch in range(1, epochs + 1):
    train_loss = train(epoch)
    test_loss = test(epoch)
    
    sample = plot_image_from_latent(z_sample)

    writer.add_scalar('train_loss', train_loss.item(), epoch)
    writer.add_scalar('test_loss', test_loss.item(), epoch)      
    
    writer.add_images('Image_from_latent', sample, epoch)
elapsed_time = time.time() - start
writer.add_scalar('Exp time second', elapsed_time)
writer.close()

100%|██████████| 469/469 [00:07<00:00, 63.55it/s]

Epoch: 1 Train loss: 8.6437, 0.2713



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 11.0372, 0.0447


100%|██████████| 469/469 [00:07<00:00, 60.16it/s]

Epoch: 2 Train loss: 14.3812, 0.0657



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 14.1872, 0.1656


100%|██████████| 469/469 [00:07<00:00, 61.37it/s]

Epoch: 3 Train loss: 19.1721, 0.0691



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 25.0242, 0.0362


100%|██████████| 469/469 [00:07<00:00, 61.61it/s]


Epoch: 4 Train loss: 33.9162, 0.0576


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 35.8233, 0.0511


100%|██████████| 469/469 [00:07<00:00, 61.47it/s]


Epoch: 5 Train loss: 32.8601, 0.0900


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 31.6670, 0.0740


100%|██████████| 469/469 [00:07<00:00, 62.02it/s]


Epoch: 6 Train loss: 40.2332, 0.0756


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 48.6511, 0.0737


100%|██████████| 469/469 [00:07<00:00, 59.03it/s]

Epoch: 7 Train loss: 35.8131, 0.0928



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 42.2108, 0.0661


100%|██████████| 469/469 [00:07<00:00, 62.21it/s]

Epoch: 8 Train loss: 44.0968, 0.0815



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 51.3868, 0.0596


100%|██████████| 469/469 [00:07<00:00, 60.47it/s]


Epoch: 9 Train loss: 50.3416, 0.0852


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 36.8834, 0.1182


100%|██████████| 469/469 [00:07<00:00, 62.10it/s]

Epoch: 10 Train loss: 41.7826, 0.0944



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 31.6169, 0.1008


100%|██████████| 469/469 [00:07<00:00, 60.39it/s]


Epoch: 11 Train loss: 40.6140, 0.0971


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 44.1410, 0.0713


100%|██████████| 469/469 [00:07<00:00, 61.03it/s]

Epoch: 12 Train loss: 49.0119, 0.0895



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 51.6916, 0.0763


100%|██████████| 469/469 [00:07<00:00, 62.81it/s]

Epoch: 13 Train loss: 47.1412, 0.0927



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 39.6306, 0.0881


100%|██████████| 469/469 [00:07<00:00, 61.23it/s]

Epoch: 14 Train loss: 43.9959, 0.1090



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 53.0351, 0.1310


100%|██████████| 469/469 [00:07<00:00, 62.99it/s]


Epoch: 15 Train loss: 49.4753, 0.1093


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 62.5869, 0.0562


100%|██████████| 469/469 [00:07<00:00, 60.27it/s]

Epoch: 16 Train loss: 53.2745, 0.1042



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 47.8086, 0.1148


100%|██████████| 469/469 [00:07<00:00, 63.34it/s]

Epoch: 17 Train loss: 48.7600, 0.1090



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 56.6560, 0.1614


100%|██████████| 469/469 [00:07<00:00, 61.96it/s]


Epoch: 18 Train loss: 47.4495, 0.1225


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 58.6053, 0.1275


100%|██████████| 469/469 [00:07<00:00, 58.92it/s]

Epoch: 19 Train loss: 46.8771, 0.1390



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 49.7367, 0.0900


100%|██████████| 469/469 [00:07<00:00, 61.71it/s]

Epoch: 20 Train loss: 50.0641, 0.1398



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 47.7229, 0.0970


100%|██████████| 469/469 [00:07<00:00, 62.58it/s]


Epoch: 21 Train loss: 48.9747, 0.1417


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 54.1271, 0.2600


100%|██████████| 469/469 [00:07<00:00, 60.15it/s]


Epoch: 22 Train loss: 49.5728, 0.1398


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 49.8324, 0.1326


100%|██████████| 469/469 [00:07<00:00, 62.36it/s]


Epoch: 23 Train loss: 52.2034, 0.1252


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 49.4642, 0.1631


100%|██████████| 469/469 [00:07<00:00, 61.53it/s]

Epoch: 24 Train loss: 47.4923, 0.1488



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 48.6663, 0.1152


100%|██████████| 469/469 [00:07<00:00, 60.33it/s]

Epoch: 25 Train loss: 50.4620, 0.1369



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 52.3123, 0.1435


100%|██████████| 469/469 [00:07<00:00, 60.03it/s]


Epoch: 26 Train loss: 51.0790, 0.1308


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 44.9480, 0.1216


100%|██████████| 469/469 [00:07<00:00, 60.41it/s]

Epoch: 27 Train loss: 51.3671, 0.1235



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 65.0699, 0.0923


100%|██████████| 469/469 [00:07<00:00, 62.33it/s]

Epoch: 28 Train loss: 58.7469, 0.1119



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 56.8301, 0.1411


100%|██████████| 469/469 [00:07<00:00, 59.65it/s]

Epoch: 29 Train loss: 60.0058, 0.1317



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 66.1869, 0.1141


100%|██████████| 469/469 [00:07<00:00, 60.36it/s]

Epoch: 30 Train loss: 58.6511, 0.1011



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 56.5887, 0.1094


100%|██████████| 469/469 [00:07<00:00, 61.16it/s]

Epoch: 31 Train loss: 56.1996, 0.1187



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 63.5563, 0.0952


100%|██████████| 469/469 [00:07<00:00, 61.58it/s]


Epoch: 32 Train loss: 57.5843, 0.1205


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 61.7895, 0.1437


100%|██████████| 469/469 [00:07<00:00, 59.82it/s]

Epoch: 33 Train loss: 62.4787, 0.1160



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 54.5871, 0.1776


100%|██████████| 469/469 [00:07<00:00, 61.38it/s]

Epoch: 34 Train loss: 54.0180, 0.1609



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 57.4475, 0.1917


100%|██████████| 469/469 [00:08<00:00, 58.32it/s]

Epoch: 35 Train loss: 56.3928, 0.1428



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 50.4163, 0.1117


100%|██████████| 469/469 [00:07<00:00, 58.89it/s]

Epoch: 36 Train loss: 57.7917, 0.1167



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 52.7644, 0.1508


100%|██████████| 469/469 [00:07<00:00, 60.03it/s]


Epoch: 37 Train loss: 51.1880, 0.1477


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 62.3150, 0.1450


100%|██████████| 469/469 [00:07<00:00, 60.84it/s]

Epoch: 38 Train loss: 56.3798, 0.1382



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 52.2229, 0.1590


100%|██████████| 469/469 [00:07<00:00, 61.13it/s]

Epoch: 39 Train loss: 53.5684, 0.1441



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 56.9267, 0.1361


100%|██████████| 469/469 [00:07<00:00, 61.56it/s]

Epoch: 40 Train loss: 51.1577, 0.1493



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 45.9806, 0.1961


100%|██████████| 469/469 [00:07<00:00, 60.65it/s]

Epoch: 41 Train loss: 53.0973, 0.1514



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 49.8693, 0.1864


100%|██████████| 469/469 [00:07<00:00, 60.21it/s]

Epoch: 42 Train loss: 50.0841, 0.1395



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 43.8470, 0.1639


100%|██████████| 469/469 [00:08<00:00, 56.25it/s]

Epoch: 43 Train loss: 51.8269, 0.1588



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 48.5654, 0.1659


100%|██████████| 469/469 [00:07<00:00, 62.44it/s]

Epoch: 44 Train loss: 45.1855, 0.1884



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 41.0840, 0.1772


100%|██████████| 469/469 [00:07<00:00, 59.94it/s]


Epoch: 45 Train loss: 46.8412, 0.1626


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 49.8021, 0.1763


100%|██████████| 469/469 [00:07<00:00, 62.62it/s]

Epoch: 46 Train loss: 51.8166, 0.1731



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 45.2091, 0.2683


100%|██████████| 469/469 [00:07<00:00, 61.02it/s]

Epoch: 47 Train loss: 48.8199, 0.1728



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 46.1852, 0.2404


100%|██████████| 469/469 [00:07<00:00, 60.16it/s]

Epoch: 48 Train loss: 46.3675, 0.1780



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 38.0546, 0.1703


100%|██████████| 469/469 [00:08<00:00, 56.49it/s]

Epoch: 49 Train loss: 47.1948, 0.1699



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 47.7333, 0.1934


100%|██████████| 469/469 [00:07<00:00, 61.05it/s]

Epoch: 50 Train loss: 45.6808, 0.1789





Test loss: 45.9222, 0.1632
