<a href="https://colab.research.google.com/github/mitran27/GenerativeNetworks/blob/main/DiffusionModel.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
from torch.nn import Module,Sequential,Conv2d,GroupNorm,SiLU,Identity,Linear,LayerNorm,MultiheadAttention,ConvTranspose2d

In [None]:
def convBlock(in_dim,out_dim):
  return Sequential([Conv2d(in_dim,out_dim,3,padding=1),GroupNorm(8,out_dim),SiLU()])
def timeEmeddingMLP(dim):
  return Sequential([Linear(dim,dim*4),SiLU(),Linear(dim*4,dim)])
def FCN(dim):
  return Sequential([Linear(dim,dim*4),SiLU(),Linear(dim*4,dim)])



class ResnetLayer(Module):
  def __init__(self, in_dim, out_dim, time_dim=None):

    self.block1 = convBlock(in_dim, out_dim)
    self.block2 = convBlock(in_dim, out_dim)

    self.resBlock = convBlock(out_dim)  if(in_dim!=out_dim) else Identity
    self.timeLinear = Sequential(SiLU(), Linear(time_dim, out_dim)) if time_dim else None

  def forward(self,X,timeLatent):
    y = self.block1(X)
    if(self.timeEmbedding):
      y = y + self.timeLinear(timeLatent)
    y = self.block2(y)

    y = y + self.resBlock(X)

    return y



In [None]:
class UnetAttention(Module):
  def __init__(self, dim, heads):

    self.layerNorm1 = LayerNorm(dim)
    self.layerNorm2 = LayerNorm(dim)

    self.conv_in = Conv2d(dim, dim ,kernel_size=1, padding=0)
    self.conv_out = Conv2d(dim, dim ,kernel_size=1, padding=0)

    self.attn = MultiheadAttention(dim , heads)
    self.fcn = FCN(dim)


  def forward(self,X,timeStamp):

    y = self.conv_in(X)

    # convert image format to sequence format

    b,c,h,w = y.shape

    y = y.view((b, c, h * w))
    y_r = y.transpose(-1,-2)

    y = self.layerNorm1(y_r)
    y = self.attn(y)

    y_r = y + y_r

    y = self.layerNorm2(y_r)
    y = self.fcn(y)

    y = y + y_r

    y  = y.transpose(-1,-2)
    y = y.view((b, c, h , w))

    y = self.conv_out(y)

    # transformer under unet have property to add the input at end

    return y + X


In [None]:
def SamplingBlock(indim, outdim, timedim, no_heads):
  return Sequential([
      ResnetLayer(indim, outdim, timedim),
      ResnetLayer(outdim, outdim, timedim),
      UnetAttention(outdim,no_heads),
  ])
def downSamplingBlock(indim, outdim, timedim, no_heads, down_sample=False):
  model = SamplingBlock()
  if(down_sample):
     model.add(Conv2d(outdim, outdim, kernel_size=3, stride=2, padding=1))
  return model
def upSamplingBlock(indim, outdim, timedim, no_heads, upsample = False):
  model = SamplingBlock()
  if(upsample):
     model.add(ConvTranspose2d(outdim, outdim, kernel_size=3, stride=2, padding=1))
  return model

def bottleNeckBlock(dim, heads):
  return Sequential([
      ResnetLayer(dim, dim, None),
      UnetAttention(dim,heads),
      ResnetLayer(dim, dim, None),
  ])

In [None]:
class UnetDownBlock(Module):
  def __init__(self):
    pass
  def forward(self,x):
    #