In [1]:
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
import torch.nn.functional as F

import numpy as np 
import pandas as pd 
import matplotlib.pyplot as plt

In [2]:
class ActNorm(nn.Module):
    def __init__(self, n_channel):
        super().__init__()
        self.n_channel = n_channel

        self.scale = Parameter(torch.Tensor(1,n_channel,1,1))
        self.bias = Parameter(torch.Tensor(1,n_channel,1,1))
        self.initialize = False
    def forward(self, x):
        #x: bxcxhxw Tensor
        #return: output(bxcxhxw Tensor), log_det(b Tensor)
        b, c, h, w = x.shape
        assert x.shape[1] == self.n_channel
        output = torch.clone(x)
        # Initialize into zero-mean, unit variance of mini-batch
        if not self.initialize: 
            data = output.transpose(0,1).reshape(self.n_channel,-1)
            std, mean = torch.std_mean(data, dim = -1)
            std, mean = std.view(1, self.n_channel, 1, 1), mean.view(1, self.n_channel, 1, 1)
            self.scale.data.copy_(1/(std+1e-9))
            self.bias.data.copy_(-mean)
            self.initialize = True
        output += self.bias
        output *= self.scale

        log_det = h * w * self.scale.abs().log().sum()
        log_det = log_det.repeat(b)
        return output, log_det
    def reverse(self, z): 
        #왠지 모르겟지만 Clone을 쓰지 않을 경우, input으로 들어가는 Tensor 객체가 바뀜
        output = torch.clone(z)
        output /= self.scale
        output -= self.bias
        return output


In [3]:
b, c, h, w = 4, 10, 32, 32
x = torch.randn(b, c, h, w)
model = ActNorm(c)
z, log_det = model(x)
x_reverse = model.reverse(z)
#check initialize into 0 mean, Unit Variance
z_std, z_mean = torch.std_mean(z.transpose(0,1).reshape(c, -1), dim = -1)
assert torch.allclose(z_mean, torch.zeros_like(z_mean), atol = 1e-7)
assert torch.allclose(z_std, torch.ones_like(z_std), atol = 1e-7)
#check log_det
print(f"log_det, {log_det[0].item():.2f}")
assert log_det.shape == (b, )
#check invertible
assert torch.allclose(x_reverse, x, atol = 1e-7)

log_det, -18.91


In [4]:
class ImageAffineCouplingLayer(nn.Module):
    def __init__(self, n_channel):
        super().__init__()
        self.n_channel = n_channel
        #split along channel
        self.n_split = n_channel // 2

        self.nn = nn.Sequential(
            nn.Conv2d(self.n_split, 512, 3, padding =1),
            nn.ReLU(),
            nn.Conv2d(512, 512, 1),
            nn.ReLU(),
            nn.Conv2d(512, 2*(self.n_channel - self.n_split), 3, padding = 1)
        )
        #init last weight into zero
        nn.init.constant_(self.nn[-1].weight, 0)
        nn.init.constant_(self.nn[-1].bias, 0)
    def forward(self, x):
        # x: bxcxhxw
        b,c,h,w = x.shape
        assert self.n_channel == c
        x_a, x_b = x[:,:self.n_split], x[:,self.n_split:]
        nn_result = self.nn(x_a) #bx2(D-d) x h x w
        #log_s, t: bx(D-d)xhxw
        log_s, t = nn_result[:,0::2,:,:], nn_result[:,1::2,:,:]
        # s = torch.exp(log_s) #log_s는 initially 0
        s = torch.sigmoid(log_s + 2) #torch.exp대문에 잘 안되는 듯??
        y_a, y_b = x_a, s*x_b + t
        y = torch.cat((y_a, y_b), dim = 1)
        log_det = s.view(b, -1).abs().log().sum(dim = 1)

        return y, log_det
    def reverse(self, z):
        # z: bxcxhxw
        z_a, z_b = z[:,:self.n_split], z[:,self.n_split:]
        nn_result = self.nn(z_a) #bx2(D-d) x h x w
        log_s, t = nn_result[:,0::2,:,:], nn_result[:,1::2,:,:]
        # s = torch.exp(log_s)
        s = torch.sigmoid(log_s + 2)
        x_a, x_b = z_a, (z_b-t)/s
        x = torch.cat((x_a, x_b), dim = 1)
        return x


In [5]:
b, c, h, w = 4, 10, 32, 32
split = 3
x = torch.randn(b,c,h,w)
model = ImageAffineCouplingLayer(n_channel = c)
z, log_det = model(x)
x2 = model.reverse(z)
#check z == x at first
# assert torch.allclose(x, z, atol = 1e-7)
#check invertible
assert torch.allclose(x, x2, atol = 1e-7)
#check log_det
assert log_det.shape == (b,)
print(log_det.mean().item())

-649.8717651367188


In [6]:
class Invertible1to1Conv(nn.Module):
    def __init__(self, n_channel):
        super().__init__()
        self.n_channel = n_channel
        #LDU decomposition안해도 괜찮은 성능인 듯 하여 안씀
        self.matrix = Parameter(torch.Tensor(self.n_channel, self.n_channel))
        
        #initialize with random permutation matrix
        init_matrix = torch.eye(self.n_channel)
        randperm = torch.randperm(self.n_channel)
        init_matrix = init_matrix[:, randperm]
        self.matrix.data.copy_(init_matrix)
    def forward(self, x):
        #x: bxcxhxw
        b,c,h,w = x.shape 
        output = x.transpose(1, -1) # bxhxwxc
        output = torch.matmul(output, self.matrix) #bxhxwxc
        log_det = h*w*self.matrix.det().abs().log().repeat(b)
        return output.transpose(1, -1), log_det
    def reverse(self, z):
        output = z.transpose(1, -1)
        output = torch.matmul(output, self.matrix.inverse())
        return output.transpose(1, -1)


In [7]:
b, c, h, w = 4, 1000, 32, 32
x = torch.randn(b,c,h,w)
model = Invertible1to1Conv(n_channel = c)
z, log_det = model(x)
x2 = model.reverse(z)
#check invertible
assert torch.allclose(x, x2, atol = 1e-7)
#check model matrix, permutation으로 det = 0
assert torch.allclose(log_det , torch.zeros(b), atol = 1e-7)

In [8]:
class GlowBlock(nn.Module):
    def __init__(self, n_channel):
        super().__init__()
        self.step = nn.ModuleList([
            ActNorm(n_channel = n_channel),
            Invertible1to1Conv(n_channel = n_channel),
            ImageAffineCouplingLayer(n_channel = n_channel),
        ])
    def forward(self, x):
        b,c,h,w = x.shape
        output, log_det = x, 0 
        for layer in self.step:
            output, log_det_ = layer(output)
            log_det += log_det_
        return output, log_det
    def reverse(self, z):
        output = z
        for layer in self.step[::-1]:
            output = layer.reverse(output)
        return output

In [9]:
b, c, h, w = 4, 10, 32, 32
x = torch.randn(b,c,h,w)
model = GlowBlock(n_channel= c)
z, log_det = model(x)
x2 = model.reverse(z)
assert torch.allclose(x, x2, atol = 1e-7)

In [10]:
class GlowLevel(nn.Module):
    def __init__(self, n_channel, n_flow, split = True):
        super().__init__()
        self.n_channel, self.n_flow, self.split = n_channel, n_flow, split
        self.step = nn.ModuleList([GlowBlock(n_channel = n_channel* 4) for _ in range(n_flow)])
    def forward(self, x):
        b, c, h, w = x.shape
        c_out, h_out, w_out = c*4, h//2, w//2
        output = x.view(b,c,h_out,2,w_out,2).permute(0,1,3,5,2,4).reshape(b, c_out, h_out, w_out)
        log_det = 0
        for layer in self.step:
            output, log_det_ = layer(output)
            log_det += log_det_
        if self.split:
            z_new, output = output.chunk(2 , dim = 1)
            return (z_new, output), log_det
        else:
            return output, log_det
    def reverse(self, z):
        output = None
        if self.split:
            z1, z2 = z
            output = torch.cat([z1,z2], dim = 1)
        else:
            output = z
        b, c, h, w = output.shape
        for layer in self.step[::-1]:
            output = layer.reverse(output)
        output = output.view(b, c//4, 2, 2, h, w).permute(0, 1, 4, 2, 5, 3)
        output = output.reshape(b , c//4, h*2, w*2)
        return output
        

In [11]:
b, c, h, w = 4, 10, 32, 32
x = torch.randn(b, c,h,w)
model = GlowLevel(n_channel= c, n_flow = 10)
z, log_det = model(x)
z1, z2 = z
x2 = model.reverse(z)
assert z1.shape == z2.shape
assert torch.allclose(x, x2, atol = 1e-7)
model = GlowLevel(n_channel= c, n_flow = 10, split = False)
z, log_det = model(x)
x2 = model.reverse(z)
assert z.shape == (b, 4*c, h//2, w//2)
assert torch.allclose(x, x2, atol = 1e-7)

In [12]:
class Glow(nn.Module):
    def __init__(self, n_channel, n_flow, n_level):
        super().__init__()
        self.n_level, self.n_flow, self.n_channel = n_level, n_flow, n_channel
        self.blocks = nn.ModuleList([GlowLevel(n_channel = self.n_channel *(2**idx), 
                                               split = idx!= self.n_level-1, n_flow = n_flow) for idx in range(self.n_level)])
    def forward(self, x):
        b,c,h,w = x.shape
        hidden, z_arr, log_det = x, [], 0
        for layer in self.blocks[:-1]:
            (z, hidden), log_det_= layer(hidden)
            z_arr.append(z)
            log_det += log_det_
        z, log_det_ = self.blocks[-1](hidden)
        log_det += log_det_
        z_arr.append(z)
        return z_arr, log_det
    def reverse(self, z):
        hidden = self.blocks[-1].reverse(z[-1])
        for idx in range(2, self.n_level+1):
            hidden = self.blocks[-idx].reverse((z[-idx], hidden))
        return hidden

In [13]:
def make_latent(n_level = 3, image_shape= (10, 3, 32, 32), device = torch.device("cpu")):
    b, c, h, w = image_shape
    z_arr =[]
    for idx in range(n_level):
        multiple = 2**(idx+1)
        channel = c*multiple
        if n_level -1 == idx:
            channel *= 2
        z_arr.append(torch.randn(b, channel, h//multiple, w//multiple, device = device))
    return z_arr

In [14]:
def check_model_size(model):
    n_params = np.sum([p.numel() for p in model.parameters()])
    print(f"{n_params* 4 /10**6:2f}MB")

In [15]:
b, c, h, w = 4, 3, 32, 32
x = torch.randn(b,c,h,w)
model = Glow(n_channel= c, n_flow= 32, n_level = 3)
z, log_det = model(x)
x2 = model.reverse(z)
assert torch.allclose(x, x2, atol = 1e-7)
#check make_latent make same shape
new_z = make_latent(3, (b,c,h,w))
assert len(z) == len(new_z)
for idx in range(len(z)):
    assert z[idx].shape == new_z[idx].shape, f"{idx} error"
n_params = np.sum([p.numel() for p in model.parameters()])
print(f"{n_params* 4 /10**6:2f}MB")

175.793664MB


In [16]:
device = torch.device("cuda:0")
# model.to(device), 1.3GB

In [17]:
from time import time
def time_check(start):
    total_time = round(time() - start)
    min, seconds = divmod(total_time, 60)
    return "{:02}:{:02}".format(int(min),int(seconds))

In [18]:
from torch.optim import Adam
from torchvision.datasets import CIFAR10
import torchvision.transforms as T
from torchvision.utils import save_image
from torch.utils.data import DataLoader

train_data = CIFAR10("./data", train= True, download = True, transform = T.Compose([T.ToTensor()]))
train_loader = DataLoader(train_data, batch_size= 64, shuffle= True, num_workers= 4)
#Temperature

device = torch.device("cuda:0")

model = Glow(n_channel= 3, n_flow = 32, n_level= 3)
model.to(device)
optimizer = Adam(model.parameters(), lr = 5e-5)

n_pixel = 32*32*3
n_bits = 5

train_loss_arr = []
n_iter = 0
for ep in range(100):
    model.train()
    ep_train_loss_arr = []
    start = time()
    for idx, (img, label) in enumerate(train_loader):
        # if idx> 30:
        #     break
        img = img.to(device)
        batch_size = img.shape[0]
        optimizer.zero_grad()
        output, log_det = model(img)
        output = [latent.reshape(batch_size, -1) for latent in output]
        output = torch.cat(output, dim = 1)
        loss_prior = ((output ** 2+ np.log(2 * np.pi))/2).sum(dim = 1)
        loss = loss_prior - log_det + n_pixel *np.log(256)
        loss = loss.mean()
        loss.backward()
        optimizer.step()
        n_iter += 1
        ep_train_loss_arr.append(loss.item())
    train_loss_arr += ep_train_loss_arr
    model.eval()
    with torch.no_grad():
        z = make_latent(n_level = 3, image_shape = (10, 3, 32, 32), device = device)
        generated = model.reverse(z)
        save_image(generated, f"Samples/{ep}_images.jpg",nrow = 10)
    print(f"[{ep},{n_iter}]: time: {time_check(start)} train_loss: {np.mean(ep_train_loss_arr):.4f}")


Files already downloaded and verified
[0,31]: time: 00:24 train_loss: 14851.8680
[1,62]: time: 00:24 train_loss: 13204.4499
[2,93]: time: 00:24 train_loss: 12019.0395
[3,124]: time: 00:24 train_loss: 11498.7102
[4,155]: time: 00:24 train_loss: 10896.4355
[5,186]: time: 00:24 train_loss: 10492.8645
[6,217]: time: 00:24 train_loss: 10174.2690
[7,248]: time: 00:24 train_loss: 10091.5697
[8,279]: time: 00:24 train_loss: 10029.0329
[9,310]: time: 00:24 train_loss: 9939.2910
[10,341]: time: 00:24 train_loss: 9747.4722
[11,372]: time: 00:24 train_loss: 9718.5530
[12,403]: time: 00:24 train_loss: 9764.2184
[13,434]: time: 00:24 train_loss: 9683.0403
[14,465]: time: 00:24 train_loss: 9680.5428
[15,496]: time: 00:24 train_loss: 9513.7864
[16,527]: time: 00:24 train_loss: 9532.6306
[17,558]: time: 00:24 train_loss: 9464.0624
[18,589]: time: 00:24 train_loss: 9506.5529
[19,620]: time: 00:24 train_loss: 9380.0016
[20,651]: time: 00:24 train_loss: 9379.7689
[21,682]: time: 00:24 train_loss: 9693.883

KeyboardInterrupt: 