In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import weight_norm

In [6]:
def nonlinearity(x):
    return F.silu(x)

In [7]:
def normalize(x, temb, name):
    return nn.GroupNorm(num_groups=32, num_channels=x.shape[1], eps=1e-6, affine=True)(x)

In [8]:
def conv2d(x, num_units, kernel_size=3, stride=1, init_scale=1.0):
    conv = weight_norm(nn.Conv2d(x.shape[1], num_units, kernel_size, stride, padding=kernel_size // 2))
    nn.init.kaiming_normal_(conv.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
    conv.weight.data *= init_scale
    return conv(x)

In [9]:
def upsample(x, with_conv):
    B, C, H, W = x.shape
    x = F.interpolate(x, scale_factor=2, mode='nearest')
    if with_conv:
        x = conv2d(x, num_units=C, kernel_size=3, stride=1)
    return x

In [10]:
def downsample(x, with_conv):
    if with_conv:
        x = conv2d(x, num_units=x.shape[1], kernel_size=3, stride=2)
    else:
        x = F.avg_pool2d(x, kernel_size=2, stride=2, padding=0)
    return x

In [11]:
def nin(x, num_units):
    B, C, H, W = x.shape
    return weight_norm(nn.Conv2d(C, num_units, kernel_size=1, stride=1, padding=0))(x)

In [12]:
def resnet_block(x, temb, out_ch=None, conv_shortcut=False, dropout=0.0):
    B, C, H, W = x.shape
    if out_ch is None:
        out_ch = C

    h = x
    h = nonlinearity(normalize(h, temb, name='norm1'))
    h = conv2d(h, num_units=out_ch)
    h = h + nn.linear(nonlinearity(temb), out_ch)[:, :, None, None]

    h = nonlinearity(normalize(h, temb, name='norm2'))
    h = F.dropout(h, p=dropout, training=True)
    h = conv2d(h, num_units=out_ch, init_scale=0.)

    if C != out_ch:
        if conv_shortcut:
            x = conv2d(x, num_units=out_ch)
        else:
            x = nin(x, out_ch)

    return x + h

In [13]:
def dense(x, num_units):
    return weight_norm(nn.Linear(x.shape[-1], num_units))(x)

In [14]:
def attn_block(x, temb):
    B, C, H, W = x.shape
    h = normalize(x, temb=temb, name='norm')
    q = nin(h, C)
    k = nin(h, C)
    v = nin(h, C)

    w = torch.einsum('bchw,bCHW->bhwHW', q, k) * (C ** -0.5)
    w = w.view(B, H, W, H * W)
    w = F.softmax(w, dim=-1)
    w = w.view(B, H, W, H, W)

    h = torch.einsum('bhwHW,bHWc->bhwc', w, v)
    h = nin(h, C)

    return x + h

In [15]:
def get_timestep_embedding(t, dim):
    half_dim = dim // 2
    emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -(torch.log(torch.tensor(10000.0)) / half_dim))
    emb = t.float()[:, None] * emb[None, :]
    return torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)


In [16]:
class Model(nn.Module):
    def __init__(self, num_classes, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks=2, attn_resolutions=[], dropout=0.0, resamp_with_conv=True):
        super(Model, self).__init__()
        self.num_classes = num_classes
        self.ch = ch
        self.out_ch = out_ch
        self.ch_mult = ch_mult
        self.num_res_blocks = num_res_blocks
        self.attn_resolutions = attn_resolutions
        self.dropout = dropout
        self.resamp_with_conv = resamp_with_conv

        self.temb_dense_0 = dense
        self.temb_dense_1 = dense

        self.conv_in = weight_norm(nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1))

        self.down = nn.ModuleList()
        self.up = nn.ModuleList()
        num_resolutions = len(ch_mult)

        for i_level in range(num_resolutions):
            for i_block in range(num_res_blocks):
                self.down.append(resnet_block)
                if 2 ** i_level in attn_resolutions:
                    self.down.append(attn_block)
            if i_level != num_resolutions - 1:
                self.down.append(downsample)

        self.mid = nn.ModuleList([
            resnet_block,
            attn_block,
            resnet_block,
        ])

        for i_level in reversed(range(num_resolutions)):
            for i_block in range(num_res_blocks + 1):
                self.up.append(resnet_block)
                if 2 ** i_level in attn_resolutions:
                    self.up.append(attn_block)
            if i_level != 0:
                self.up.append(upsample)

        self.norm_out = normalize
        self.conv_out = weight_norm(nn.Conv2d(ch, out_ch, kernel_size=3, stride=1, padding=1))

    def forward(self, x, t, y=None):
        B, C, H, W = x.shape
        assert y is None, 'not supported'

        temb = get_timestep_embedding(t, self.ch)
        temb = self.temb_dense_0(temb, self.ch * 4)
        temb = self.temb_dense_1(nonlinearity(temb), self.ch * 4)

        h = self.conv_in(x)
        hs = [h]

        for layer in self.down:
            h = layer(h, temb=temb)
            hs.append(h)

        for layer in self.mid:
            h = layer(h, temb=temb)

        for layer in self.up:
            h = layer(torch.cat([h, hs.pop()], dim=1), temb=temb)

        h = nonlinearity(self.norm_out(h, temb=temb, name='norm_out'))
        h = self.conv_out(h)
        return h