In [138]:
import numpy as np
import matplotlib.pyplot as plt
import os

import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader

import torchvision
from torchvision import datasets, transforms
from torchvision.utils import make_grid, save_image

from torch.distributions.multivariate_normal import MultivariateNormal

print('PyTorch version:', torch.__version__)
print('torchvision verseion:', torchvision.__version__)
print('Is GPU avaibale:', torch.cuda.is_available())

PyTorch version: 1.0.0
torchvision verseion: 0.2.1
Is GPU avaibale: True


In [139]:
# general settings (バッチサイズとデバイス)
batchsize = 128
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

seed = 1
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)

In [140]:
# データセットの準備
# Tensorにしつつ、 (-1 ~ 1)の範囲に正規化

#def preprocess(tensor):
#    return tensor - 0.5

tf = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# データセットをロード(今回はMNIST)
# 本当はtraining data, validation data, test dataに分けるべきだが、今回は簡便のため2つに分ける.
mnist_train = datasets.MNIST(root = '../../data/MNIST',
                                 train = True,
                                 transform = tf,
                                 download = False)
mnist_validation = datasets.MNIST(root = '../../data/MNIST',
                                      train = False,
                                      transform = tf)

# データローダーを作成
mnist_train_loader = DataLoader(mnist_train, batch_size = batchsize, shuffle = True)
mnist_validation_loader = DataLoader(mnist_validation, batch_size = batchsize, shuffle = False)

print('the number of training data:', len(mnist_train))
print('the number of validation data:', len(mnist_validation))

the number of training data: 60000
the number of validation data: 10000


In [141]:
# Actnormの実装
class ActNorm2d(nn.Module):
    def __init__(self, num_features):
        super(ActNorm2d, self).__init__()
        size = [1, num_features, 1, 1]
        self.register_parameter('bias', nn.Parameter(torch.zeros(*size)))
        self.register_parameter('log_s', nn.Parameter(torch.zeros(*size)))
        self.inited = False
        
    def forward(self, x):
        if not self.inited:
            self.initialize_parameters(x)
        
        z = torch.exp(self.log_s) * (x + self.bias)
        log_det_jacobian = self.calculate_log_det_jacobian(x)
        return z, log_det_jacobian
    
    def inverse(self, z):
        x = z * torch.exp(-self.log_s) - self.bias
        return x

    def calculate_log_det_jacobian(self, x):
        h, w = x.size(2), x.size(3)
        return h * w * torch.sum(self.log_s)
    
    def initialize_parameters(self, first_minibatch_x):
        with torch.no_grad():
            bias = -1.0 * self.multidim_mean(first_minibatch_x.clone(), dims=[0, 2, 3])
            var_s = self.multidim_mean((first_minibatch_x.clone() + bias) ** 2, dims=[0, 2, 3])
            log_s = torch.log(1 / (torch.sqrt(var_s) + 1e-6))
        
            self.bias.data.copy_(bias.data)
            self.log_s.data.copy_(log_s.data)
        
            self.inited = True
            
    def multidim_mean(self, tensor, dims):
        dims = sorted(dims)
        for d in dims:
            tensor = tensor.mean(dim=d, keepdim=True)
        return tensor

In [142]:
# invertible 1x1 convolutionの実装
class Invertible1x1Conv2d(nn.Module):
    def __init__(self, num_features):
        super(Invertible1x1Conv2d, self).__init__()
        self.conv = nn.Conv2d(num_features, num_features, kernel_size=1, stride=1, padding=0, bias=False)
        
        W = torch.qr(torch.FloatTensor(num_features, num_features).normal_())[0]
        
        if torch.det(W) < 0:
            W[:,0] = -W[:,0]

        self.conv.weight.data = W.view(num_features, num_features, 1, 1)
        
    def forward(self, x):
        z = self.conv(x)
        log_det_jacobian = self.calculate_log_det_jacobian(x)
        return z, log_det_jacobian
        
    def inverse(self, z, train_finished=False):
        if train_finished:
            if not hasattr(self, 'W_inverse'):
                W = self.conv.weight.squeeze()
                W_inverse = W.inverse()
                self.W_inverse = W_inverse.view(*W_inverse.size(), 1, 1)
            x = F.conv2d(z, self.W_inverse, bias=None, stride=1, padding=0)
        else:
            W = self.conv.weight.squeeze()
            W_inverse = W.inverse().view(*W.size(), 1, 1)
            x = F.conv2d(z, W_inverse, bias=None, stride=1, padding=0)
        return x
        
    def calculate_log_det_jacobian(self, x):
        W = self.conv.weight.squeeze()
        h, w = x.size(2), x.size(3)
        return h * w * torch.logdet(W)

In [143]:
# coupling layerで使われるCNN
class CNN(nn.Module):
    def __init__(self, n_in, n_hidden, affine=True):
        super(CNN, self).__init__()
        self.affine = affine
        if affine:
            n_out = n_in*2
        else:
            n_out = n_in
            
        self.cv1 = nn.Conv2d(n_in, n_hidden, kernel_size=3, stride=1, padding=1)
        self.ac1 = ActNorm2d(n_hidden)
        self.cv2 = nn.Conv2d(n_hidden, n_hidden, kernel_size=1, stride=1, padding=0)
        self.ac2 = ActNorm2d(n_hidden)
        self.cv3 = nn.Conv2d(n_hidden, n_out, kernel_size=3, stride=1, padding=1)
        self.init_weights()
        
    def forward(self, CNN_input):
        out = F.relu(self.ac1(self.cv1(CNN_input))[0])
        out = F.relu(self.ac2(self.cv2(out))[0])
        # out = F.relu(self.cv1(CNN_input))
        # out = F.relu(self.cv2(out))
        if self.affine:
            out = self.cv3(out)
            n_half = int(out.size(1) / 2)
            log_s = torch.tanh(out[:,:n_half,:,:])
            bias = out[:,n_half:,:,:]
            return [log_s, bias]
        else:
            bias = self.cv3(out)
            return bias
        
    def init_weights(self):
        self.cv1.weight.data.normal_(0, 0.05)
        self.cv1.bias.data.zero_()
        self.cv2.weight.data.normal_(0, 0.05)
        self.cv2.bias.data.zero_()
        self.cv3.weight.data.zero_()
        self.cv3.bias.data.zero_()

In [144]:
# coupling layerの実装
class CouplingLayer(nn.Module):
    def __init__(self, num_features, n_hidden, affine=True):
        super(CouplingLayer, self).__init__()
        
        assert num_features % 2 == 0
        self.n_half = int(num_features / 2)
        self.affine = affine
        
        self.CNN = CNN(self.n_half, n_hidden, affine)
            
    def forward(self, x):
        x_a = x[:,:self.n_half,:,:]
        x_b = x[:,self.n_half:,:,:]
        
        CNN_output = self.CNN(x_a)
        
        if self.affine:
            log_s = CNN_output[0]
            bias = CNN_output[1]
            z_b = torch.exp(log_s) * (x_b + bias)
        else:
            log_s = None
            z_b = x_b + CNN_output
            
        z = torch.cat([x_a, z_b], dim=1)
        log_det_jacobian = self.calculate_log_det_jacobian(log_s)
        return z, log_det_jacobian
        
    def inverse(self, z):
        z_a = z[:,:self.n_half,:,:]
        z_b = z[:,self.n_half:,:,:]
        
        CNN_output = self.CNN(z_a)
        
        if self.affine:
            log_s = CNN_output[0]
            bias = CNN_output[1]
            x_b = z_b * torch.exp(-log_s) - bias
        else:
            x_b = z_b - CNN_output
            
        x = torch.cat([z_a, x_b], dim=1)
        return x
        
    def calculate_log_det_jacobian(self, log_s):
        if self.affine:
            return torch.sum(log_s) / log_s.size(0)
        else:
            return 0.0

In [145]:
# 上３つをまとめたFlow
class StepofFlow(nn.Module):
    def __init__(self, num_features, n_hidden, affine=True):
        super(StepofFlow, self).__init__()
        self.actnorm = ActNorm2d(num_features)
        self.invertible1x1conv = Invertible1x1Conv2d(num_features)
        self.couplinglayer = CouplingLayer(num_features, n_hidden, affine)

    def forward(self, x):
        x, ldj_actnorm  = self.actnorm(x)
        # print('after_actnorm', torch.mean(torch.abs(x)))
        # print('act', ldj_actnorm)
        x, ldj_1x1conv  = self.invertible1x1conv(x)
        # print('after_1x1conv', torch.mean(torch.abs(x)))
        # print('1x1conv', ldj_1x1conv)
        z, ldj_coupling = self.couplinglayer(x)
        # print('after_coupling', torch.mean(torch.abs(z)))
        # print('coupling', ldj_coupling)
        log_det_jacobian = ldj_actnorm + ldj_1x1conv + ldj_coupling
        return z, log_det_jacobian
    
    def inverse(self, z, train_finished=False):
        z = self.couplinglayer.inverse(z)
        z = self.invertible1x1conv.inverse(z, train_finished)
        x = self.actnorm.inverse(z)
        return x

In [146]:
# Glow本体
class Glow(nn.Module):
    def __init__(self, L, K, num_input_features, n_hidden_list, affine=True):
        super(Glow, self).__init__()
        self.L = L
        self.K = K
        
        num_features = num_input_features
        assert len(n_hidden_list) == L*K
        
        self.flow = torch.nn.ModuleList()
        for l in range(L):
            # squeeze
            num_features *= 4
            for k in range(K):
                # step of flow
                self.flow.append(StepofFlow(num_features, int(n_hidden_list[l*K + k]), affine))
            # split
            num_features = num_features // 2
        
    def forward(self, x):
        z = []
        log_det_jacobian = 0
        
        for l in range(self.L):
            # squeeze
            x = self.squeeze(x)
            for k in range(self.K):
                # step of flow
                x, ldj = self.flow[l*self.K + k](x)
                log_det_jacobian += ldj
            # split
            if l == self.L-1:
                z.append(x.view(x.size(0), -1))
            else:
                z.append(x[:,:x.size(1)//2,:,:].view(x.size(0), -1))
                x = x[:,x.size(1)//2:,:,:]
        
        z = torch.cat(z, dim=1)
        if not hasattr(self, 'Z'):
            batchsize, Z_dim = z.size()
            self.Z = MultivariateNormal(torch.zeros(Z_dim).to(device), torch.eye(Z_dim).to(device))
            self.last_z_shape = x.size()[1:]
            
        return z, log_det_jacobian
        
    def inverse(self, z, train_finished=False):
        x_dim = self.last_z_shape[0] * self.last_z_shape[1] * self.last_z_shape[2]
        for l in reversed(range(self.L)):
            if l == self.L-1:
                x = z[:,-x_dim:].view(-1, *self.last_z_shape)
            else:
                z_in = z[:,-x_dim*2:-x_dim]
                x = torch.cat([z_in.view(*x.size()), x], dim = 1)
                x_dim = x_dim*2
                
            for k in reversed(range(self.K)):
                x = self.flow[l*self.K + k].inverse(x, train_finished)
                
            x = self.unsqueeze(x)
        return x
                
    def squeeze(self, x, factor=2):
        batchsize, channels, height, width = x.size()
        assert height % factor == 0
        assert width % factor == 0
        z = x.view(batchsize, channels, height // factor, factor, width // factor, factor)
        z = z.permute(0, 1, 3, 5, 2, 4)
        z = z.contiguous().view(batchsize, channels * factor**2, height // factor, width // factor)
        return z
    
    def unsqueeze(self, z, factor=2):
        batchsize, channels, height, width = z.size()
        x = z.view(batchsize, channels // (factor**2), factor, factor, height, width)
        x = x.permute(0, 1, 4, 2, 5, 3)
        x = x.contiguous().view(batchsize, channels // (factor**2), height * factor, width * factor)
        return x

In [56]:
net = Glow(L=2, K=16, num_input_features=1, n_hidden_list=np.ones(2*16)*128, affine=True)
net = net.to(device)

# warm_up_epochs = 10
learning_rate = 0.001 # * (10 ** (-warm_up_epochs))
optimizer = optim.Adam(net.parameters(), lr=learning_rate)

n_epochs = 100
save_image_interval = 1
n_save_image = 25
save_dir = '../../data/glow_MNIST/'

num_trainable_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
print('The number of parameters:', num_trainable_params)

The number of parameters: 882496


In [112]:
def train(train_loader):
    net.train()
    running_loss = 0
    for batch_index, sample_x in enumerate(train_loader):
        sample_x = sample_x[0].to(device)
        
        optimizer.zero_grad()
        predict_z, log_det_jacobian = net(sample_x)
        log_p_z = net.Z.log_prob(predict_z)
        log_p_z_mean = torch.mean(log_p_z) / predict_z.size(1)
        log_det_jacobian = log_det_jacobian / predict_z.size(1)
        loss = (log_p_z_mean - log_det_jacobian)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        print('batch_index[%3d/%3d] log_p_z_mean:%1.5f log_det_jacobian:%1.5f' \
            % (batch_index+1, len(train_loader), log_p_z_mean.item(), log_det_jacobian.item()))
        
    return running_loss / len(train_loader)

In [58]:
def validation(validation_loader, epoch):
    net.eval()
    running_loss = 0
    with torch.no_grad():
        for sample_x, _ in validation_loader:
            sample_x = sample_x.to(device)
            
            predict_z, log_det_jacobian = net(sample_x)
            log_p_z = net.Z.log_prob(predict_z)
            log_p_z_mean = torch.mean(log_p_z) / predict_z.size(1)
            log_det_jacobian = log_det_jacobian / predict_z.size(1)
            loss = -(log_p_z_mean + log_det_jacobian)
            running_loss += loss.item()
        
        if epoch % save_image_interval == 0:
            sample_z = net.Z.sample((n_save_image,))
            predict_x = net.inverse(sample_z)
            print(torch.mean(predict_x))
            save_image(predict_x.data.cpu(), '{}/epoch_{}.png'.format(save_dir, epoch), nrow=5, normalize=True)
            
    return running_loss / len(validation_loader)

In [59]:
train_nll_list = []
validation_nll_list = []

for epoch in range(n_epochs):
    train_nll = train(mnist_train_loader)
    validation_nll = validation(mnist_validation_loader, epoch)
    
    train_nll_list.append(train_nll)
    validation_nll_list.append(validation_nll)
    
    # if epoch < warm_up_epochs:
    #   optimizer.param_groups[0]['lr'] *= 10
    
    # print(optimizer.param_groups[0]['lr'])
    print('epoch[%2d/%2d] train_nll:%1.4f validation_nll:%1.4f' % (epoch+1, n_epochs, train_nll, validation_nll))

torch.save(net.state_dict(), save_dir + 'glow_model.pth')
torch.save(optimizer.state_dict(), save_dir + 'glow_optimizer.pth')

np.save(save_dir + 'train_nll_list.npy', np.array(train_nll_list))
np.save(save_dir + 'validation_nll_list.npy', np.array(validation_nll_list))

batch_index[  1/469] log_p_z_mean:-1.41766 log_det_jacobian:1.01319
batch_index[  2/469] log_p_z_mean:-0.94205 log_det_jacobian:-1.64657
batch_index[  3/469] log_p_z_mean:-0.93810 log_det_jacobian:-4.29338
batch_index[  4/469] log_p_z_mean:-0.95632 log_det_jacobian:-7.44805
batch_index[  5/469] log_p_z_mean:-1.74472 log_det_jacobian:-9.99125
batch_index[  6/469] log_p_z_mean:-1.32199 log_det_jacobian:-10.97683
batch_index[  7/469] log_p_z_mean:-2.46990 log_det_jacobian:-11.22617
batch_index[  8/469] log_p_z_mean:-0.96171 log_det_jacobian:-11.17433
batch_index[  9/469] log_p_z_mean:-0.96874 log_det_jacobian:-11.15997
batch_index[ 10/469] log_p_z_mean:-0.99012 log_det_jacobian:-11.22556
batch_index[ 11/469] log_p_z_mean:-1.01690 log_det_jacobian:-11.32691
batch_index[ 12/469] log_p_z_mean:-1.03288 log_det_jacobian:-11.43548
batch_index[ 13/469] log_p_z_mean:-1.02215 log_det_jacobian:-11.53787
batch_index[ 14/469] log_p_z_mean:-0.99108 log_det_jacobian:-11.61552
batch_index[ 15/469] log_p

KeyboardInterrupt: 