In [1]:
from torch.nn import Module,Sequential,Conv2d,GroupNorm,SiLU,Identity,Linear,LayerNorm,MultiheadAttention,ConvTranspose2d,ModuleList
import torch

In [2]:
def convBlock(in_dim,out_dim):
  return Sequential(Conv2d(in_dim,out_dim,3,padding=1),GroupNorm(8,out_dim),SiLU())
def timeEmeddingMLP(dim):
  return Sequential(Linear(dim,dim*4),SiLU(),Linear(dim*4,dim))
def FCN(dim):
  return Sequential(Linear(dim,dim*4),SiLU(),Linear(dim*4,dim))



class ResnetLayer(Module):
  def __init__(self, in_dim, out_dim, time_dim=None):
    super().__init__()

    self.block1 = convBlock(in_dim, out_dim)
    self.block2 = convBlock(out_dim, out_dim)

    self.resBlock = convBlock(in_dim,out_dim)  if(in_dim!=out_dim) else Identity()
    self.timeLinear = Sequential(SiLU(), Linear(time_dim, out_dim)) if time_dim else None

  def forward(self,X,timeLatent=None):
    y = self.block1(X)
    if(self.timeLinear):
      y = y + self.timeLinear(timeLatent).unsqueeze(-1).unsqueeze(-1)
    y = self.block2(y)

    y = y + self.resBlock(X)

    return y

class SinusoidalEmbeddings(Module):
  def __init__(self, dim):

    super().__init__()

    self.dim = dim
  def forward(self,timestep):

    dim = self.dim//2
    freqs = torch.pow(10000, -torch.arange(dim, dtype=torch.float32) / dim)
    x = torch.tensor([timestep], dtype=torch.float32)
    x = x[:, None] * freqs[None]
    return torch.cat([torch.cos(x), torch.sin(x)], dim=-1)






In [3]:
class UnetAttention(Module):
  def __init__(self, dim, heads):
    super().__init__()

    self.layerNorm1 = LayerNorm(dim)
    self.layerNorm2 = LayerNorm(dim)

    self.conv_in = Conv2d(dim, dim ,kernel_size=1, padding=0)
    self.conv_out = Conv2d(dim, dim ,kernel_size=1, padding=0)

    self.qkv_linear = Linear(dim,dim*3)
    self.MHA = MultiheadAttention(dim , heads)
    self.fcn = FCN(dim)

  def attn(self,X):
    q,k,v = self.qkv_linear(X).chunk(3,dim=-1)
    return self.MHA(q,k,v)[0]
  def forward(self,X):

    y = self.conv_in(X)

    # convert image format to sequence format

    b,c,h,w = y.shape

    y = y.view((b, c, h * w))
    y_r = y.transpose(-1,-2)

    y = self.layerNorm1(y_r)
    y = self.attn(y)

    y_r = y + y_r

    y = self.layerNorm2(y_r)
    y = self.fcn(y)

    y = y + y_r

    y  = y.transpose(-1,-2)
    y = y.view((b, c, h , w))

    y = self.conv_out(y)

    # transformer under unet have property to add the input at end

    return y + X
class SamplingBlock(Module):
  def __init__(self,indim, outdim, timedim, no_heads):
    super().__init__()

    self.res_block1 = ResnetLayer(indim, outdim, timedim)
    self.res_block2 = ResnetLayer(outdim, outdim, timedim)
    self.attn_block = UnetAttention(outdim,no_heads)
  def forward(self, X, time):
    y = self.res_block1(X,time)
    y = self.res_block2(y,time)
    y = self.attn_block(y)

    return y

In [4]:
class downSamplingBlock(SamplingBlock):
  def __init__(self,indim, outdim, timedim, no_heads, down_sample=False):
    super().__init__(indim, outdim, timedim, no_heads)
    self.down_sample =Conv2d(outdim, outdim, kernel_size=4, stride=2, padding=1) if down_sample else None


  def forward(self,X,time):
    y = super().forward(X,time)
    if(self.down_sample):
      y = self.down_sample(y)
    return y

class upSamplingBlock(SamplingBlock):
  def __init__(self,indim, outdim, timedim, no_heads, up_sample=False):
    super().__init__(indim, outdim, timedim, no_heads)
    self.up_sample =ConvTranspose2d(outdim, outdim, kernel_size=4, stride=2, padding=1) if up_sample else None


  def forward(self,X,time):
    y = super().forward(X,time)
    if(self.up_sample):
      y = self.up_sample(y)
    return y

In [5]:

def bottleNeckBlock(dim, heads):
  return Sequential(
      ResnetLayer(dim, dim, None),
      UnetAttention(dim,heads),
      ResnetLayer(dim, dim, None),
  )

In [6]:
class UNET(Module):
  def __init__(self, config):

    super().__init__()

    self.timeEmbedding  = Sequential(SinusoidalEmbeddings(config.base_dim),timeEmeddingMLP(config.base_dim))

    self.init_conv =Conv2d(config.channels, config.base_dim, 7, padding=3)

    self.encoder = ModuleList([ downSamplingBlock(indim, outdim,config.base_dim, config.heads,index!=len(config.downsamplingDimensions)) for index,(indim, outdim) in enumerate(config.downsamplingDimensions,start=1) ])
    self.bottleNeck = bottleNeckBlock(config.downsamplingDimensionsLast,config.heads)
    self.decoder = ModuleList([ upSamplingBlock(indim, outdim,config.base_dim, config.heads,index!=1) for index,(indim, outdim) in enumerate(config.upsamplingDimensions,start=1) ])

    outdim = config.upsamplingDimensionsLast
    self.finalConv = Sequential(
        convBlock(outdim, outdim),
        Conv2d(outdim,config.channels,3,1)
    )

  def forward(self,X, time_step):

    time_Embedding = self.timeEmbedding(time_step)
    skip_connections = []

    y = self.init_conv(X)
    for layer in self.encoder:
      y = layer(y,time_Embedding)
      skip_connections.append(y)

    y = self.bottleNeck(y)

    for layer in self.decoder:
      print(y.shape,skip_connections[-1].shape)
      y = torch.cat([y, skip_connections.pop()],dim=1)
      y = layer(y,time_Embedding)
      print(y.shape)

    y = self.finalConv(y)

    return y






In [7]:
class DiffusionModelConfig():
  base_dim = 64

  downsamplingDimensions = [(64,128),(128,256),(256,512),(512,1024)]
  downsamplingDimensionsLast = 1024
  upsamplingDimensions = [(2048,1024),(1536,512),(768,256),(384,128)]
  upsamplingDimensionsLast = 128

  channels = 3
  heads = 4







In [8]:
model = UNET(DiffusionModelConfig)
X = torch.randn(1,3,256,256)