<a href="https://colab.research.google.com/github/mitran27/GenerativeNetworks/blob/main/StableDiffusionPytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Diffusion Models**
* Generate high quality images

1) Uncoditional : Generate randomly

2) Conditional : (conditioned on) TEXT, IMAGE,CLASS




***Latent diffusion Model***

*  Instead of training the diffusion model in pixel space(2D) convert it to latent vector(1D) using VAE and do diffusion process after that convert bac to pixel space using Decoder

In [1]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn import Sequential,Module,ModuleList
from torch.nn import Conv2d,Identity,GroupNorm,Upsample

In [None]:
scale_factor = 0.18215#@param {type:"number"}



In [None]:
class ConvBlock(Module):
  def __init__(self,in_channels,out_channels):
    super().__init__()
    self.GN =  GroupNorm(num_groups=32,num_channels=in_channels)
    self.conv = nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=3,padding=1)

  def forward(self,x):
    x = self.GN(x)
    x = F.silu(x)
    x = self.conv(x)
    return x

class VAE_Padding(Module):
  def __init__(self):
    super().__init__()
  def forward(self,x):
    return F.pad(x,(0,1,0,1))

def VAE_Sample(x, e):

   # conver the feature vector to mean & log_variance
   mu, log_var = torch.chunk(x,2, dim=1) # channel dimension
   log_var = torch.clamp(log_var,-30,20)
   std = torch.exp(log_var).sqrt()

   y = mu + std * e


   #arrived at scale_factor = 0.18215 by averaging over a bunch of examples generated by the vae, in order to ensure they have unit variance with the variance taken over all dimensions simultaneously? And scale_factor = 1 / std(z)
   return y * scale_factor


class GatedGElU(Module):
  def __init__(self):
    super().__init__()
  def forward(self,x):
    x,gate = torch.chunk(x,2,dim=1)
    return x * F.gelu(gate)

In [None]:
class VAE_ResidualBlock():
  def __init__(self,in_channels, out_channels):
    super().__init__()

    self.conv_1 = ConvBlock(in_channels,out_channels)
    self.conv_2 = ConvBlock(out_channels,out_channels)
    if in_channels != out_channels:
      self.residual_layer = Conv2d(in_channels,out_channels,kernel_size=1,padding=0)
    else
      self.residual_layer = Identity()
  def forward(self,x):
    y = self.conv_1(x)
    y = self.conv_2(y)
    return y + self.residual_layer(x)

In [None]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(MultiHeadSelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (
            self.head_dim * heads == embed_size
        ), "Embedding size must be divisible by number of heads"

        self.qkv_linear = nn.Linear(embed_size, embed_size * 3, bias=False)

        # Output projection layer
        self.fc_out = nn.Linear(embed_size, embed_size)

    def forward(self, x, mask=None):
        N, seq_length, embed_size = x.shape

        qkv = self.qkv_linear(x)  # Shape: (N, seq_length, embed_size * 3)

        qkv = qkv.reshape(N, seq_length, 3, self.heads, self.head_dim)
        queries, keys, values = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2] # along dim

        # Compute attention scores using queries and keys
        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])  # (N, heads, query_len, key_len)

        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        attention = torch.softmax(energy / (self.embed_size ** 0.5), dim=3)  # (N, heads, query_len, key_len)

        # Aggregate values based on attention
        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, seq_length, self.heads * self.head_dim
        )

        # Output projection
        out = self.fc_out(out)
        return out


class MultiHeadCrossAttention(nn.Module):
    def __init__(self, embed_size, n_heads, d_cross, in_proj_bias=False,out_proj_bias = False):
        super(MultiHeadSelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (
            self.head_dim * heads == embed_size
        ), "Embedding size must be divisible by number of heads"

        self.q_linear = nn.Linear(embed_size, embed_size , bias=False)
        self.k_linear = nn.Linear(embed_size, embed_size , bias=False)
        self.v_linear = nn.Linear(embed_size, embed_size , bias=False)


        # Output projection layer
        self.fc_out = nn.Linear(embed_size, embed_size)
    def forward(self,x,context):

        N, seq_length, embed_size = x.shape
        k = self.k_linear(context)
        q = self.q_linear(x)
        v = self.v_linear(context)

        energy = torch.einsum("nqhd,nkhd->nhqk", [q, k])

        attention = torch.softmax(energy / (self.embed_size ** 0.5), dim=3)  # (N, heads, query_len, key_len)

        # Aggregate values based on attention
        out = torch.einsum("nhql,nlhd->nqhd", [attention, v]).reshape(
            N, seq_length, self.heads * self.head_dim
        )

        # Output projection
        out = self.fc_out(out)
        return out


  class VAE_AttentionBlock():
    def __init__(self,in_channels):
      super().__init__()
      self.attention = MultiHeadSelfAttention(in_channels,4)
    def forward(self,x):
      # convert pixel space to linear space
      b,c,h,w = x.shape
      y = x.view(b,c,h*w)
      y = y.permute(0,2,1) # text are batch length feats
      y = self.attention(y)
      # convert linear space to pixel space
      y = y.permute(0,2,1)
      y = y.view(b,c,h,w)
      return y

In [None]:
class VAE_Encoder(Sequential):
  def __init__(self):
    # Feature Extraction
    layers = [

        Conv2d(in_channels=3,out_channels=128,kernel_size=3,padding=1),
        VAE_ResidualBlock(128,128),
        VAE_ResidualBlock(128,128),

        VAE_Padding(),
        Conv2d(in_channels=128,out_channels=128,kernel_size=3,stride=2),
        VAE_ResidualBlock(128,256),
        VAE_ResidualBlock(256,256),

        VAE_Padding(),
        Conv2d(in_channels=256,out_channels=256,kernel_size=3,stride=2),
        VAE_ResidualBlock(256,512),
        VAE_ResidualBlock(512,512),

        VAE_Padding(),
        Conv2d(in_channels=512,out_channels=512,kernel_size=3,stride=2),
        VAE_ResidualBlock(512,512),
        VAE_ResidualBlock(512,512),
        VAE_ResidualBlock(512,512),

        VAE_AttentionBlock(512),
        VAE_ResidualBlock(512,512),

        ConvBlock(512,8)
        Conv2D(8,8,kernel_size=1,padding=0)#conv Linear
    ]
    super().__init__()

In [None]:
class VAE_Decoder(Sequential):
  def __init__(self):
    # Feature Extraction
    layers = [
         # encoder out put was 8 chunk and resample so channel dimension is 4
        Conv2D(4,4,kernel_size=1,padding=0)#conv Linear

        Conv2D(4,512,kernel_size=3,padding=1)

        VAE_ResidualBlock(512,512),
        VAE_AttentionBlock(512),

        VAE_ResidualBlock(512,512),
        VAE_ResidualBlock(512,512),
        VAE_ResidualBlock(512,512),

        VAE_ResidualBlock(512,512),

        Upsample(scale_factor=2),
        Conv2D(512,512,kernel_size=3,padding=1),
        VAE_ResidualBlock(512,512),
        VAE_ResidualBlock(512,512),
        VAE_ResidualBlock(512,512),

        Upsample(scale_factor=2),
        Conv2D(512,512,kernel_size=3,padding=1),
        VAE_ResidualBlock(512,256),
        VAE_ResidualBlock(256,256),
        VAE_ResidualBlock(256,256),

        Upsample(scale_factor=2),
        Conv2D(256,256,kernel_size=3,padding=1),
        VAE_ResidualBlock(256,128),
        VAE_ResidualBlock(128,128),
        VAE_ResidualBlock(128,128),

        ConvBlock(128,3)
    ]
    super().__init__(*layers)

  def forward(self,x):
    x/= scale_factor
    return super().forward(x)


In [None]:
class TransformerBlock(Module):
  def __init__(self,embed_size,heads):
    super().__init__()
    self.attention = MultiHeadSelfAttention(embed_size,heads)
    self.norm1 = nn.LayerNorm(embed_size)
    self.norm2 = nn.LayerNorm(embed_size)
    self.mlp = Sequential(
        nn.Linear(embed_size,embed_size*4),
        nn.GElU(),
        nn.Linear(embed_size*4,embed_size)
    )
  def forward(self,x):
    y = self.norm1(x)
    y = self.attention(x)
    y = x + y

    y1 = self.norm2(y)
    y1 = self.mlp(y1)
    y1 = y + y1
    return y1
class CLIPTransformer():

  def __init__(self,n_voc, n_seq, n_heads, n_embedding, n_layers):
    self.token_embedding = nn.Embedding(n_voc,n_embedding)
    self.position_embedding = nn.Parameter(torch.zeros(1,n_seq,n_embedding))
    self.layers = nn.ModuleList([TransformerBlock(n_embedding,n_heads) for _ in range(n_layers)])
    self.norm = nn.LayerNorm(n_embedding)
  def forward(self,x):
    y = self.token_embedding(x)
    y = y + position_embedding
    for layer in self.layers:
      y = layer(y)
    y = self.norm(y)
    return y

In Conditional Diffusion Models
while passing to Unet : unet consit of downsampling , bottle neck and upsample

All have Residual Attention Convolution in their block for feature processing

Residual Attention contains

1) Mlp for timestamp embedding with out_channel features

2) add with the convultionalized features

3) pass the features and propmt context to cross attention

***Cross Attention***

* query -> diffusion input/latent/noise

* Key, value -> prompt/context vector (text embedding or image embedding)

*

In [None]:
class DiffusionResNet(Module): # handle time step of diffusion
  def __init__(self,in_channels,out_channels,n_time = 1280):
    self.conv_1 = ConvBlock(in_channels,out_channels)
    self.conv_merged = ConvBlock(out_channels,out_channels)
    if in_channels != out_channels:
      self.residual_layer = Conv2d(in_channels,out_channels,kernel_size=1,padding=0)
    else
      self.residual_layer = Identity()
    self.time_linear = Linear(1280,out_channels)

  def forward(self,x,t):
    y = self.conv_1(x)
    time_emb = self.time_linear(F.Silu(t))
    y = y + time_emb.unsqueeze(-1).unsqueeze(-1)
    y = self.conv_merged(y)
    return y + self.residual_layer(x)

class DiffusionAttention(Module): # handle context(embedding of prompt) to condition the diffusion process
  def __init__(self, n_heads, n_embedding, d_context):
    super().__init__()
    channels = n_heads *
    self.conv_block = ConvBlock(n_embedding,n_embedding)


    self.norm_self_attn = nn.LayerNorm(n_embedding)
    self.norm_cross_attn = nn.LayerNorm(n_embedding)

    self.self_attn = MultiHeadSelfAttention(n_embedding,n_heads)
    self.cross_attn = MultiHeadCrossAttention(n_embedding,n_heads)


    self.norm_mlp = nn.LayerNorm(n_embedding)
    self.mlp = Sequential(
        nn.Linear(n_embedding,n_embedding*4*2),
        GatedGElU(),
        nn.Linear(n_embedding*4,n_embedding)
    )

    self.conv_out = ConvBlock(n_embedding,n_embedding,kernel_size=1,padding=0)

  def forward(self,x,context):

    y = self.conv_block(x)

    b,c,h,w = y.shape
    y = y.view((b,c,h*w))
    y = y.permute(0,2,1) # batch length feats

    residual = y
    y = self.norm_self_attn(y)
    y = self.self_attn(y)
    y = y + residual

    residual = y
    y = self.norm_cross_attn(y)
    y = self.cross_attn(y,context)
    y = y + residual

    residual = y
    y = self.norm_mlp(y)
    y = self.mlp(y)
    y = y + residual

    y = y.permute(0,2,1) # batch features length
    y = y.view((b,c,h,w))
    return self.conv_out(y) + x

class ResidualAttention(Module):
  def __init__(self,in_channel,out_channel,attn_dim=None,attn_heads=None,upsample=False):
    self.conv = DiffusionResNet(in_channel,out_channel)

    self.upsample = Sequential([
        Upsample(scale_factor=2),
        Conv2D(out_channel,out_channel,kernel_size=3,padding=1)]) if upsample else None

    self.attn = DiffusionAttention(attn_dim,attn_heads) if attn_dim is not None else None

  def forward(self,x,time,context):

    y = self.conv(x,time)

    if exists(self.attn):
     y = self.attn(y,context)

    if exists(self.upsample):
     y = self.upsample(y)
    return y

In [None]:
class DiffusionUnet(Module):
  def __init__(self):

    self.encoder = ModuleList([

        Conv2D(4,320,kernel_size=3,padding=1),
        ResidualAttention(320,320,40,8),
        ResidualAttention(320,320,40,8),

        Conv2D(320,320,kernel_size=3,stride=2,padding=1),
        ResidualAttention(320,640,80,8),
        ResidualAttention(640,640,80,8),

        Conv2D(640,640,kernel_size=3,stride=2,padding=1),
        ResidualAttention(640,1280,160,8),
        ResidualAttention(1280,1280,160,8),

        Conv2D(1280,1280,kernel_size=3,stride=2,padding=1),
        ResidualAttention(1280,1280),
        ResidualAttention(1280,1280),
        ])

    self.bottle_neck = ModuleList([
            ResidualAttention(1280,1280,160,8),
            ResidualAttention(1280,1280)
    ])

    self.decoders = ModuleList([

            ResidualAttention(2560,1280),
            ResidualAttention(2560,1280),
            ResidualAttention(2560,1280,upsample = True),

            ResidualAttention(2560,1280,160,8),
            ResidualAttention(2560,1280,160,8),
            ResidualAttention(1920,1280,160,8,upsample=True),

            ResidualAttention(1920,640,80,8)
            ResidualAttention(1280,640,80,8)
            ResidualAttention(960,640,80,8, upsample = True),

            ResidualAttention(960,320,40,8)
            ResidualAttention(640,320,40,8)
            ResidualAttention(480,320,40,8, upsample = True),

        ])

In [None]:
class DiffusionModel(Module):
  def __init__(self):
    super().__init__()


    self.unet = DiffusionUnet
    self.output = ConvBlock(320,4)

In [None]:
# prompt: y + time_emb[:, :, None, None]   write more proper form for this , use efficient function

y + time_emb.unsqueeze(-1).unsqueeze(-1)
