In [None]:
import torch
from torch import nn
from torch.nn import functional as F
# from decoder import 

In [None]:
from clip import CLIPLayer, CLIP

In [None]:
class VAE_Encoder(nn.Sequential):

    def __init__(self):
        super().__init__(
            # (Batch_size, Channel, Height, Width) -> (B, 128, Height, Weight)
            nn.Conv2d(3, 128, kernel_size=3, padding=1),

            # (B, 128, Height, Weight) -> (B, 128, Height, Weight)
            # VAE_ResidualBlock(128, 128),
            # VAE_ResidualBlock(128, 128),

            # (B, 128, Height, Weight) -> (B, 128, Height / 2, Weight / 2) 
            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=0),

            # (B, 128, Height / 2, Weight / 2) -> (B, 256, Height / 2, Weight / 2) 
            # VAE_ResidualBlock(128, 256),
            # VAE_ResidualBlock(256, 256),

            # (B, 256, Height / 4, Weight / 4) 
            nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=0),
            # (B, 512, Height / 4, Weight / 4) 
            # VAE_ResidualBlock(256, 512),
            # VAE_ResidualBlock(512, 512),

            # (B, 512, Height / 8, Weight / 8) 
            nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=0),
            # (B, 512, Height / 8, Weight / 8) 
            # VAE_ResidualBlock(512, 512),
            # VAE_ResidualBlock(512, 512),
            # VAE_ResidualBlock(512, 512),

            # (B, 512, Height / 8, Weight / 8) 
            # VAE_AttentionBlock(512)

            # VAE_ResidualBlock(512, 512),
            # VAE_ResidualBlock(512, 512),

            # (B, 512, Height / 8, Weight / 8) 
            nn.GroupNorm(32, 512),
            nn.SiLU(),

            # (B, 8, Height / 8, Weight / 8) 
            nn.Conv2d(512, 8, kernel_size=3, padding=1),
            nn.Conv2d(8, 8, kernel_size=1, padding=0),
        )
    
    def forward(self, x: torch.tensor, noise: torch.tensor) -> torch.tensor:
        # x: (B, C, H, W)
        # noise: (B, out_channels, H / 8, W / 8)
        for module in self:
            if getattr(module, 'stride', None) == (2, 2):
                # (padding - left, right, top, bottom) 
                x = F.pad(x, (0, 1, 0, 1))
                x = module(x)
        
        # now we have to return the mean and var since this is a VAE
        # output of last layer is (B, 8, Height / 8, Weight / 8) -> two tensors of shape (B, 4, Height / 8, width / 8)
        # TODO: Why are we doing this to find mean and variance?
        mean, log_variance = torch.chunk(x, 2, dim=1)
        # (B, 4, Height / 8, width / 8)
        log_variance = torch.clamp(log_variance, -30, 20)        
        variance = log_variance.exp()
        std_dev = variance.sqrt()

        # z = N(0, 1) -> N(mean, variance) = X?
        x = mean + std_dev * noise

        # scale the output (not sure why)
        x *= 0.18215
        return x


In [None]:
# from attention import SelfAttention

In [None]:
x = torch.randn(1, 1, 8, 8)
conv = nn.Conv2d(1, 8, kernel_size=1, padding=0)
print(x)

In [None]:
# conv(x)

In [None]:
class VAE_ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.groupnorm_1 = nn.GroupNorm(32, in_channels)
        self.conv_1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)

        self.groupnorm_2 = nn.GroupNorm(32, out_channels)
        self.conv_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)

        if in_channels == out_channels:
            self.residual_layer = nn.Identity()
        else:
            self.residual_layer = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
    
    def forward(self, x: torch.tensor) -> torch.tensor:
        # x: (B, in_channels, height, width)
        residue = x
        x = self.groupnorm_1(x)
        x = F.silu(x)
        x = self.conv_1(x)
        x = self.groupnorm_2(x)
        x = F.silu(x)
        x = self.conv_2(x)
        return x + self.residual_layer(residue)

In [None]:
class VAE_AttentionBlock(nn.Module):

    def __init__(self, channels: int):
        super().__init__()
        self.groupnorm = nn.GroupNorm(32, channels)
        self.attention = SelfAttention(1, channels)
    
    def forward(self, x: torch.tensor) -> torch.tensor:
        residue = x
        n,c,h,w = x.shape
        x = x.view(n, c, h * w)
        # (B, H*w, Features)
        x = x.transpose(-1, -2)
        # do the attention between features
        x = self.attention(x)

        # (B, Features, H*w)
        x = x.transpose(-1, -2)
        # convert to original shape
        x = x.view(n, c, h, w)
        x = x + residue
        return x


In [None]:
# Build attention - both self attention and cross attention

In [None]:
import math

In [None]:
x = torch.randn(8, 3, 3)
q, k, v = x.chunk(3, dim=-1)
q.shape, k.shape, v.shape

In [None]:
q = torch.randn(8, 3, 16, 4)
v = torch.randn(8, 3, 16, 4)
torch.matmul(q, v.transpose(-2, -1)).shape

In [None]:
class SelfAttention(nn.Module):
    def __init__(self, n_heads: int, d_embed: int, in_proj_bias=True, out_proj_bias=True):
        '''
        d_embed is number of channels/features per pixel here
        '''
        super().__init__()
        self.in_proj = nn.Linear(d_embed, 3 * d_embed, bias=in_proj_bias)
        self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias)
        self.n_heads = n_heads
        self.head_dim = d_embed // n_heads
    
    def forward(self, x, causual_mask=False):
        '''
        X is like # (B, H*w, Features)
        '''
        input_shape = x.shape
        batch_size, sequence_length, d_embed = input_shape
        interim_shape = (batch_size, sequence_length, self.n_heads, self.head_dim)

        # (B, S, D) -> (B, S, 3*D) -> 3 tensors of shape (B, S, D)
        q,k ,v = self.in_proj(x).chunk(3, dim=-1)

        # (B, S, H, Dim/H) -> (B, H, S, Dim/H)
        q = q.view(interim_shape).transpose(1, 2)
        k = k.view(interim_shape).transpose(1, 2)
        v = v.view(interim_shape).transpose(1, 2)

        # (B, H, S, Dim/H) * (B, H, Dim/H, S) -> (B, H, S, S)
        weight = torch.matmul(q, k.tranpose(-1, -2)) / math.sqrt(self.head_dim)
        if causual_mask:
            mask = torch.ones_like(weight, dtype=torch.bool).triu(1)
            weight.masked_fill_(mask, -torch.inf)
        
        # (B, H, S, S)
        weights = F.softmax(weight, dim=-1)
        # (B, H, S, S)
        # (B, H, S, Dim/H) -> (B, H, S, Dim/H)
        output = weight @ v
        output = output.transpose(1, 2)
        # (B, S, dim)
        output = output.view(input_shape)
        # (B, S, dim)
        output = self.out_proj(output)
        return output

In [None]:
class VAE_Decoder(nn.Sequential):
    def __init__(self):
        super().__init__(
            nn.Conv2d(4, 4, kernel_size=1, padding=0),
            nn.Conv2d(4, 512, kernel_size=3, padding=1),
            VAE_ResidualBlock(512, 512),

            VAE_AttentionBlock(512),
            VAE_ResidualBlock(512, 512),
            VAE_ResidualBlock(512, 512),
            VAE_ResidualBlock(512, 512),
            VAE_ResidualBlock(512, 512),

            # (B, 512, Height / 8, Width / 8) -> (64, 64) for us

            # (B, 512, Height / 8, Width / 8) -> # (B, 512, Height / 4, Width / 4)
            nn.Upsample(scale_factor=2),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            VAE_ResidualBlock(512, 512),
            VAE_ResidualBlock(512, 512),
            VAE_ResidualBlock(512, 512),

            # (B, 512, Height / 4, Width / 4) -> # (B, 512, Height / 2, Width / 2) (256, 256) for us
            nn.Upsample(scale_factor=2),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            VAE_ResidualBlock(512, 256),
            VAE_ResidualBlock(256, 256),
            VAE_ResidualBlock(256, 256),

            # (B, 246, Height / 2, Width / 2) -> # (B, 128, Height, Width) (512, 512) for us
            nn.Upsample(scale_factor=2),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            VAE_ResidualBlock(256, 128),
            VAE_ResidualBlock(128, 128),
            VAE_ResidualBlock(128, 128),

            nn.GroupNorm(32, 128),
            nn.SiLU(),

            # (B, 128, Height, Width) -> (B, 3, Height, Width)
            nn.Conv2d(128, 3, kernel_size=3, padding=1)
        )
    
    def forward(self, x: torch.tensor) -> torch.tensor:
        # x: (B, 4, H / 8, W / 8)
        x /= 0.18215
        for module in self:
            x = module(x)
        
        # (B, 3, H, W)
        return x

In [None]:
# CLIP text encoder

In [None]:
class CLIPEmbedding(nn.Module):
    def __init__(self, n_vocab: int, n_embd: int, n_token: int):
        super().__init__()
        
        self.token_embedding = nn.Embedding(n_vocab, n_embd)
        self.position_embedding = nn.Parameter(torch.zeros((n_token, n_embd)))
    
    def forward(self, tokens):
        # (B, S) -> (B, S, D)
        x = self.token_embedding(tokens)
        x += self.position_embedding
        return x

In [None]:
clip_embedding = CLIPEmbedding(1024, 16, 20)

In [None]:
tokens = torch.arange(20)
tokens

In [None]:
clip_embedding(tokens).shape

In [None]:
class CLIPLayer(nn.Module):
    def __init__(self, n_head: int, n_embd: int):
        super().__init__()
        self.layernorm_1 = nn.LayerNorm(n_embd)
        self.attention = SelfAttention(n_head, n_embd)
        self.layernorm_2 = nn.LayerNorm(n_embd)

        self.linear1 = nn.Linear(n_embd, 4 * n_embd)
        self.linear2 = nn.Linear(4 * n_embd, n_embd)
    
    def forward(self, x):
        x = x + self.attention(self.layernorm_1(x), causal_mask=True)
        x = self.linear1(self.layernorm_2(x))
        x = x * torch.sigmoid(1.702 * x) # QuickGELU activation function
        x = x + self.linear2(x)
        return x

In [None]:
class CLIP(nn.Module):
    '''
    Returns an embedding for every token in the input sequence.
    '''
    def __init__(self):
        super().__init__()
        self.embedding = CLIPEmbedding(49408, 768, 77)
        self.layers = nn.ModuleList([
            CLIPLayer(12, 768) for i in range(12)
        ])
        self.layernorm = nn.LayerNorm(768)
    
    def forward(self, x):
        tokens = tokens.type(torch.long)

        # (B, S) -> (B, S, Dim)
        state = self.embedding(tokens)
        for layer in self.layers:
            state = layer(state)
        output = self.layernorm(state)

        return output

In [None]:
### Diffusion

In [None]:
x = torch.randn(8, 3, 16, 16)
conv_1 = nn.Conv2d(3, 10, kernel_size=3, padding=1)
conv_1(x).shape

In [None]:
beta = torch.linspace(0, 10, 100, dtype=torch.float32) ** 2
# beta

In [None]:
from transformers import CLIPTokenizer

In [None]:
tokenizer = CLIPTokenizer("../data/tokenizer_vocab.json", merges_file="../data/tokenizer_merges.txt")

In [None]:
tokens = tokenizer.encode("How are you doing")

In [None]:
clip_layer = CLIP()

In [None]:
# clip_layer(torch.tensor(tokens)).shape

In [None]:
tokenizer.bos_token_id

In [1]:
import model_converter

In [2]:
model_file = "../data/v1-5-pruned-emaonly.ckpt"
state_dict = model_converter.load_from_standard_weights(model_file, 'cpu')

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
state_dict.keys()

dict_keys(['diffusion', 'encoder', 'decoder', 'clip'])

In [4]:
from clip import CLIP
from encoder import VAE_Encoder
from decoder import VAE_Decoder
from diffusion import Diffusion
import torch

In [5]:
clip = CLIP().to('cpu')
clip.load_state_dict(state_dict['clip'], strict=True)

<All keys matched successfully>

In [6]:
encoder = VAE_Encoder().to('cpu')
encoder.load_state_dict(state_dict['encoder'], strict=True)

<All keys matched successfully>

In [7]:
decoder = VAE_Decoder().to('cpu')
decoder.load_state_dict(state_dict['decoder'], strict=True)

<All keys matched successfully>

In [8]:
diffusion = Diffusion().to('cpu')
diffusion.load_state_dict(state_dict['diffusion'], strict=True)

<All keys matched successfully>

In [9]:
beta_start = .00085
beta_end = .0120
betas = torch.linspace(beta_start ** 0.5, beta_end ** 0.5, 1000, dtype=torch.float32) ** 2
# betas

In [None]:
torch.linspace(1, 10, 20)

In [None]:
a = torch.tensor([1, 2, 3, 4])
torch.cumprod(a, dim=-1)

In [None]:
import numpy as np
np.arange(0, 10)[::-1]

In [None]:
torch.arange(99, 0, -1)

In [None]:
torch.from_numpy(np.arange(0, 20)[::-1].copy())

In [None]:
torch.arange(19, -1, -1)

In [None]:
step_ratio = 1000 // 50
timesteps = (np.arange(0, 50) * step_ratio).round()[::-1].copy().astype(np.int64)
timesteps

In [None]:
torch.arange(49, -1, -1).view(2, 25).flatten()

In [None]:
class DDPMSampler:
    def __init__(self, generator: torch.Generator, num_training_steps=1000, beta_start=0.00085, beta_end=0.0120):
        self.betas = torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_training_steps, dtype=torch.float32) ** 2
        self.alphas = 1 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        self.one = torch.tensor(1.0)

        self.generator = generator
        self.num_train_timesteps = num_training_steps
        self.timesteps = torch.arange(num_training_steps - 1, -1, -1)
    
    def set_inference_timesteps(self, num_inference_steps=50):
        self.num_inference_steps = num_inference_steps
        step_ratio = self.num_train_timesteps // self.num_inference_steps
        self.step_ratio = step_ratio
        # In decreasing order 980, 960, 940 ....
        self.timesteps = torch.arange(num_inference_steps - 1, -1, -1) * step_ratio
    
    def _get_previous_timestep(self, timestep: int) -> int:
        prev_t = timestep - self.step_ratio
        return prev_t
    
    def add_noise(self, original_samples: torch.tensor, timesteps: torch.tensor):
        sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
        # TODO: Is this needed, verify the shapes
        sqrt_alpha_prod = sqrt_alpha_prod.flatten()
        while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
            sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
        
        variance = (1 - self.alphas_cumprod[timesteps]) ** 0.5
        variance = variance.flatten()
        while len(variance.shape) < len(original_samples.shape):
            variance = variance.unsqueeze(-1)
        
        # Like in eq(4) of DDPm paper, q(x_t/x_0) can be obtained
        noise = torch.randn(original_samples.shape, generator=self.generator, device=original_samples.device, dtype=original_samples.dtype)
        noisy_samples = sqrt_alpha_prod * original_samples + variance * noise
        return noisy_samples
    
    def set_strength(self, strength=1):
        '''
        More noise: output will be further from input
        less noise: outpout will be closer to input image
        '''
        start_step = self.num_inference_steps - int(self.num_inference_steps * strength)
        self.timesteps = self.timesteps[start_step:]
        self.start_step = start_step
    
    def step(self, t: int, latents: torch.tensor, model_output: torch.tensor):
        '''
        Remove noise from the latents and get latent at timestep (t-1)
        latents is x(t) at timestep t, model_output is the predicted noise
        '''
        prev_t = self._get_previous_timestep(t)

        # We use formula 6 and 7 of the paper, calculate alphas and beta first
        alpha_prod_t = self.alphas_cumprod[t]
        alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
        beta_prod_t = 1 - alpha_prod_t
        beta_prod_t_prev = 1 - alpha_prod_t_prev
        current_alpha_t = alpha_prod_t / alpha_prod_t_prev
        current_beta_t = 1 - current_alpha_t

        # Now calculate x_0
        pred_original_sample = (latents - (beta_prod_t ** 0.5) * model_output) / (alpha_prod_t ** 0.5)
        predicted_original_sample_coeff = ((alpha_prod_t_prev ** 0.5) * current_beta_t) / beta_prod_t
        current_sample_coeff = (current_alpha_t ** (0.5) * beta_prod_t_prev) / beta_prod_t

        prev_sample_mean = predicted_original_sample_coeff * pred_original_sample + current_sample_coeff * latents

        # variance
        variance = 0
        if t > 0:
            noise = torch.randn(model_output.shape, generator=self.generator, device=model_output.device, dtype=model_output.dtype)
            variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t
            variance = torch.clamp(variance, min=1e-20)
            variance = variance ** 0.5
        
        pred_prev_sample = pred_prev_sample + variance * noise
        return pred_prev_sample

In [None]:
from PIL import Image
generator = torch.Generator()
generator.manual_seed(0)

In [None]:
ddpm_sampler = DDPMSampler(generator)

In [None]:
img = Image.open("../images/dog.jpg")
img_tensor = torch.tensor(np.array(img))
img_tensor.shape
# img

In [None]:
img_tensor = ((img_tensor / 255.0) * 2.0) - 1.0
img_tensor.shape

In [None]:
noise_levels = [0, 10, 50, 75, 100, 250, 500, 750]

In [None]:
batch = img_tensor.repeat(len(noise_levels), 1, 1, 1)
batch.shape

In [None]:
ts = torch.tensor(noise_levels)

In [None]:
noise_imgs = []
epsilons = torch.randn(batch.shape)

In [None]:
import math

In [None]:
for i in range(len(ts)):
    a_hat = ddpm_sampler.alphas_cumprod[ts[i]]
    noise_imgs.append(
        # Equation 4 of the paper, all images in the batch are identical
        math.sqrt(a_hat) * batch[i] + math.sqrt(1 - a_hat) * epsilons[i]
    )

In [None]:
noise_imgs[0].shape

In [None]:
noise_imgs = torch.stack(noise_imgs, dim=0)
noise_imgs.shape

In [None]:
noise_imgs = (noise_imgs.clamp(-1, 1) + 1) / 2
noise_imgs = (noise_imgs * 255).type(torch.uint8)

In [None]:
noise_imgs[7].squeeze(0).shape

In [None]:
display_img = Image.fromarray(noise_imgs[4].squeeze(0).numpy(), 'RGB')
# display_img

In [None]:
import torch.nn.functional as F
import torch

In [None]:
x = torch.tensor([1, 2, 3, 4, 5, 6], dtype=torch.float32).view(1, 2, 3); x

In [None]:
y = F.interpolate(x, scale_factor=2, mode='nearest'); y
print(y.shape)
y

In [None]:
a = []
a.append(1)
a.append(2)

In [None]:
a.pop()

In [None]:
import torch
from torch import nn
import torch.nn.functional as F
from attention import SelfAttention, CrossAttention

class TimeEmbedding(nn.Module):
    def __init__(self, n_embed):
        super().__init__()
        self.linear_1 = nn.Linear(n_embed, 4 * n_embed)
        self.linear_2 = nn.Linear(4 * n_embed, 4 * n_embed)
    
    def forward(self, x):
        # x : (1, 320)
        x = F.silu(self.linear_1(x))
        # (1, 1280)
        return self.linear_2(x)


class UNET_ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, n_time=1280):
        super().__init__()
        self.groupnorm_feature = nn.GroupNorm(32, in_channels)
        self.conv_feature = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.linear_time = nn.Linear(n_time, out_channels)

        self.groupnorm_merged = nn.GroupNorm(32, out_channels)
        self.conv_merged = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)

        if in_channels == out_channels:
            self.residual_layer = nn.Identity()
        else:
            self.residual_layer = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
    
    def forward(self, feature, time):
        # feature: (B, C, H, W)
        # time: (1, 1280)
        residue = feature

        # (B, in_channels, H, W)
        feature = self.groupnorm_feature(feature)
        feature = F.silu(feature)
        # (B, out_channels, H, W)
        feature = self.conv_feature(feature)

        # (1, 1280)
        time = F.silu(time)
        # (1, 1280) -> (1. out_channels)
        time = self.linear_time(time)

        # Add height and width dimension to time
        # (B, out_channels, H, W)
        merged = feature + time.unsqueeze(-1).unsqueeze(-1)
        merged = self.groupnorm_merged(merged)
        merged = F.silu(merged)
        merged = self.conv_merged(merged)

        return merged + self.residual_layer(residue)


class UNET_AttentionBlock(nn.Module):
    def __init__(self, n_head: int, n_embed: int, d_context: int=768):
        super().__init__()
        channels = n_head * n_embed

        self.groupnorm = nn.GroupNorm(32, channels, eps=1e-6)
        self.conv_input = nn.Conv2d(channels, channels, kernel_size=1, padding=0)

        self.layernorm_1 = nn.LayerNorm(channels)
        self.attention_1 = SelfAttention(n_head, channels, in_proj_bias=False)
        self.layernorm_2 = nn.LayerNorm(channels)
        self.attention_2 = CrossAttention(n_head, channels, d_context, in_proj_bias=False)
        self.layernorm_3 = nn.LayerNorm(channels)
        self.linear_geglu_1 = nn.Linear(channels, 4 * channels * 2)
        self.linear_geglu_2 = nn.Linear(4 * channels, channels)

        self.conv_output = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
    
    def forward(self, image, context):
        # image: (B, F, H, W)
        # context: (B, seq_len, Dim)
        residue = image

        # (B, F, H, W)
        image = self.groupnorm(image)
        image = self.conv_input

        n, c, h, w = image.shape

        # (B, H*W, F)
        image = image.view((n, c, h * w)).transpose(1, 2)

        # Self attention
        # (B, H*W, F)
        image = image + self.attention_1(self.layernorm_1(image))

        # Cross Attention
        # (B, H*W, F)
        image = image + self.attention_2(self.layernorm_2(image), context)

        residue_tmp = image
        # (B, H*W, F)
        image = self.layernorm_3(image)

        # (B, H*W, F) -> two tensors of dim (B, H*W, F * 4)
        image, gate = self.linear_geglu_1(image).chunk(2, dim=-1)

        # (B, H*W, F * 4)
        image = image * F.gelu(gate)
        # (B, H*W, F)
        image = self.linear_geglu_2(image)

        image += residue_tmp
        # (B, F, H, W)
        image = image.transpose(-1, -2).view(image.shape)

        # residual connection -> (B, F, H, W)
        return residue + self.conv_output(image)
    
class UpSample(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
    
    def forward(self, x):
        # X : (B, F, H, W)

        # (B, F, H * 2, W * 2)
        x = F.interpolate(x, scale_factor=2, mode='nearest')
        return self.conv(x)

class SwitchSequential(nn.Sequential):
    def forward(self, image, context, time):
        for layer in self:
            if isinstance(layer, UNET_AttentionBlock):
                image = layer(layer, context)
            elif isinstance(layer, UNET_ResidualBlock):
                image = layer(image, time)
            else:
                image = layer(image)
    
        return image

In [None]:
a = UNET_AttentionBlock(8, 40)

In [None]:
encoders = nn.ModuleList([
            # (B, 4, H / 8, W / 8) -> (B, 320, H / 8, W / 8)
            SwitchSequential(nn.Conv2d(4, 320, kernel_size=3, padding=1)),
            # (B, 320, H / 8, W / 8)
            SwitchSequential(UNET_ResidualBlock(320, 320), UNET_AttentionBlock(8, 40)),
            SwitchSequential(UNET_ResidualBlock(320, 320), UNET_AttentionBlock(8, 40)),
])

In [None]:
for layer in encoders:
    print(layer.__class__)

In [None]:
class UNET(nn.Module):
    def __init__(self):
        super().__init__()
        # decrease image dimension and increase the number of channels, input to this is the output of VAEEncoder (latent)
        self.encoders = nn.ModuleList([
            # (B, 4, H / 8, W / 8) -> (B, 320, H / 8, W / 8)
            SwitchSequential(nn.Conv2d(4, 320, kernel_size=3, padding=1)),
            # (B, 320, H / 8, W / 8)
            SwitchSequential(UNET_ResidualBlock(320, 320), UNET_AttentionBlock(8, 40)),
            SwitchSequential(UNET_ResidualBlock(320, 320), UNET_AttentionBlock(8, 40)),

            # (B, 320, H / 16, W / 16)
            SwitchSequential(nn.Conv2d(320, 320, kernel_size=3, padding=1, stride=2)),
            # (B, 640, H / 16, W / 16)
            SwitchSequential(UNET_ResidualBlock(320, 640), UNET_AttentionBlock(8, 80)),
            SwitchSequential(UNET_ResidualBlock(320, 640), UNET_AttentionBlock(8, 80)),

            # (B, 640, H / 32, W / 32)
            SwitchSequential(nn.Conv2d(640, 640, kernel_size=3, padding=1, stride=2)),
            # (B, 1280, H / 32, W / 32)
            SwitchSequential(UNET_ResidualBlock(640, 1280), UNET_AttentionBlock(8, 160)),
            SwitchSequential(UNET_ResidualBlock(1280, 1280), UNET_AttentionBlock(8, 160)),

            # (B, 1280, H / 64, W / 64)
            SwitchSequential(nn.Conv2d(1280, 1280, kernel_size=3, padding=1, stride=2)),
            # (B, 1280, H / 64, W / 64)
            SwitchSequential(UNET_ResidualBlock(1280, 1280)),
            SwitchSequential(UNET_ResidualBlock(1280, 1280)),
        ])

        self.bottleneck = SwitchSequential(
            # (B, 1280, H / 64, W / 64)
            UNET_ResidualBlock(1280, 1280),
            UNET_AttentionBlock(8, 160),
            UNET_ResidualBlock(1280, 1280),
        )

        self.decoders = nn.ModuleList([
            # (B, 1280, H / 64, W / 64)
            SwitchSequential(UNET_ResidualBlock(2560, 1280)), # 1280 concat
            # (B, 1280, H / 64, W / 64)
            SwitchSequential(UNET_ResidualBlock(2560, 1280)),  # 1280 concat

            # (B, 1280, H / 32, W / 32)
            SwitchSequential(UNET_ResidualBlock(2560, 1280), UpSample(1280)), # 1280 concat

            # (B, 1280, H / 32, W / 32)
            SwitchSequential(UNET_ResidualBlock(2560, 1280), UNET_AttentionBlock(8, 160)), # 1280 concat
            SwitchSequential(UNET_ResidualBlock(2560, 1280), UNET_AttentionBlock(8, 160)), # 1280 concat

            # (B, 1280, H / 16, W / 16)
            SwitchSequential(UNET_ResidualBlock(1920, 1280), UNET_AttentionBlock(8, 160), UpSample(1280)), # 640 concat
            # (B, 640, H / 16, W / 16)
            SwitchSequential(UNET_ResidualBlock(1920, 640), UNET_AttentionBlock(8, 80)), # 640 concat
            SwitchSequential(UNET_ResidualBlock(1280, 640), UNET_AttentionBlock(8, 80)), # 640 concat

            # (B, 640, H / 8, W / 8)
            SwitchSequential(UNET_ResidualBlock(960, 640), UNET_AttentionBlock(8, 160), UpSample(640)), # 320 concat
            SwitchSequential(UNET_ResidualBlock(960, 320), UNET_AttentionBlock(8, 40)), # 320 concat
            SwitchSequential(UNET_ResidualBlock(640, 320), UNET_AttentionBlock(8, 40)), # 320 concat
            # (B, 320, H / 8, W / 8)
            SwitchSequential(UNET_ResidualBlock(640, 320), UNET_AttentionBlock(8, 40)), # 320 concat
        ])
    
    def forward(self, image, context, time):
        '''
        image: (B, 4, H / 8, W / 8)
        context: (B, s_len, Dim)
        time: (1, 1280)
        '''

        skip_connections = []
        for layers in self.encoders:
            image = layers(image, context, time)
            skip_connections.append(image)
        
        image = self.bottleneck(image, context, time)
        
        for layer in self.decoders:
            # this concat increases the number of images sent to decoder layers
            # 
            image = torch.cat((image, skip_connections.pop()), dim=1)
            image = layer(image, context, time)

            '''
            encoder last layer output is (B, 1280, H / 64, W / 64)
            image from bottleneck layer output is (B, 1280, H / 64, W / 64)
            These are concat in first decoder layer so first dim is 2560
            '''
        # (B, 320, H / 8, W / 8)
        return x


In [None]:
class UNET_OutputLayer(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.groupnorm = nn.GroupNorm(32, in_channels)
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
    
    def forward(self, image):
        # image: (B, 320, H / 8, W / 8)
        image = self.groupnorm(image)
        image = F.silu(image)
        # (B, 4, H / 8, W / 8)
        image = self.conv(image)
        return image

In [None]:
class Diffusion(nn.Module):
    def __init__(self):
        super().__init__()
        self.time_embedding = TimeEmbedding(320)
        self.unet = UNET()
        self.final = UNET_OutputLayer(320, 4)
    
    def forward(self, latent, context, time):
        '''
        latent is (B, 4, H / 8, W / 8)
        context: (B, seq_len, dim)
        time: (1, 320)
        '''
        time = self.time_embedding(time)
        # (B, 4, H / 8, W / 8)
        output = self.final(self.unet(latent, context, time))
        return output
