# Glow （CIFAR10）

In [1]:
from __future__ import print_function
import torch
import torch.utils.data
from torch import nn, optim
from torch.utils.data import DataLoader 
from torchvision import datasets, transforms
from tensorboardX import SummaryWriter

from tqdm import tqdm

batch_size = 32
epochs = 5
seed = 1
torch.manual_seed(seed)

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

In [2]:
root = '../data'
num_workers = 8

transform_train = transforms.Compose([transforms.RandomHorizontalFlip(), transforms.ToTensor()])
transform_test = transforms.Compose([transforms.ToTensor()])

train_loader = DataLoader(datasets.CIFAR10(root=root, train=True, download=True, transform=transform_train),
                          batch_size=batch_size, shuffle=True, num_workers=num_workers)

test_loader = DataLoader(datasets.CIFAR10(root=root, train=False, download=True, transform=transform_test),
                         batch_size=batch_size, shuffle=False, num_workers=num_workers)

0it [00:00, ?it/s]

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ../data/cifar-10-python.tar.gz



0it [00:00, ?it/s][A
  0%|          | 114688/170498071 [00:00<02:28, 1143530.36it/s][A

Failed download. Trying https -> http instead. Downloading http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ../data/cifar-10-python.tar.gz



  2%|▏         | 3489792/170498071 [00:00<01:43, 1610213.58it/s][A
  7%|▋         | 11403264/170498071 [00:00<01:09, 2280400.21it/s][A
 11%|█▏        | 19210240/170498071 [00:00<00:47, 3217433.90it/s][A
 16%|█▌        | 27156480/170498071 [00:00<00:31, 4517897.07it/s][A
 21%|██        | 35078144/170498071 [00:00<00:21, 6300035.80it/s][A
 25%|██▌       | 43032576/170498071 [00:00<00:14, 8704523.62it/s][A
 30%|██▉       | 50946048/170498071 [00:00<00:10, 11873804.22it/s][A
 35%|███▍      | 58892288/170498071 [00:00<00:07, 15937396.90it/s][A
 39%|███▉      | 66781184/170498071 [00:01<00:04, 20950907.89it/s][A
 44%|████▍     | 74694656/170498071 [00:01<00:03, 26879441.58it/s][A
 48%|████▊     | 82640896/170498071 [00:01<00:02, 33535689.07it/s][A
 53%|█████▎    | 90497024/170498071 [00:01<00:01, 40435995.51it/s][A
 58%|█████▊    | 98459648/170498071 [00:01<00:01, 47436362.22it/s][A
 62%|██████▏   | 106397696/170498071 [00:01<00:01, 53949068.14it/s][A
 67%|██████▋   | 11431936

Files already downloaded and verified


In [3]:
from pixyz.distributions import Normal, InverseTransformedDistribution
from pixyz.flows import AffineCoupling, FlowList, Squeeze, Unsqueeze, Preprocess, ActNorm2d, ChannelConv
from pixyz.layers import ResNet
from pixyz.models import ML
from pixyz.utils import print_latex

In [4]:
in_channels = 3
mid_channels = 64
num_scales = 2
input_dim = 32

In [5]:
# prior model p(z)
prior = Normal(loc= torch.tensor(0.), scale=torch.tensor(1.),
               var=["z"], features_shape=[in_channels, input_dim, input_dim], name="p_prior")

In [6]:
class ScaleTranslateNet(nn.Module):
    def __init__(self, in_channels, mid_channels):
        super().__init__()
        self.resnet = ResNet(in_channels=in_channels, mid_channels=mid_channels, out_channels=in_channels*2,
                             num_blocks=8, kernel_size=3, padding=1,
                             double_after_norm=True)

    def forward(self, x):
        s_t = self.resnet(x)
        log_s, t = torch.chunk(s_t, 2, dim=1)
        log_s = torch.tanh(log_s)
        return log_s, t

In [7]:
flow_list = []

flow_list.append(Preprocess())

# Squeeze -> 3x coupling (channel-wise)
flow_list.append(Squeeze())

for i in range(3):
    flow_list.append(ActNorm2d(in_channels*4))
    flow_list.append(ChannelConv(in_channels*4))
    flow_list.append(AffineCoupling(in_features=in_channels*4, mask_type="channel_wise",
                                    scale_translate_net=ScaleTranslateNet(in_channels*4, mid_channels*2),
                                    inverse_mask=False))
flow_list.append(Unsqueeze())

f = FlowList(flow_list)

In [8]:
# inverse transformed distribution (z -> f^-1 -> x)
p = InverseTransformedDistribution(prior=prior, flow=f, var=["x"]).to(device)
print(p)
print_latex(p)

Distribution:
  p(x) = p(x=f^{-1}_{flow}(z))
Network architecture:
  InverseTransformedDistribution(
    name=p, distribution_name=InverseTransformedDistribution,
    var=['x'], cond_var=[], input_var=[], features_shape=torch.Size([])
    (prior): Normal(
      name=p_{prior}, distribution_name=Normal,
      var=['z'], cond_var=[], input_var=[], features_shape=torch.Size([3, 32, 32])
      (loc): torch.Size([1, 3, 32, 32])
      (scale): torch.Size([1, 3, 32, 32])
    )
    (flow): FlowList(
      (0): Preprocess()
      (1): Squeeze()
      (2): ActNorm2d()
      (3): ChannelConv()
      (4): AffineCoupling(
        in_features=12, mask_type=channel_wise, inverse_mask=False
        (scale_translate_net): ScaleTranslateNet(
          (resnet): ResNet(
            (in_norm): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (in_conv): WNConv2d(
              (conv): Conv2d(24, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            )

<IPython.core.display.Math object>

In [9]:
model = ML(p, optimizer=optim.Adam, optimizer_params={"lr":1e-3})
print(model)
print_latex(model)

Distributions (for training): 
  p(x) 
Loss function: 
  mean \left(- \log p(x) \right) 
Optimizer: 
  Adam (
  Parameter Group 0
      amsgrad: False
      betas: (0.9, 0.999)
      eps: 1e-08
      lr: 0.001
      weight_decay: 0
  )


<IPython.core.display.Math object>

In [10]:
def train(epoch):
    train_loss = 0
    
    for x, _ in tqdm(train_loader):
        x = x.to(device)
        loss = model.train({"x": x})
        train_loss += loss

    train_loss = train_loss * train_loader.batch_size / len(train_loader.dataset)
    print('Epoch: {} Train loss: {:.4f}'.format(epoch, train_loss))
    return train_loss

In [11]:
def test(epoch):
    test_loss = 0
    for x, _ in test_loader:
        x = x.to(device)
        loss = model.test({"x": x})
        test_loss += loss

    test_loss = test_loss * test_loader.batch_size / len(test_loader.dataset)
    print('Test loss: {:.4f}'.format(test_loss))
    return test_loss

In [12]:
def plot_image_from_latent(z_sample):
    with torch.no_grad():
        sample = p.inverse(z_sample).cpu()
        return sample
    

def plot_reconstrunction(x):
    with torch.no_grad():
        z = p.forward(x, compute_jacobian=False)
        recon_batch = p.inverse(z)
    
        comparison = torch.cat([x.view(-1, 3, 32, 32), recon_batch]).cpu()
        return comparison

In [None]:
import datetime

dt_now = datetime.datetime.now()
exp_time = dt_now.strftime('%Y%m%d_%H:%M:%S')

In [13]:
import pixyz
v = pixyz.__version__
writer = SummaryWriter("runs/" + v + ".glow" + exp_time)

z_sample = torch.randn(64, 3, 32, 32).to(device)
_x, _ = iter(test_loader).next()
_x = _x.to(device)

import time
start = time.time()

for epoch in range(1, epochs + 1):
    train_loss = train(epoch)
    test_loss = test(epoch)
    
    recon = plot_reconstrunction(_x[:8])
    sample = plot_image_from_latent(z_sample)

    writer.add_scalar('train_loss', train_loss.item(), epoch)
    writer.add_scalar('test_loss', test_loss.item(), epoch)      
    
    writer.add_images('Image_from_latent', sample, epoch)
    writer.add_images('Image_reconstrunction', recon, epoch)
elapsed_time = time.time() - start
writer.add_scalar('Exp time second', elapsed_time)
writer.close()

100%|██████████| 1563/1563 [07:15<00:00,  3.59it/s]

Epoch: 1 Train loss: -7293.8525





Test loss: -8135.8296


100%|██████████| 1563/1563 [07:10<00:00,  3.63it/s]


Epoch: 2 Train loss: -8241.7881
Test loss: -8486.4766


100%|██████████| 1563/1563 [07:11<00:00,  3.62it/s]


Epoch: 3 Train loss: -8486.0371
Test loss: -8624.7988


100%|██████████| 1563/1563 [07:08<00:00,  3.64it/s]


Epoch: 4 Train loss: -8603.4814
Test loss: -8681.9209


 71%|███████   | 1106/1563 [05:02<02:05,  3.65it/s]