In [125]:
import torch
from torch import nn
import torch.nn.functional as F
import math

In [126]:
"GPU is there" if torch.cuda.is_available() else Exception("GPU is missing") 

'GPU is there'

### Build Conditional UNet from Lecture 24

In [127]:
def noop(x):
    return x

In [128]:
def timestep_embedding(tsteps, emb_dim, max_period=10000):
    exponent = -math.log(max_period) * torch.linspace(0, 1, emb_dim//2, device=tsteps.device)
    emb = tsteps[:,None].float() * exponent.exp()[None,:]
    emb = torch.cat([emb.sin(), emb.cos()], dim=-1)
    return F.pad(emb, (0,1,0,0)) if emb_dim%2==1 else emb

In [129]:
def lin(ni, nf, act=nn.SiLU, norm=None, bias=True):
    layers = nn.Sequential()
    if norm: layers.append(norm(ni))
    if act : layers.append(act())
    layers.append(nn.Linear(ni, nf, bias=bias))
    return layers

In [130]:
def pre_conv(ni, nf, ks=3, stride=1, act=nn.SiLU, norm=None, bias=True):
    layers = nn.Sequential()
    if norm: layers.append(norm(ni))
    if act : layers.append(act())
    layers.append(nn.Conv2d(ni, nf, stride=stride, kernel_size=ks, padding=ks//2, bias=bias))
    return layers

In [131]:
def heads_to_batch(x, heads):
    n,sl,d = x.shape
    x = x.reshape(n, sl, heads, -1)
    return x.transpose(2, 1).reshape(n*heads,sl,-1)

def batch_to_heads(x, heads):
    n,sl,d = x.shape
    x = x.reshape(-1, heads, sl, d)
    return x.transpose(2, 1).reshape(-1,sl,d*heads)

class SelfAttention(nn.Module):
    def __init__(self, ni, attn_chans, transpose=True):
        super().__init__()
        self.nheads = ni//attn_chans
        self.scale = math.sqrt(ni/self.nheads)
        self.norm = nn.LayerNorm(ni)
        self.qkv = nn.Linear(ni, ni*3)
        self.proj = nn.Linear(ni, ni)
        self.t = transpose
    
    def forward(self, x):
        n,c,s = x.shape
        if self.t: x = x.transpose(1, 2)
        x = self.norm(x)
        x = self.qkv(x)
        if self.nheads != 1: x = heads_to_batch(x, self.nheads)
        q,k,v = torch.chunk(x, 3, dim=-1)
        s = (q@k.transpose(1,2))/self.scale
        x = s.softmax(dim=-1)@v
        if self.nheads != 1: x = batch_to_heads(x, self.nheads)
        x = self.proj(x)
        if self.t: x = x.transpose(1, 2)
        return x

class SelfAttention2D(SelfAttention):
    def forward(self, x):
        n,c,h,w = x.shape
        return super().forward(x.view(n, c, -1)).reshape(n,c,h,w)

In [132]:
class EmbResBlock(nn.Module):
    def __init__(self, n_emb, ni, nf=None, ks=3, act=nn.SiLU, norm=nn.BatchNorm2d, attn_chans=0):
        super().__init__()
        if nf is None: nf = ni
        self.emb_proj = nn.Linear(n_emb, nf*2)
        self.conv1 = pre_conv(ni, nf, ks, act=act, norm=norm)
        self.conv2 = pre_conv(nf, nf, ks, act=act, norm=norm)
        self.idconv = noop() if ni==nf else nn.Conv2d(ni, nf, 1)
        self.attn = False
        if attn_chans: self.attn = SelfAttention2D(nf, attn_chans)

    def forward(self, x, emb):
        inp = x
        x = self.conv1(x)
        emb = self.emb_proj(F.silu(emb))[:, :, None, None]
        scale, shift = torch.chunk(emb, 2, dim=1)
        x = x*(1+scale) + shift
        x = self.conv2(x)
        x = x + self.idconv(inp)
        if self.attn: x = x + self.attn(x)
        return x

In [133]:
def saved(m, blk):
    m_ = m.forward

    @wraps(m.forward)
    def _f(*args, **kwargs):
        res = m_(*args, **kwargs)
        blk.saved.append(res)
        return res

    m.forward = _f
    return m

In [134]:
class DownBlock(nn.Module):
    def __init__(self, n_emb, ni, nf, add_down=True, num_layers=1, attn_chans=0):
        super().__init__()
        self.resnets = nn.ModuleList([saved(EmbResBlock(n_emb, ni if i==0 else nf, nf, attn_chans=attn_chans), self)
                                      for i in range(num_layers)])
        self.down = saved(nn.Conv2d(nf, nf, 3, stride=2, padding=1), self) if add_down else nn.Identity()

    def forward(self, x, emb):
        self.saved = []
        for resnet in self.resnets: x = resnet(x, emb)
        x = self.down(x)
        return x

In [135]:
def upsample(nf): return nn.Sequential(nn.Upsample(scale_factor=2.), nn.Conv2d(nf, nf, 3, padding=1))

class UpBlock(nn.Module):
    def __init__(self, n_emb, ni, prev_nf, nf, add_up=True, num_layers=2, attn_chans=0):
        super().__init__()
        self.resnets = nn.ModuleList(
            [EmbResBlock(n_emb, (prev_nf if i==0 else nf)+(ni if (i==num_layers-1) else nf), nf, attn_chans=attn_chans)
            for i in range(num_layers)])
        self.up = upsample(nf) if add_up else nn.Identity()

    def forward(self, x, t, ups):
        for resnet in self.resnets: x = resnet(torch.cat([x, ups.pop()], dim=1), t)
        return self.up(x)

In [136]:
class CondUNetModel(nn.Module):
    def __init__( self, n_classes, in_channels=3, out_channels=3, nfs=(224,448,672,896), num_layers=1):
        super().__init__()
        self.conv_in = nn.Conv2d(in_channels, nfs[0], kernel_size=3, padding=1)
        self.n_temb = nf = nfs[0]
        n_emb = nf*4
        self.cond_emb = nn.Embedding(n_classes, n_emb)
        self.emb_mlp = nn.Sequential(lin(self.n_temb, n_emb, norm=nn.BatchNorm1d),
                                     lin(n_emb, n_emb))
        self.downs = nn.ModuleList()
        for i in range(len(nfs)):
            ni = nf
            nf = nfs[i]
            self.downs.append(DownBlock(n_emb, ni, nf, add_down=i!=len(nfs)-1, num_layers=num_layers))
        self.mid_block = EmbResBlock(n_emb, nfs[-1])

        rev_nfs = list(reversed(nfs))
        nf = rev_nfs[0]
        self.ups = nn.ModuleList()
        for i in range(len(nfs)):
            prev_nf = nf
            nf = rev_nfs[i]
            ni = rev_nfs[min(i+1, len(nfs)-1)]
            self.ups.append(UpBlock(n_emb, ni, prev_nf, nf, add_up=i!=len(nfs)-1, num_layers=num_layers+1))
        self.conv_out = pre_conv(nfs[0], out_channels, act=nn.SiLU, norm=nn.BatchNorm2d, bias=False)

    def forward(self, inp):
        x,t,c = inp
        temb = timestep_embedding(t, self.n_temb)
        cemb = self.cond_emb(c)
        emb = self.emb_mlp(temb) + cemb
        x = self.conv_in(x)
        saved = [x]
        for block in self.downs: x = block(x, emb)
        saved += [p for o in self.downs for p in o.saved]
        x = self.mid_block(x, emb)
        for block in self.ups: x = block(x, emb, saved)
        return self.conv_out(x)

### Understanding Embs from Lecture 24

In [137]:
# Filters for each layer of UNet
nfs = (224,448,672,896)

In [138]:
# Embedding sizes
n_temb = nf = nfs[0]
n_emb = nf*4

In [139]:
# Set up testing data
x0 = torch.randn(64,3,28,28) # mock batch of images
n = len(x0) # batch of 64
t = torch.rand(n,).to(x0).clamp(0,0.999) # 64 random timesteps

In [140]:
t.shape, x0.shape

(torch.Size([64]), torch.Size([64, 3, 28, 28]))

In [141]:
# Get timestep embs
t_emb = timestep_embedding(t, nf)
t_emb.shape

torch.Size([64, 224])

In [142]:
# Pass through linear layers to get same size as conditional embs
emb_mlp = nn.Sequential(lin(n_temb, n_emb, norm=nn.BatchNorm1d), lin(n_emb, n_emb))
t_emb_mlp = emb_mlp(t_emb)
t_emb_mlp.shape

torch.Size([64, 896])

In [143]:
# Get conditional embs
n_classes = 10
cond_emb = nn.Embedding(n_classes, n_emb)
class_id = torch.tensor(0)
c_emb = cond_emb(class_id)
c_emb.shape

torch.Size([896])

In [144]:
# Get combined embs
emb = t_emb_mlp + c_emb
emb.shape

torch.Size([64, 896])

In [145]:
# Silu + Linear layer
linear = nn.Linear(n_emb, nf*2) # Double it here to get enough embeddings for scale and shift!
emb2 = linear(F.silu(emb))[:, :, None, None]
emb2.shape

torch.Size([64, 448, 1, 1])

In [146]:
# Get transformations from embs
scale, shift = torch.chunk(emb2, 2, dim=1)
scale.shape, shift.shape

(torch.Size([64, 224, 1, 1]), torch.Size([64, 224, 1, 1]))

In [147]:
# Transforms are meant to end up having same num_filters that x will take on 
# after being passed through conv1 in ResBlock (e.g. 224 in this case)
ni = x0.shape[1]
ks = 3
conv1 = pre_conv(ni, nf, ks, act=nn.SiLU, norm=nn.BatchNorm2d)
x = conv1(x0)
x.shape

torch.Size([64, 224, 28, 28])

In [148]:
# Apply transformations
x = x*(1+scale) + shift # From what I understand, this can be subject to creativity
x.shape

torch.Size([64, 224, 28, 28])

In [149]:
# Finish ResBlock
conv2 = pre_conv(nf, nf, ks, act=nn.SiLU, norm=nn.BatchNorm2d)
x = conv2(x)
convid = nn.Conv2d(ni, nf, 1)
x = conv2(x)
x.shape

torch.Size([64, 224, 28, 28])