<a href="https://colab.research.google.com/github/mitran27/GenerativeNetworks/blob/main/StableDiffusionPytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Diffusion Models**
* Generate high quality images

1) Uncoditional : Generate randomly

2) Conditional : (conditioned on) TEXT, IMAGE,CLASS




***Latent diffusion Model***

*  Instead of training the diffusion model in pixel space(2D) convert it to latent vector(1D) using VAE and do diffusion process after that convert bac to pixel space using Decoder

In [1]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn import Sequential,Module
from torch.nn import Conv2d,Identity,GroupNorm,Upsample

In [None]:
class ConvBlock(Module):
  def __init__(self,in_channels,out_channels):
    super().__init__()
    self.GN =  GroupNorm(num_groups=32,num_channels=in_channels)
    self.conv = nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=3,padding=1)

  def forward(self,x):
    x = self.GN(x)
    x = F.silu(x)
    x = self.conv(x)
    return x

class VAE_Padding(Module):
  def __init__(self):
    super().__init__()
  def forward(self,x):
    return F.pad(x,(0,1,0,1))

def VAE_Sample(x, e):

   # conver the feature vector to mean & log_variance
   mu, log_var = torch.chunk(x,2, dim=1) # channel dimension
   log_var = torch.clamp(log_var,-30,20)
   std = torch.exp(log_var).sqrt()

   y = mu + std * e


   #arrived at scale_factor = 0.18215 by averaging over a bunch of examples generated by the vae, in order to ensure they have unit variance with the variance taken over all dimensions simultaneously? And scale_factor = 1 / std(z)
   scale_factor = 0.18215
   return y * scale_factor


In [None]:
class VAE_ResidualBlock():
  def __init__(self,in_channels, out_channels):
    super().__init__()

    self.conv_1 = ConvBlock(in_channels,out_channels)
    self.conv_2 = ConvBlock(out_channels,out_channels)
    if in_channels != out_channels:
      self.residual_layer = Conv2d(in_channels,out_channels,kernel_size=1,padding=0)
    else
      self.residual_layer = Identity()
  def forward(self,x):
    y = self.conv_1(x)
    y = self.conv_2(y)
    return y + self.residual_layer(x)

In [None]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(MultiHeadSelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (
            self.head_dim * heads == embed_size
        ), "Embedding size must be divisible by number of heads"

        self.qkv_linear = nn.Linear(embed_size, embed_size * 3, bias=False)

        # Output projection layer
        self.fc_out = nn.Linear(embed_size, embed_size)

    def forward(self, x, mask=None):
        N, seq_length, embed_size = x.shape

        qkv = self.qkv_linear(x)  # Shape: (N, seq_length, embed_size * 3)

        qkv = qkv.reshape(N, seq_length, 3, self.heads, self.head_dim)
        queries, keys, values = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]

        # Compute attention scores using queries and keys
        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])  # (N, heads, query_len, key_len)

        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        attention = torch.softmax(energy / (self.embed_size ** 0.5), dim=3)  # (N, heads, query_len, key_len)

        # Aggregate values based on attention
        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, seq_length, self.heads * self.head_dim
        )

        # Output projection
        out = self.fc_out(out)
        return out


  class VAE_AttentionBlock():
    def __init__(self,in_channels):
      super().__init__()
      self.attention = MultiHeadSelfAttention(in_channels,4)
    def forward(self,x):
      # convert pixel space to linear space
      b,c,h,w = x.shape
      y = x.view(b,c,h*w)
      y = y.permute(0,2,1) # text are batch length feats
      y = self.attention(y)
      # convert linear space to pixel space
      y = y.permute(0,2,1)
      y = y.view(b,c,h,w)
      return y

In [None]:
class VAE_Encoder(Sequential):
  def __init__(self):
    # Feature Extraction
    layers = [

        Conv2d(in_channels=3,out_channels=128,kernel_size=3,padding=1),
        VAE_ResidualBlock(128,128),
        VAE_ResidualBlock(128,128),

        VAE_Padding(),
        Conv2d(in_channels=128,out_channels=128,kernel_size=3,stride=2),
        VAE_ResidualBlock(128,256),
        VAE_ResidualBlock(256,256),

        VAE_Padding(),
        Conv2d(in_channels=256,out_channels=256,kernel_size=3,stride=2),
        VAE_ResidualBlock(256,512),
        VAE_ResidualBlock(512,512),

        VAE_Padding(),
        Conv2d(in_channels=512,out_channels=512,kernel_size=3,stride=2),
        VAE_ResidualBlock(512,512),
        VAE_ResidualBlock(512,512),
        VAE_ResidualBlock(512,512),

        VAE_AttentionBlock(512),
        VAE_ResidualBlock(512,512),

        ConvBlock(512,8)
        Conv2D(8,8,kernel_size=1,padding=0)#conv Linear
    ]
    super().__init__()

In [None]:
class VAE_Decoder(Sequential):
  def __init__(self):
    # Feature Extraction
    layers = [

        Conv2d(in_channels=3,out_channels=128,kernel_size=3,padding=1),
        VAE_ResidualBlock(128,128),
        VAE_ResidualBlock(128,128),

        VAE_Padding(),
        Conv2d(in_channels=128,out_channels=128,kernel_size=3,stride=2),
        VAE_ResidualBlock(128,256),
        VAE_ResidualBlock(256,256),

        VAE_Padding(),
        Conv2d(in_channels=256,out_channels=256,kernel_size=3,stride=2),
        VAE_ResidualBlock(256,512),
        VAE_ResidualBlock(512,512),

        VAE_Padding(),
        Conv2d(in_channels=512,out_channels=512,kernel_size=3,stride=2),
        VAE_ResidualBlock(512,512),
        VAE_ResidualBlock(512,512),
        VAE_ResidualBlock(512,512),

        VAE_AttentionBlock(512),
        VAE_ResidualBlock(512,512),

        ConvBlock(512,8)
        Conv2D(8,8,kernel_size=1,padding=0)#conv Linear


        # encoder out put was 8 chunk and resample so channel dimension is 4
        Conv2D(4,4,kernel_size=1,padding=0)#conv Linear

        Conv2D(4,512,kernel_size=3,padding=1)

        VAE_ResidualBlock(512,512),
        VAE_AttentionBlock(512),

        VAE_ResidualBlock(512,512),
        VAE_ResidualBlock(512,512),
        VAE_ResidualBlock(512,512),

        VAE_ResidualBlock(512,512),

        Upsample(scale_factor=2),
        Conv2D(512,512,kernel_size=3,padding=1),
        VAE_ResidualBlock(512,512),
        VAE_ResidualBlock(512,512),
        VAE_ResidualBlock(512,512),

        Upsample(scale_factor=2),
        Conv2D(512,512,kernel_size=3,padding=1),
        VAE_ResidualBlock(512,256),
        VAE_ResidualBlock(256,256),
        VAE_ResidualBlock(256,256),

        Upsample(scale_factor=2),
        Conv2D(256,256,kernel_size=3,padding=1),
        VAE_ResidualBlock(256,128),
        VAE_ResidualBlock(128,128),
        VAE_ResidualBlock(128,128),

        ConvBlock(128,3)






    ]
    super().__init__()