In [1]:
import math
import torch
import torch.nn.functional as F
import torchvision.transforms as T
from torch import nn

In [2]:
def check_dims(x, shape):
    assert x.size() == shape, f'Expected {shape}, got {x.size()}'

## Weight-Standardized Convolution

In [3]:
class WSConv2d(nn.Conv2d):
    '''
    Weight-Standardized Convolution
    https://arxiv.org/abs/1903.10520
    '''
    def __init__(self, *args, eps=1e-5, **kwargs):
        super().__init__(*args, **kwargs)
        self.eps = eps

    def forward(self, x):
        mean = self.weight.mean(dim=1, keepdim=True)
        var = self.weight.var(dim=1, correction=0, keepdim=True)
        norm_weight = (self.weight - mean) * torch.rsqrt(var + self.eps)
        out = F.conv2d(
            x, norm_weight,
            self.bias, self.stride, self.padding, self.dilation, self.groups
        )
        return out

### Debug

In [None]:
torch.manual_seed(3985)
x = torch.rand([2, 32, 128, 128])
ws_conv = WSConv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1).to(x.device)
check_dims(ws_conv(x), (2, 64, 128, 128))

## Time-conditioned ResNet Block

In [4]:
class TimeResNetBlock(nn.Module):
    '''
    B: Batch size
    D: in_dim
    E: out_dim
    F: out_dim * 2
    G: E + F = out_dim * 3
    '''
    def __init__(self, in_dim, out_dim, t_dim):
        super().__init__()

        self.proj_t = nn.Linear(t_dim, out_dim*2)
        self.conv1 = nn.Sequential(
            WSConv2d(in_dim, out_dim, kernel_size=3, padding=1),
            nn.GroupNorm(num_groups=8, num_channels=out_dim)
        )
        self.conv2 = nn.Sequential(
            WSConv2d(out_dim, out_dim, kernel_size=3, padding=1),
            nn.GroupNorm(num_groups=8, num_channels=out_dim)
        )
        if in_dim != out_dim:
            self.rconv = nn.Conv2d(in_dim, out_dim, kernel_size=1)
        else:
            self.rconv = nn.Identity()

    def forward(self, x_BDHW, t_embd_BT):
        t_embd_BF = self.proj_t(t_embd_BT)
        scale_BE, shift_BE = t_embd_BF.reshape(*t_embd_BF.shape, 1, 1).chunk(2, dim=1)
        h_BEHW = self.conv1(x_BDHW)
        h_BEHW = F.silu(h_BEHW * (scale_BE + 1) + shift_BE)
        h_BEHW = self.conv2(h_BEHW)
        x_BEHW = h_BEHW + self.rconv(x_BDHW)
        return x_BEHW

### Debug

In [10]:
B = 2
D = 32
E = 64
F_ = 2 * E
T_ = D * 4
H, W = 128, 128

torch.manual_seed(3985)
x_BDHW = torch.rand([B, D, H, W], device='cuda')
t_BT = torch.rand([B, T_], device=x_BDHW.device)
model = TimeResNetBlock(D, E, T_).to(x_BDHW.device)

In [11]:
t_embd_BF = model.proj_t(t_embd_BT)
check_dims(t_embd_BF, (B, F_))

In [12]:
scale_BE, shift_BE = t_embd_BF.reshape(*t_embd_BF.shape, 1, 1).chunk(2, dim=1)
check_dims(scale_BE, (B, E, 1, 1))
check_dims(shift_BE, (B, E, 1, 1))

In [13]:
h_BEHW = model.conv1(x_BDHW)
check_dims(h_BEHW, (B, E, H, W))

In [14]:
h_BEHW = F.silu(h_BEHW * (scale_BE + 1) + shift_BE)
check_dims(h_BEHW, (B, E, H, W))

In [15]:
h_BEHW = model.conv2(h_BEHW)
check_dims(h_BEHW, (B, E, H, W))

In [16]:
model.rconv(x_BDHW).size(), h_BEHW.size()

(torch.Size([2, 64, 128, 128]), torch.Size([2, 64, 128, 128]))

In [20]:
x_BEHW = h_BEHW + model.rconv(x_BDHW)
check_dims(x_BEHW, (B, E, H, W))

In [21]:
check_dims(model(x_BDHW, t_BT), (B, E, H, W))

## Attention

In [5]:
class Attention(nn.Module):
    def __init__(self, d_embd, n_heads=4, d_head=32):
        super().__init__()
        d_hid = n_heads * d_head

        self.n_heads = n_heads
        self.d_head = d_head
        self.attn_proj = nn.Conv2d(d_embd, d_hid*3, kernel_size=1, bias=False)
        self.scale = d_head ** -0.5
        self.out_proj = nn.Conv2d(d_hid, d_embd, kernel_size=1)

    def forward(self, x):
        B, C, H, W = x.size()
        qkv = self.attn_proj(x).chunk(3, dim=1)
        to_attn_head = lambda z: z.reshape(B, self.n_heads, self.d_head, -1)
        q, k, v = map(to_attn_head, qkv)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        score = F.softmax(attn, dim=-1)
        y = score @ v
        y = y.transpose(-2, -1).reshape(B, -1, H, W)
        out = self.out_proj(y)

        return out

## UNet Downsample

In [6]:
class DownsampleOutProject(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.out_proj = nn.Conv2d(4*in_dim, out_dim, kernel_size=1)

    def forward(self, x):
        B, C, H, W = x.size()
        x = x.reshape(B, 4*C, H//2, W//2)
        x = self.out_proj(x)
        return x


class UNetDownsample(nn.Module):
    '''
    B: Batch size
    D: in_dim
    E: out_dim
    T: t_dim
    H, W: Last 2 dimensions of x
    Ho, Wo: (H, W) if is_last else (H // 2, W // 2)
    '''
    def __init__(self, in_dim, out_dim, t_dim, is_last=False):
        super().__init__()
        self.block1 = TimeResNetBlock(in_dim, in_dim, t_dim)
        self.block2 = TimeResNetBlock(in_dim, in_dim, t_dim)
        self.norm = nn.GroupNorm(num_groups=1, num_channels=in_dim)
        self.attn = Attention(in_dim)

        if is_last:
            self.dsample = nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1)
        else:
            self.dsample = DownsampleOutProject(in_dim, out_dim)

    def forward(self, x_BDHW, t_embd_BT):
        fmap1_BDHW = self.block1(x_BDHW, t_embd_BT)
        fmap2_BDHW = self.block2(fmap1_BDHW, t_embd_BT)
        x_BDHW = self.attn(self.norm(fmap2_BDHW)) + fmap2_BDHW
        x_BDHoWo = self.dsample(x_BDHW)
        return x_BDHoWo, fmap1_BDHW, fmap2_BDHW

### Debug

#### Case 1: `is_last = False`

In [48]:
B = 2
D = 32
H, W = 128, 128
E = 64
T_ = 128
Ho, Wo = H // 2, W // 2

torch.manual_seed(3985)
x_BDHW = torch.rand([B, D, H, W], device='cuda')
t_embd_BT = torch.rand([B, T_], device=x_BDHW.device)
model = UNetDownsample(D, E, T_).to(x_BDHW.device)

In [49]:
fmap1_BDHW = model.block1(x_BDHW, t_embd_BT)
check_dims(fmap1_BDHW, (B, D, H, W))

In [50]:
fmap2_BDHW = model.block2(fmap1_BDHW, t_embd_BT)
check_dims(fmap2_BDHW, (B, D, H, W))

In [51]:
x_BDHW = model.attn(model.norm(fmap2_BDHW)) + fmap2_BDHW
check_dims(x_BDHW, (B, D, H, W))

In [52]:
x_BEHoWo = model.dsample(x_BDHW)
check_dims(x_BEHoWo, (B, E, Ho, Wo))

In [54]:
a, b, c = model(x_BDHW, t_embd_BT)
check_dims(a, (B, E, Ho, Wo))
check_dims(b, (B, D, H, W))
check_dims(c, (B, D, H, W))

#### Case 2 `is_last = True`

In [55]:
Ho, Wo = H, W
model = UNetDownsample(D, E, T_, is_last=True).to(x_BDHW.device)

In [37]:
fmap1_BDHW = model.block1(x_BDHW, t_embd_BT)
check_dims(fmap1_BDHW, (B, D, H, W))

In [38]:
fmap2_BDHW = model.block2(fmap1_BDHW, t_embd_BT)
check_dims(fmap2_BDHW, (B, D, H, W))

In [39]:
x_BDHW = model.attn(model.norm(fmap2_BDHW)) + fmap2_BDHW
check_dims(x_BDHW, (B, D, H, W))

In [40]:
x_BEHoWo = model.dsample(x_BDHW)
check_dims(x_BEHoWo, (B, E, Ho, Wo))

In [56]:
a, b, c = model(x_BDHW, t_embd_BT)
check_dims(a, (B, E, Ho, Wo))
check_dims(b, (B, D, H, W))
check_dims(c, (B, D, H, W))

## UNet Upsample

In [7]:
class UNetUpsample(nn.Module):
    '''
    B: Batch size
    D: in_dim
    E: out_dim
    F: in_dim + out_dim
    T: t_dim
    H, W: Last 2 dimensions of x
    Ho, Wo: (H, W) if is_last else (H * 2, W * 2)
    '''
    def __init__(self, in_dim, out_dim, t_dim, is_last=False):
        super().__init__()
        self.block1 = TimeResNetBlock(in_dim+out_dim, out_dim, t_dim)
        self.block2 = TimeResNetBlock(in_dim+out_dim, out_dim, t_dim)
        self.norm = nn.GroupNorm(num_groups=1, num_channels=out_dim)
        self.attn = Attention(out_dim)

        if is_last:
            self.usample = nn.Conv2d(out_dim, in_dim, kernel_size=3, padding=1)
        else:
            self.usample = nn.Sequential(
                nn.Upsample(scale_factor=2, mode='nearest'),
                nn.Conv2d(out_dim, in_dim, kernel_size=3, padding=1)
            )

    def forward(self, x_BEHW, fmap1_BDHW, fmap2_BDHW, t_embd_BT):
        x_BFHW = torch.cat([x_BEHW, fmap1_BDHW], dim=1)
        x_BEHW = self.block1(x_BFHW, t_embd_BT)

        x_BFHW = torch.cat([x_BEHW, fmap2_BDHW], dim=1)
        x_BEHW = self.block2(x_BFHW, t_embd_BT)

        x_BEHW = self.attn(self.norm(x_BEHW)) + x_BEHW
        x_BDHoWo = self.usample(x_BEHW)

        return x_BDHoWo

### Debug

#### Case 1: `is_last = False`

In [8]:
B = 2
D = 32
H, W = 128, 128
E = 64
F_ = D + E
T_ = 128
Ho, Wo = H * 2, W * 2

torch.manual_seed(3985)
x_BEHW = torch.rand([B, E, H, W], device='cuda')
fmap1_BDHW = torch.rand([B, D, H, W], device=x_BEHW.device)
fmap2_BDHW = torch.rand([B, D, H, W], device=x_BEHW.device)
t_embd_BT = torch.rand([B, T_], device=x_BEHW.device)
model = UNetUpsample(D, E, T_).to(x_BEHW.device)

In [17]:
x_BFHW = torch.cat([x_BEHW, fmap1_BDHW], dim=1)
check_dims(x_BFHW, (B, F_, H, W))

In [18]:
x_BEHW = model.block1(x_BFHW, t_embd_BT)
check_dims(x_BEHW, (B, E, H, W))

In [19]:
x_BFHW = torch.cat([x_BEHW, fmap2_BDHW], dim=1)
check_dims(x_BFHW, (B, F_, H, W))

In [20]:
x_BEHW = model.block2(x_BFHW, t_embd_BT)
check_dims(x_BEHW, (B, E, H, W))

In [21]:
x_BEHW = model.attn(model.norm(x_BEHW)) + x_BEHW
check_dims(x_BEHW, (B, E, H, W))

In [22]:
x_BDHoWo = model.usample(x_BEHW)
check_dims(x_BDHoWo, (B, D, Ho, Wo))

In [9]:
check_dims(model(x_BEHW, fmap1_BDHW, fmap2_BDHW, t_embd_BT), (B, D, Ho, Wo))

#### Case 2: `is_last = True`

In [10]:
Ho, Wo = H, W
model = UNetUpsample(D, E, T_, is_last=True).to(x_BEHW.device)

In [11]:
check_dims(model(x_BEHW, fmap1_BDHW, fmap2_BDHW, t_embd_BT), (B, D, Ho, Wo))

## UNet Block

In [12]:
class UNetBlock(nn.Module):
    '''
    B: Batch size
    D: dim
    H, W: Last 2 dimensions of x
    T: t_dim
    '''
    def __init__(self, dim, t_dim):
        super().__init__()
        self.block1 = TimeResNetBlock(dim, dim, t_dim)
        self.block2 = TimeResNetBlock(dim, dim, t_dim)
        self.norm = nn.GroupNorm(num_groups=1, num_channels=dim)
        self.attn = Attention(dim)

    def forward(self, x_BDHW, t_embd_BT):
        x_BDHW = self.block1(x_BDHW, t_embd_BT)
        x_BDHW = self.block2(x_BDHW, t_embd_BT)
        x_BDHW = self.attn(self.norm(x_BDHW)) + x_BDHW
        return x_BDHW

#### Debug

In [14]:
B = 2
D = 32
H, W = 128, 128
T_ = 128

torch.manual_seed(3985)
x_BDHW = torch.rand([B, D, H, W], device='cuda')
t_embd_BT = torch.rand([B, T_], device=x_BDHW.device)
model = UNetBlock(D, T_).to(x_BDHW.device)
check_dims(model(x_BDHW, t_embd_BT), (B, D, H, W))

## UNet

In [15]:
class UNet(nn.Module):
    '''
    D: dim
    C: n_channels
    H, W: Last 2 dimensions of x
    T: t_dim
    F: dim // 2

    D0 = D
    H0 = H // 2
    W0 = W // 2
    
    D1 = D * 2
    H1 = H // 4
    W1 = W // 4
    
    D2 = D * 4
    H2 = H // 8
    W2 = W // 8
    
    D3 = D * 8
    '''
    def __init__(self, dim, n_channels):
        super().__init__()

        self.in_conv = nn.Conv2d(n_channels, dim, kernel_size=1, padding=0)

        t_dim = dim * 4
        amp = math.log(1e4) / (dim // 2 - 1)
        self.register_buffer(
            'freqs_F', torch.exp(torch.arange(dim//2) * -amp)
        )
        self.proj_t = nn.Sequential(
            nn.Linear(dim, t_dim),
            nn.GELU(),
            nn.Linear(t_dim, t_dim)
        )

        dim0 = dim
        dim1 = dim * 2
        dim2 = dim * 4
        dim3 = dim * 8

        self.dsample0 = UNetDownsample(dim0, dim0, t_dim)
        self.dsample1 = UNetDownsample(dim0, dim1, t_dim)
        self.dsample2 = UNetDownsample(dim1, dim2, t_dim)
        self.dsample3 = UNetDownsample(dim2, dim3, t_dim, is_last=True)

        self.mblock = UNetBlock(dim3, t_dim)
        
        self.usample3 = UNetUpsample(dim2, dim3, t_dim)
        self.usample2 = UNetUpsample(dim1, dim2, t_dim)
        self.usample1 = UNetUpsample(dim0, dim1, t_dim)
        self.usample0 = UNetUpsample(dim0, dim0, t_dim, is_last=True)

        self.out_resblk = TimeResNetBlock(2*dim, dim, t_dim)
        self.out_conv = nn.Conv2d(dim, n_channels, kernel_size=1)

    def forward(self, x_BCHW, t_B):
        pos_embd_BF = t_B.unsqueeze(1) * self.freqs_F.unsqueeze(0)
        t_embd_BD = torch.cat([pos_embd_BF.sin(), pos_embd_BF.cos()], dim=-1)
        t_embd_BT = self.proj_t(t_embd_BD)
    
        x_BDHW = self.in_conv(x_BCHW)
        r_BDHW = x_BDHW.clone()

        x_BD0H0W0, fmap2_BD0HW  , fmap1_BD0HW   = self.dsample0(x_BDHW   , t_embd_BT)
        x_BD1H1W1, fmap2_BD0H0W0, fmap1_BD0H0W0 = self.dsample1(x_BD0H0W0, t_embd_BT)
        x_BD2H2W2, fmap2_BD1H1W1, fmap1_BD1H1W1 = self.dsample2(x_BD1H1W1, t_embd_BT)
        x_BD3H2W2, fmap2_BD2H2W2, fmap1_BD2H2W2 = self.dsample3(x_BD2H2W2, t_embd_BT)
        x_BD3H2W2 = self.mblock(x_BD3H2W2, t_embd_BT)
        x_BD2H1W1 = self.usample3(x_BD3H2W2, fmap1_BD2H2W2, fmap2_BD2H2W2, t_embd_BT)
        x_BD1H0W0 = self.usample2(x_BD2H1W1, fmap1_BD1H1W1, fmap2_BD1H1W1, t_embd_BT)
        x_BD0HW   = self.usample1(x_BD1H0W0, fmap1_BD0H0W0, fmap2_BD0H0W0, t_embd_BT)
        x_BDHW    = self.usample0(x_BD0HW  , fmap1_BD0HW  , fmap2_BD0HW  , t_embd_BT)

        x_BD1HW = torch.cat([x_BDHW, r_BDHW], dim=1)
        x_BCHW = self.out_conv(self.out_resblk(x_BD1HW, t_embd_BT))

        return x_BCHW

### Debug

In [18]:
B = 2
C = 3
H = 128
W = 128
D = 32
T_ = 4 * D
F_ = D / 2

D0 = D
H0 = H // 2
W0 = W // 2

D1 = D * 2
H1 = H // 4
W1 = W // 4

D2 = D * 4
H2 = H // 8
W2 = W // 8

D3 = D * 8

In [19]:
torch.manual_seed(3985)
x_BCHW = torch.rand([2, C, H, W], device='cuda')
t_B = torch.randint(0, 1000, [B], device=x_BCHW.device)
model = UNet(dim=D, n_channels=C).to(x_BCHW.device)

In [50]:
pos_embd_BF = t_B.unsqueeze(1) * model.freqs_F.unsqueeze(0)
check_dims(pos_embd_BF, (B, F_))

In [51]:
t_embd_BD = torch.cat([pos_embd_BF.sin(), pos_embd_BF.cos()], dim=-1)
check_dims(t_embd_BD, (B, D))

In [52]:
t_embd_BT = model.proj_t(t_embd_BD)
check_dims(t_embd_BT, (B, T_))

In [53]:
x_BDHW = model.in_conv(x_BCHW)
check_dims(x_BDHW, (B, D, H, W))

In [54]:
r_BDHW = x_BDHW.clone()
check_dims(r_BDHW, (B, D, H, W))

In [55]:
x_BD0H0W0, fmap1_BD0HW, fmap2_BD0HW = model.dsample0(x_BDHW, t_embd_BT)
check_dims(x_BD0H0W0, (B, D0, H0, W0))
check_dims(fmap1_BD0HW, (B, D0, H, W))
check_dims(fmap2_BD0HW, (B, D0, H, W))

In [56]:
x_BD1H1W1, fmap1_BD0H0W0, fmap2_BD0H0W0 = model.dsample1(x_BD0H0W0, t_embd_BT)
check_dims(x_BD1H1W1, (B, D1, H1, W1))
check_dims(fmap1_BD0H0W0, (B, D0, H0, W0))
check_dims(fmap2_BD0H0W0, (B, D0, H0, W0))

In [57]:
x_BD2H2W2, fmap1_BD1H1W1, fmap2_BD1H1W1 = model.dsample2(x_BD1H1W1, t_embd_BT)
check_dims(x_BD2H2W2, (B, D2, H2, W2))
check_dims(fmap1_BD1H1W1, (B, D1, H1, W1))
check_dims(fmap2_BD1H1W1, (B, D1, H1, W1))

In [58]:
x_BD3H2W2, fmap1_BD2H2W2, fmap2_BD2H2W2 = model.dsample3(x_BD2H2W2, t_embd_BT)
check_dims(x_BD3H2W2, (B, D3, H2, W2))
check_dims(fmap1_BD2H2W2, (B, D2, H2, W2))
check_dims(fmap2_BD2H2W2, (B, D2, H2, W2))

In [59]:
x_BD3H2W2 = model.mblock(x_BD3H2W2, t_embd_BT)
check_dims(x_BD3H2W2, (B, D3, H2, W2))

In [60]:
x_BD2H1W1 = model.usample3(x_BD3H2W2, fmap1_BD2H2W2, fmap2_BD2H2W2, t_embd_BT)
check_dims(x_BD2H1W1, (B, D2, H1, W1))

In [61]:
x_BD1H0W0 = model.usample2(x_BD2H1W1, fmap1_BD1H1W1, fmap2_BD1H1W1, t_embd_BT)
check_dims(x_BD1H0W0, (B, D1, H0, W0))

In [63]:
x_BD0HW = model.usample1(x_BD1H0W0, fmap1_BD0H0W0, fmap2_BD0H0W0, t_embd_BT)
check_dims(x_BD0HW, (B, D0, H, W))

In [65]:
x_BDHW = model.usample0(x_BD0HW, fmap1_BD0HW, fmap2_BD0HW, t_embd_BT)
check_dims(x_BDHW, (B, D, H, W))

In [66]:
x_BD1HW = torch.cat([x_BDHW, r_BDHW], dim=1)
check_dims(x_BD1HW, (B, D1, H, W))

In [67]:
x_BCHW = model.out_conv(model.out_resblk(x_BD1HW, t_embd_BT))
check_dims(x_BCHW, (B, C, H, W))

In [21]:
check_dims(model(x_BCHW, t_B), (B, C, H, W))