In [37]:
import os
import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import matplotlib.pyplot as plt
from enum import IntEnum
from dataclasses import dataclass
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms

In [38]:
@dataclass
class Hyperparameters:
    learning_rate: float = 0.001
    batch_size: int = 32
    num_epochs: int = 10
    dropout_rate: float = 0.5
    num_workers: int = 4
    device: str = 'cuda' if torch.cuda.is_available() else 'mps' if torch.mps.is_available() else 'cpu'

hp = Hyperparameters()


In [39]:
transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor()
])

transform_test = transforms.Compose([
    transforms.ToTensor()
])

trainset = torchvision.datasets.CIFAR10(root='data', train=True, download=True, transform=transform_train)
trainloader = DataLoader(trainset, batch_size=hp.batch_size, shuffle=True, num_workers=hp.num_workers)

testset = torchvision.datasets.CIFAR10(root='data', train=False, download=True, transform=transform_test)
testloader = DataLoader(testset, batch_size=hp.batch_size, shuffle=False, num_workers=hp.num_workers)


Files already downloaded and verified
Files already downloaded and verified


In [40]:
class ResNet(nn.Module):
    class ResidualBlock(nn.Module):
        def __init__(self, in_channels, out_channels):
            super(ResNet.ResidualBlock, self).__init__()
            
            self.model = nn.Sequential(*[
                nn.BatchNorm2d(in_channels),
                nn.ReLU(),  
                nn.utils.weight_norm(nn.Conv2d(in_channels,
                                               out_channels, 
                                               kernel_size=3, 
                                               padding=1, 
                                               bias=False)),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(),
                nn.utils.weight_norm(nn.Conv2d(out_channels, 
                                               out_channels,
                                               kernel_size=3, 
                                               padding=1, 
                                               bias=True))
            ])
            
        def forward(self, x):
            return x + self.model(x)
    
    def __init__(self, in_channels, mid_channels, out_channels, num_blocks,
                 kernel_size, padding, double_after_norm):
        super(ResNet, self).__init__()
        
        self.head_norm = nn.BatchNorm2d(in_channels)
        self.double_after_norm = double_after_norm

        self.head_conv = nn.utils.weight_norm(nn.Conv2d(2*in_channels, 
                                                        mid_channels, 
                                                        kernel_size, 
                                                        padding=padding, 
                                                        bias=True))
        self.head_skip = nn.utils.weight_norm(nn.Conv2d(mid_channels,
                                                        mid_channels,
                                                        kernel_size=1,
                                                        padding=0,
                                                        bias=True))
        
        self.blocks = nn.ModuleList([self.ResidualBlock(mid_channels, mid_channels)
                                     for _ in range(num_blocks)])
        self.skips = nn.ModuleList([nn.utils.weight_norm(nn.Conv2d(mid_channels,
                                                                   mid_channels,
                                                                   kernel_size=1,
                                                                   padding=0,
                                                                   bias=True))
                                   for _ in range(num_blocks)])
        
        
        self.tail_norm = nn.BatchNorm2d(mid_channels)
        self.tail_conv = nn.utils.weight_norm(nn.Conv2d(mid_channels, 
                                                        out_channels, 
                                                        kernel_size=1, 
                                                        padding=0, 
                                                        bias=True))
        
    def forward(self, x):
        x = self.head_norm(x)
        if self.double_after_norm:
            x *= 2.0

        x = torch.cat((x, -x), dim=1) # Unet의 디코딩 과정과 비슷해보임
        x = F.relu(x)
        
        x = self.head_conv(x)
        x_skip = self.head_skip(x)
        
        for block, skip in zip(self.blocks, self.skips):
            x = block(x)
            x_skip += skip(x)
        
        x = self.tail_norm(x_skip)
        x = F.relu(x)
        x = self.tail_conv(x)
        
        return x


![](https://raw.githubusercontent.com/crlotwhite/ML_Study/main/%EB%85%BC%EB%AC%B8%EA%B5%AC%ED%98%84/generative/assets/Cap%202024-08-15%2001-45-54-178.png)

![](https://raw.githubusercontent.com/crlotwhite/ML_Study/main/%EB%85%BC%EB%AC%B8%EA%B5%AC%ED%98%84/generative/assets/Cap%202024-08-15%2003-25-01-809.png)

왼쪽: 체크 무늬  
오른쪽: Channel-wise

In [41]:
class CouplingLayer(nn.Module):
    class MaskType(IntEnum):
        CHECKERBOARD = 0
        CHANNEL_WISE = 1
        
    class Rescale(nn.Module):
        def __init__(self, num_channels):
            super(CouplingLayer.Rescale, self).__init__()
            
            self.weight = nn.Parameter(torch.ones(num_channels, 1, 1))
        
        def forward(self, x):
            return x * self.weight
        
    def __init__(self, in_channels, mid_channels, 
                 num_blocks, mask_type, reverse_mask):
        super(CouplingLayer, self).__init__()
        
        self.mask_type = mask_type
        self.reverse_mask = reverse_mask
        
        if self.mask_type == self.MaskType.CHANNEL_WISE:
            in_channels //= 2
        
        self.st_net = ResNet(in_channels, mid_channels, 2*in_channels, 
                             num_blocks=num_blocks, kernel_size=3, padding=1,
                             double_after_norm=(self.mask_type == self.MaskType.CHECKERBOARD))
        
        self.rescale = nn.utils.weight_norm(self.Rescale(in_channels))
        
    def forward(self, x, sldj=None, reverse=True):
        if self.mask_type == self.MaskType.CHECKERBOARD:
            b = self.checkerboard_mask(x.size(2), x.size(3), self.reverse_mask, device=x.device)
            x_b = x * b
            st = self.st_net(x_b)
            s, t = st.chunk(2, dim=1)
            s = self.rescale(torch.tanh(s))
            s = s * (1 - b)
            t = t * (1 - b)
            
            if reverse:
                inv_exp_s = s.mul(-1).exp()
                x = x * inv_exp_s - t
            else:
                exp_s = s.exp()
                x = (x + t) * exp_s
                
                sldj += s.contiguous().view(s.size(0), -1).sum(-1) # sldj = log|det(dz/dx)|
        else:
            if self.reverse_mask:
                x_id, x_change = x.chunk(2, dim=1)
            else:
                x_change, x_id = x.chunk(2, dim=1)
            
            st = self.st_net(x_id)
            s, t = st.chunk(2, dim=1)
            s = self.rescale(torch.tanh(s))
            
            if reverse:
                inv_exp_s = s.mul(-1).exp()
                x_change = x_change * inv_exp_s - t
            else:
                exp_s = s.exp()
                x_change = (x_change + t) * exp_s
                
                sldj += s.contiguous().view(s.size(0), -1).sum(-1)
            
            if self.reverse_mask:
                x = torch.cat((x_id, x_change), dim=1)
            else:
                x = torch.cat((x_change, x_id), dim=1)
                
        return x, sldj
        
    def checkerboard_mask(self, height, width, reverse=False, 
                          dtype=torch.float32, device=None, 
                          requires_grad=False):
        checkerboard = [[((i % 2) + j) % 2 for j in range(width)] for i in range(height)]
        mask = torch.tensor(checkerboard, dtype=dtype, device=device, requires_grad=requires_grad)

        if reverse:
            mask = 1 - mask

        mask = mask.view(1, 1, height, width)
        return mask

In [42]:
# ref: https://github.com/tensorflow/models/blob/master/research/real_nvp/real_nvp_utils.py
def squeeze_2x2(x, reverse=False, alt_order=False):
    block_size = 2
    if alt_order:
        n, c, h, w = x.size()

        if reverse:
            c //= 4
        
        squeeze_matrix = torch.tensor([[[[1., 0.], [0., 0.]]],
                                       [[[0., 0.], [0., 1.]]],
                                       [[[0., 1.], [0., 0.]]],
                                       [[[0., 0.], [1., 0.]]]],
                                      dtype=x.dtype,
                                      device=x.device)
        perm_weight = torch.zeros((4 * c, c, 2, 2), dtype=x.dtype, device=x.device)
        for c_idx in range(c):
            slice_0 = slice(c_idx * 4, (c_idx + 1) * 4)
            slice_1 = slice(c_idx, c_idx + 1)
            perm_weight[slice_0, slice_1, :, :] = squeeze_matrix
        shuffle_channels = torch.tensor([c_idx * 4 for c_idx in range(c)]
                                        + [c_idx * 4 + 1 for c_idx in range(c)]
                                        + [c_idx * 4 + 2 for c_idx in range(c)]
                                        + [c_idx * 4 + 3 for c_idx in range(c)])
        perm_weight = perm_weight[shuffle_channels, :, :, :]

        if reverse:
            x = F.conv_transpose2d(x, perm_weight, stride=2)
        else:
            x = F.conv2d(x, perm_weight, stride=2)
    else:
        b, c, h, w = x.size()
        x = x.permute(0, 2, 3, 1)

        if reverse:
            x = x.view(b, h, w, c // 4, 2, 2)
            x = x.permute(0, 1, 4, 2, 5, 3)
            x = x.contiguous().view(b, 2 * h, 2 * w, c // 4)
        else:
            x = x.view(b, h // 2, 2, w // 2, 2, c)
            x = x.permute(0, 1, 3, 5, 2, 4)
            x = x.contiguous().view(b, h // 2, w // 2, c * 4)

        x = x.permute(0, 3, 1, 2)

    return x

In [43]:
class RealNVP(nn.Module):
    class RecursiveBlock(nn.Module):
        def __init__(self, scale_idx, num_scales, in_channels, mid_channels, num_blocks):
            super(RealNVP.RecursiveBlock, self).__init__()
            
            self.is_last_block = scale_idx == num_scales - 1
            self.in_couplings = nn.ModuleList([
                CouplingLayer(in_channels, mid_channels, num_blocks, CouplingLayer.MaskType.CHECKERBOARD, False),
                CouplingLayer(in_channels, mid_channels, num_blocks, CouplingLayer.MaskType.CHECKERBOARD, True),
                CouplingLayer(in_channels, mid_channels, num_blocks, CouplingLayer.MaskType.CHECKERBOARD, False)
            ])
            
            if self.is_last_block:
                self.in_couplings.append(
                    CouplingLayer(in_channels, mid_channels, num_blocks, CouplingLayer.MaskType.CHECKERBOARD, True))
            else:
                self.out_couplings = nn.ModuleList([
                    CouplingLayer(4 * in_channels, 2 * mid_channels, num_blocks, CouplingLayer.MaskType.CHANNEL_WISE, False),
                    CouplingLayer(4 * in_channels, 2 * mid_channels, num_blocks, CouplingLayer.MaskType.CHANNEL_WISE, True),
                    CouplingLayer(4 * in_channels, 2 * mid_channels, num_blocks, CouplingLayer.MaskType.CHANNEL_WISE, False)
                ])
                self.next_block = RealNVP.RecursiveBlock(scale_idx + 1, num_scales, 2 * in_channels, 2 * mid_channels, num_blocks)
                
        def forward(self, x, sldj=None, reverse=False):
            if reverse:
                if not self.is_last_block:
                    x = squeeze_2x2(x, reverse=False, alt_order=True)
                    x, x_split = x.chunk(2, dim=1)
                    x, sldj = self.next_block(x, sldj, reverse)
                    x = torch.cat((x, x_split), dim=1)
                    x = squeeze_2x2(x, reverse=True, alt_order=True)
                    
                    x = squeeze_2x2(x, reverse=False)
                    for coupling in reversed(self.out_couplings):
                        x, sldj = coupling(x, sldj, reverse)
                    x = squeeze_2x2(x, reverse=True)
                    
                for coupling in reversed(self.in_couplings):
                    x, sldj = coupling(x, sldj, reverse)
            else:
                for coupling in self.in_couplings:
                    x, sldj = coupling(x, sldj, reverse)
                
                if not self.is_last_block:
                    x = squeeze_2x2(x, reverse=False)
                    for coupling in self.out_couplings:
                        x, sldj = coupling(x, sldj, reverse)
                    x = squeeze_2x2(x, reverse=True)
                    
                    x = squeeze_2x2(x, reverse=False, alt_order=True)
                    x, x_split = x.chunk(2, dim=1)
                    x, sldj = self.next_block(x, sldj, reverse)
                    x = torch.cat((x, x_split), dim=1)
                    x = squeeze_2x2(x, reverse=True, alt_order=True)
            
            return x, sldj
                    
    def __init__(self, num_scales=2, in_channels=3, mid_channels=64, num_blocks=8):
        super(RealNVP, self).__init__()
        
        self.register_buffer('data_constraint', torch.tensor([0.9], dtype=torch.float32))
        
        self.flows = self.RecursiveBlock(0, num_scales, in_channels, mid_channels, num_blocks)
        
    def forward(self, x, reverse=False):
        sldj = None
        if not reverse:
            x, sldj = self.pre_process(x)
        
        x, sldj = self.flows(x, sldj, reverse)
        
        return x, sldj
            
    def pre_process(self, x):
        y = (x * 255. + torch.rand_like(x)) / 256.
        y = (y * 2 - 1) * self.data_constraint
        y = (y + 1) / 2
        y = y.log() - (1 - y).log()
        
        ldj = F.softplus(y) + F.softplus(-y) \
            - F.softplus((1 - self.data_constraint).log() - self.data_constraint.log())
        sldj = ldj.view(ldj.size(0), -1).sum(-1)
        
        return y, sldj

경계 효과를 줄이기 위해 특수한 방식으로 밀도를 모델링  

밀도 함수 변환: 밀도를 모델링할 때, 다음과 같은 변환을 사용합니다
$$ logit(\alpha + (1 - \alpha)\frac{x}{256}) $$

- a: 0.05
- x: 원본 픽셀

최적화: 
- ADAM 최적화 방법을 사용하며 기본 하이퍼파라미터를 그대로 설정
- L2 정규화를 사용하여 가중치의 크기 매개변수를 조정
- 사전 분포 $P_z$는 등방성 유닛 노름 가우시안으로 설정

In [44]:
class RealNVPLoss(nn.Module):
    def __init__(self, k=256):
        super(RealNVPLoss, self).__init__()
        self.k = k
        
    def forward(self, z, sldj):
        prior_ll = -0.5 * (z ** 2 + np.log(2 * np.pi))
        prior_ll = prior_ll.contiguous().view(z.size(0), -1).sum(-1) - np.log(self.k) * np.prod(z.size()[1:])
        ll = prior_ll + sldj
        nll = -ll.mean()
        
        return nll
        

In [45]:
model = RealNVP(num_scales=2, in_channels=3, mid_channels=64, num_blocks=8).to(hp.device)
loss_fn = RealNVPLoss()

norm_params = []
unnorm_params = []
for n, p in model.named_parameters():
    if n.endswith('weight_g'):
        norm_params.append(p)
    else:
        unnorm_params.append(p)

param_groups = [{'name': 'normalized', 'params': norm_params, 'weight_decay': 5e-5},
                {'name': 'unnormalized', 'params': unnorm_params}]


optimizer = optim.Adam(param_groups, lr=1e-3)

In [46]:
class AverageMeter(object):
    """Computes and stores the average and current value.

    Adapted from: https://github.com/pytorch/examples/blob/master/imagenet/train.py
    """
    def __init__(self):
        self.val = 0.
        self.avg = 0.
        self.sum = 0.
        self.count = 0.

    def reset(self):
        self.val = 0.
        self.avg = 0.
        self.sum = 0.
        self.count = 0.

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [47]:
def sample(model, batch_size, device):
    z = torch.randn((batch_size, 3, 32, 32), dtype=torch.float32, device=device)
    x, _ = model(z, reverse=True)
    x = torch.sigmoid(x)
    
    return x

In [48]:
best_loss = 0.0

for epoch in range(hp.num_epochs):
    model.train()
    loss_meter = AverageMeter()
    with tqdm.tqdm(total=len(trainloader.dataset)) as pbar:
        for x, _ in trainloader:
            x = x.to(hp.device)
            optimizer.zero_grad()
            
            z, sldj = model(x, reverse=False)
            loss = loss_fn(z, sldj)
            loss_meter.update(loss.item(), x.size(0))
            loss.backward()
            for group in optimizer.param_groups:
                torch.nn.utils.clip_grad_norm_(group['params'], 100.0, 2)
            
            optimizer.step()
            
            pbar.set_postfix(loss=loss_meter.avg, 
                             bpd=loss_meter.avg / (np.log(2) * np.prod(x.size()[1:])))
            pbar.update(x.size(0))
            
    model.eval()
    loss_meter = AverageMeter()
    with torch.no_grad():
        with tqdm.tqdm(total=len(testloader.dataset)) as pbar:
            for x, _ in testloader:
                x = x.to(hp.device)
                
                z, sldj = model(x, reverse=False)
                loss = loss_fn(z, sldj)
                loss_meter.update(loss.item(), x.size(0))
                
                pbar.set_postfix(loss=loss_meter.avg, 
                                 bpd=loss_meter.avg / (np.log(2) * np.prod(x.size()[1:])))
                pbar.update(x.size(0))
        
        if loss_meter.avg < best_loss:
            best_loss = loss_meter.avg
        
        images = sample(model, 64, hp.device)
        os.makedirs('samples', exist_ok=True)
        images_concat = torchvision.utils.make_grid(images, nrow=8, padding=2, pad_value=255)
        torchvision.utils.save_image(images_concat, f'samples/{str(epoch).zfill(3)}.png')
    

100%|██████████| 50000/50000 [06:31<00:00, 127.74it/s, bpd=4.66, loss=9.93e+3]
100%|██████████| 10000/10000 [00:27<00:00, 360.47it/s, bpd=4.43, loss=9.44e+3]
100%|██████████| 50000/50000 [06:27<00:00, 129.07it/s, bpd=4.18, loss=8.91e+3]
100%|██████████| 10000/10000 [00:26<00:00, 371.12it/s, bpd=6.93, loss=1.48e+4]
100%|██████████| 50000/50000 [06:25<00:00, 129.61it/s, bpd=4.04, loss=8.61e+3]
100%|██████████| 10000/10000 [00:26<00:00, 375.10it/s, bpd=10.1, loss=2.14e+4]
100%|██████████| 50000/50000 [06:23<00:00, 130.30it/s, bpd=3.95, loss=8.42e+3]
100%|██████████| 10000/10000 [00:26<00:00, 374.17it/s, bpd=8.24, loss=1.75e+4]
100%|██████████| 50000/50000 [06:27<00:00, 129.01it/s, bpd=3.91, loss=8.32e+3]
100%|██████████| 10000/10000 [00:27<00:00, 369.35it/s, bpd=4.13, loss=8.8e+3]
100%|██████████| 50000/50000 [06:29<00:00, 128.53it/s, bpd=3.86, loss=8.22e+3]
100%|██████████| 10000/10000 [00:26<00:00, 372.33it/s, bpd=4.01, loss=8.55e+3]
100%|██████████| 50000/50000 [06:27<00:00, 129.19it/s

** result **

![](https://raw.githubusercontent.com/crlotwhite/ML_Study/main/%EB%85%BC%EB%AC%B8%EA%B5%AC%ED%98%84/generative/assets/realnvp_result.gif)