In [220]:
import dataclasses

from pygments.styles.dracula import background


@dataclasses.dataclass
class Hyperparameter:
    batch_size: int = 64
    num_workers: int = 2
    num_epochs: int = 20
    d_hidden: int = 64
    lr: float = 1e-3
    lr_decay: float = 0.999995
    num_res: int = 4
    buffer: int = 1024
    num_attn: int = 10
    num_mix: int = 5
    dropout: float = 0.5
    
    
hp = Hyperparameter

In [221]:
import torch
import torchvision
import torchvision.transforms as T


trainsform = T.Compose([
    T.ToTensor(),
    T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = torchvision.datasets.CIFAR10(root='C:\\Users\\tama0\\.data', train=True, download=True, transform=trainsform)
test_dataset = torchvision.datasets.CIFAR10(root='C:\\Users\\tama0\\.data', train=False, download=True, transform=trainsform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=hp.batch_size, shuffle=True, num_workers=hp.num_workers)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=hp.batch_size, shuffle=True, num_workers=hp.num_workers)

Files already downloaded and verified
Files already downloaded and verified


In [222]:
import torch.nn as nn
import torch.nn.functional as F
from enum import IntEnum

class ShiftType(IntEnum):
    DOWN = 0
    RIGHT = 1


class ShiftLayer(torch.nn.Module):
    def __init__(self, shift_type):
        super(ShiftLayer, self).__init__()
        self.shift_type = shift_type
        
    def forward(self, x):
        match self.shift_type:
            case ShiftType.DOWN:
                return F.pad(x, (0,0,1,0))[:,:,:-1,:]
            case ShiftType.RIGHT:
                return F.pad(x, (1,0))[:,:,:,:-1]
        

In [223]:
class Conv2d(nn.Conv2d):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        nn.utils.weight_norm(self)
        
class DownShiftedConv2d(nn.Conv2d):
    def forward(self, x):
        hk, wk = self.kernel_size
        x = F.pad(x, ((wk-1)//2, (wk-1)//2, hk-1, 0))
        return super().forward(x)
    
class DownRightShiftedConv2d(Conv2d):
    def forward(self, x):
        hk, wk = self.kernel_size
        x = F.pad(x, (wk-1, 0, hk-1, 0))
        return super().forward(x)
        

In [224]:
class GatedResidualBlock(nn.Module):
    def __init__(self, conv, n_channels, kernel_size, 
                 drop_rate=0, shortcut_channels=None, n_cond_classes=None):
        super().__init__()
        
        self.conv_1 = conv(2 * n_channels, n_channels, kernel_size)
        if shortcut_channels:
            self.conv_1_sc = Conv2d(2 * shortcut_channels, n_channels, kernel_size=1)
        if drop_rate > 0:
            self.drop_1 = nn.Dropout(drop_rate)
        self.conv_2 = conv(2 * n_channels, 2 * n_channels, kernel_size)
        
        if n_cond_classes:
            self.proj_h = nn.Linear(n_cond_classes, 2 * n_channels)
            
    def forward(self, x, a=None, h=None):
        c1 = self.conv_1(x)
        if a is not None:
            c1 = c1 + self.c1c(F.elu(torch.cat([a, -a], dim=1)))
        c1 = F.elu(torch.cat([c1, -c1], dim=1))
        if hasattr(self, 'drop_1'):
            c1 = self.drop_1(c1)
        c2 = self.conv_2(c1)
        if h is not None:
            c2 += self.proj_h(h)[:, :, None, None]
        a, b = c2.chunk(2, 1)
        out = x + a * torch.sigmoid(b)
        
        return out

In [225]:
class AttentionGatedResidualBlock(nn.Module):
    def __init__(self, n_channels, n_background_ch, n_res_layers, n_cond_classes,
                 drop_rate, n_hidden, d_query, d_value, attn_drop_rate):
        super().__init__()
        
        self.n_hidden = n_hidden
        self.d_query = d_query
        self.d_value = d_value
        self.attn_drop_rate = attn_drop_rate
        
        self.input_gated_resnet = nn.ModuleList([
            GatedResidualBlock(DownShiftedConv2d, n_channels, (2, 2), 
                               drop_rate, None, n_cond_classes) 
            for _ in range(n_res_layers)])
        
        self.in_proj_kv = nn.Sequential(
            GatedResidualBlock(Conv2d, 2 * n_channels + n_background_ch, 1, 
                               drop_rate, None, n_cond_classes), 
            Conv2d(2 * n_channels + n_background_ch, d_query + d_value, 1)
        )
        self.in_proj_q = nn.Sequential(
            GatedResidualBlock(Conv2d, n_channels + n_background_ch, 1, 
                               drop_rate, None, n_cond_classes),
            Conv2d(n_channels + n_background_ch, d_query, 1)
        )
        self.out_proj = GatedResidualBlock(Conv2d, n_channels, 1, drop_rate, 
                                           d_value, n_cond_classes)
    
    def forward(self, x, background, attn_mask, h=None):
        ul = x
        for layer in self.input_gated_resnet:
            ul = layer(ul, h=h)
        
        kv = self.in_proj_kv(torch.cat([x, ul, background], dim=1))
        k, v = kv.split([self.d_query, self.d_value], dim=1)
        q = self.in_proj_q(torch.cat([ul, background], dim=1))
        
        B, dq, H, W = q.shape
        _, dv, _, _ = v.shape
        
        flat_q = q.reshape(B, self.n_hidden,  dq//self.n_hidden, H, W).flatten(3) \
                 * (dq//self.n_hidden)**-0.5
        flat_k = k.reshape(B, self.n_hidden,  dq//self.n_hidden, H, W).flatten(3)
        flat_v = v.reshape(B, self.n_hidden,  dq//self.n_hidden, H, W).flatten(3)
        
        logits = flat_q.transpose(2, 3) @ flat_k
        logits = F.dropout(logits, p=self.attn_drop_rate, 
                           training=self.training, inplace=True)
        logits = logits.masked_fill(attn_mask == 0, -1e10)
        weights = F.softmax(logits, dim=-1)
        
        attn_out = weights @ flat_v.transpose(2, 3)
        attn_out = attn_out.transpose(2, 3)
        attn_out = attn_out.reshape(B, -1, H, W)
        return self.out_proj(ul, attn_out)
        

In [226]:
class PixelSNAIL(nn.Module):
    def __init__(self, image_dims=(3, 32, 32), n_channels=128, n_res_layers=5, 
                 attn_n_layers=12, attn_n_hidden=1, attn_d_query=16, 
                 attn_d_value=128, attn_drop_rate=0, n_logistic_mix=10, 
                 n_cond_classes=None, drop_rate=0.5):
        super().__init__()
        
        C, H, W = image_dims
        background_v = (((torch.arange(H, dtype=torch.float) - H / 2) / 2)
                        .view(1, 1, -1, 1).expand(1, C, H, W))
        background_h = (((torch.arange(W, dtype=torch.float) - W / 2) / 2)
                        .view(1, 1, 1, -1).expand(1, C, H, W))
        self.register_buffer('background', torch.cat([background_v, background_h], 1))
        attn_mask = torch.tril(torch.ones(1, 1, H*W, H*W), diagonal=-1).byte()
        self.register_buffer('attn_mask', attn_mask)
        
        self.ul_input_d = DownShiftedConv2d(C+1, n_channels, kernel_size=(1, 3))
        self.ul_input_dr = DownRightShiftedConv2d(C+1, n_channels, kernel_size=(2, 1))
        
        self.ul_modules = nn.ModuleList([
            AttentionGatedResidualBlock(
                n_channels, self.background.shape[1], n_res_layers, n_cond_classes, 
                drop_rate, attn_n_hidden, attn_d_query, attn_d_value, attn_drop_rate) 
            for _ in range(attn_n_layers)])
        self.output_conv = Conv2d(n_channels, (3 * C + 1) * n_logistic_mix, 1)
        
    def forward(self, x, h=None):
        x = F.pad(x, (0, 0, 0, 0, 0, 1), value=1)
        
        ul = self.down_shift(self.ul_input_d(x)) \
             + self.right_shift(self.ul_input_dr(x))
        
        for module in self.ul_modules:
            ul = module(ul, self.background.expand(x.shape[0],-1,-1,-1), 
                        self.attn_mask, h)
            
        return self.output_conv(F.elu(ul))
    
    def down_shift(self, x):
        return F.pad(x, (0,0,1,0))[:,:,:-1,:]

    def right_shift(self, x):
        return F.pad(x, (1,0))[:,:,:,:-1]

In [None]:
# https://github.com/kamenbliznashki/pixel_models/tree/master

In [227]:
class DiscretizedLogisticMixLoss(nn.Module):
    def __init__(self):
        super(DiscretizedLogisticMixLoss, self).__init__()

    def forward(self, y_true, y_pred):
        # y_true shape (batch_size, H, W, C)
        _, H, W, C = y_true.size()
        num_pixels = float(H * W * C)

        if C == 1:
            pi, mu, logvar = torch.split(y_pred, 3, dim=-1)
            mu = torch.unsqueeze(mu, dim=3)
            logvar = torch.unsqueeze(logvar, dim=3)
        else:  # C == 3
            (pi, mu_r, mu_g, mu_b, logvar_r, logvar_g, logvar_b, alpha,
             beta, gamma) = torch.split(y_pred, 10, dim=-1)

            alpha = torch.tanh(alpha)
            beta = torch.tanh(beta)
            gamma = torch.tanh(gamma)

            red = y_true[:,:,:,0:1]
            green = y_true[:,:,:,1:2]

            mu_g = mu_g + alpha * red
            mu_b = mu_b + beta * red + gamma * green
            mu = torch.stack([mu_r, mu_g, mu_b], dim=3)
            logvar = torch.stack([logvar_r, logvar_g, logvar_b], dim=3)

        logvar = torch.clamp(logvar, min=-7.)

        # Add extra-dim for broadcasting channel-wise
        y_true = torch.unsqueeze(y_true, dim=-1)

        def cdf(x):  # logistic cdf
            return torch.sigmoid((x - mu) * torch.exp(-logvar))

        def log_cdf(x):  # log logistic cdf
            return F.logsigmoid((x - mu) * torch.exp(-logvar))

        def log_one_minus_cdf(x):  # log one minus logistic cdf
            return -F.softplus((x - mu) * torch.exp(-logvar))

        def log_pdf(x):  # log logistic pdf
            norm = (x - mu) * torch.exp(-logvar)
            return norm - logvar - 2. * F.softplus(norm)

        half_pixel = 1 / 255.

        cdf_plus = cdf(y_true + half_pixel)
        cdf_min = cdf(y_true - half_pixel)

        log_cdf_plus = log_cdf(y_true + half_pixel)
        log_one_minus_cdf_min = log_one_minus_cdf(y_true - half_pixel)

        cdf_delta = cdf_plus - cdf_min
        cdf_delta = torch.clamp(cdf_delta, min=1e-12)

        # At small probabilities the interval difference is approximated
        # as the pdf value at the center
        approx_log_cdf_delta = log_pdf(y_true) - torch.log(torch.tensor(127.5))
        log_probs = torch.where(cdf_delta > 1e-5, torch.log(cdf_delta), approx_log_cdf_delta)

        # Deal with edge cases
        log_probs = torch.where(y_true > 0.999, log_one_minus_cdf_min, log_probs)
        log_probs = torch.where(y_true < -0.999, log_cdf_plus, log_probs)

        log_probs = torch.sum(log_probs, dim=3)  # whole pixel prob per component
        log_probs += F.log_softmax(pi, dim=-1)  # multiply by mixture components
        log_probs = torch.logsumexp(log_probs, dim=-1)  # add components probs
        log_probs = torch.sum(log_probs, dim=[1, 2])

        # Convert to bits per dim
        bits_per_dim = -log_probs / num_pixels / torch.log(torch.tensor(2.))

        return bits_per_dim

In [228]:
steps_per_epochs = len(train_dataset) // hp.batch_size
decay_per_epoch = hp.lr_decay ** steps_per_epochs

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = PixelSNAIL(3, hp.d_hidden, hp.num_attn, hp.dropout, hp.num_res, hp.num_mix).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=hp.lr) 
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, decay_per_epoch)

In [229]:
losses = []

for epoch in range(hp.num_epochs):
    model.train()
    total_loss = 0
    
    for  image, _ in tqdm(train_loader):
        image = image.to(device)
        
        optimizer.zero_grad()
        output = model(image)
        
        loss = DiscretizedLogisticMixLoss()(image, output)
        total_loss += loss.item()
        loss.backward()
        optimizer.step()
    
    scheduler.step()
    
    ava_loss = total_loss / len(train_loader)
    losses.append(ava_loss)

torch.save(model.state_dict(), 'pixelsnail_model.pth')
print('done')

  0%|          | 0/782 [00:00<?, ?it/s]

torch.Size([64, 64, 31, 32])
torch.Size([64, 64, 34, 30])


  0%|          | 0/782 [00:06<?, ?it/s]


RuntimeError: The size of tensor a (32) must match the size of tensor b (31) at non-singleton dimension 3