In [1]:
import torch
import torch.nn.functional as F
from torch import nn
from torchvision.transforms import CenterCrop
from dataclasses import dataclass

In [2]:
@dataclass
class DiffusionConfig:
    time_embedding_dim: int = 320

In [3]:
class SelfAttention(nn.Module):
    """
    Self Attention mechanism for sequence data.
    
    Attributes:
        scale (float): Scaling factor for the attention scores.
        n_heads (int): Number of attention heads.
        d_head (int): Dimension of each attention head.
        QKV (nn.Linear): Linear layer for Query, Key, Value.
        O (nn.Linear): Linear output layer.
    """
    
    def __init__(self, d_embed, n_heads: int = 4, qkv_bias=True, out_bias=True) -> None:
        """
        Initializes the SelfAttention class.
        
        Args:
            d_embed (int): Dimension of the embedding.
            n_heads (int): Number of attention heads. Defaults to 4.
            qkv_bias (bool): If True, adds bias to QKV linear layer. Defaults to True.
            out_bias (bool): If True, adds bias to O linear layer. Defaults to True.
        """
        
        super().__init__()
        
        self.n_heads = n_heads
        self.d_head = d_embed // n_heads
        self.scale = self.d_head ** -0.5
        self.QKV = nn.Linear(d_embed, d_embed * 3, bias=qkv_bias)
        self.O = nn.Linear(d_embed, d_embed, bias=out_bias)
        
    def forward(self, x: torch.Tensor, mask: bool = False) -> torch.Tensor:
        """
        Forward pass for the SelfAttention mechanism.
        
        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, seq_len, d_embed).
            mask (bool): If True, applies the attention mask. Defaults to False.

        Returns:
            torch.Tensor: Processed tensor.
        """
        
        # x: (batch_size, height*width, channels)
        # x: (batch_size, seq_len, d_embed)
        in_shape = x.shape
        bs, seq_len, d_embed = x.shape
        q, k, v = self.QKV(x).chunk(3, dim=-1) # (batch_size, seq_len, d_embed)@(d_embed, d_embed*3) -> (batch_size, seq_len, d_embed*3) -> (3x) (batch_size, seq_len, d_embed)
        
        # (batch_size, seq_len, d_embed) -> (batch_size, seq_len, n_heads, d_head) -> (batch_size, n_heads, seq_len, d_head)
        q = q.view(bs, seq_len, self.n_heads, self.d_head).transpose(1, 2)
        k = k.view(bs, seq_len, self.n_heads, self.d_head).transpose(1, 2)
        v = v.view(bs, seq_len, self.n_heads, self.d_head).transpose(1, 2)
        
        attn_scores = (q @ k.transpose(-2, -1)) * self.scale
        if not mask:
            mask = torch.ones_like(attn_scores).bool().triu(1) # (batch_size, n_heads, seq_len, seq_len)
            attn_scores.masked_fill_(mask, -1e9)
        weights = F.softmax(attn_scores, dim=-1)
        output = weights @ v # (batch_size, n_heads, seq_len, seq_len) -> (batch_size, n_heads, seq_len, d_head) -> (batch_size, n_heads, seq_len, d_head)
        output = output.transpose(1, 2).contiguous().view(in_shape) # (batch_size, n_heads, seq_len, d_head) -> (batch_size, seq_len, n_heads, d_head) -> (batch_size, seq_len, d_embed)
        return self.O(output) # (batch_size, seq_len, d_embed)@(d_embed, d_embed) -> (batch_size, seq_len, d_embed)
    
    
    
class CrossAttention(nn.Module):
    """
    Cross Attention mechanism for sequence data.
    
    Attributes:
        scale (float): Scaling factor for the attention scores.
        n_heads (int): Number of attention heads.
        d_head (int): Dimension of each attention head.
        QKV (nn.Linear): Linear layer for Query, Key, Value.
        O (nn.Linear): Linear output layer.
    """
    
    def __init__(self, n_heads: int, d_embed: int, d_cross: int, qkv_bias=True, out_bias=True) -> None:
        """
        Initializes the SelfAttention class.
        
        Args:
            d_embed (int): Dimension of the embedding.
            n_heads (int): Number of attention heads. Defaults to 4.
            qkv_bias (bool): If True, adds bias to QKV linear layer. Defaults to True.
            out_bias (bool): If True, adds bias to O linear layer. Defaults to True.
        """
        
        super().__init__()
        
        self.n_heads = n_heads
        self.d_head = d_embed // n_heads
        self.scale = self.d_head ** -0.5
        self.Q = nn.Linear(d_embed, d_embed, bias=qkv_bias)
        self.K = nn.Linear(d_cross, d_embed, bias=qkv_bias)
        self.V = nn.Linear(d_cross, d_embed, bias=qkv_bias)
        self.O = nn.Linear(d_embed, d_embed, bias=out_bias)
        
    def forward(self, x_q: torch.Tensor, x_kv: torch.TensorType) -> torch.Tensor:
        """
        Forward pass for the CrossAttention mechanism.
        
        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, seq_len, d_embed).
            mask (bool): If True, applies the attention mask. Defaults to False.

        Returns:
            torch.Tensor: Processed tensor.
        """
        
        # x_q(latents): (batch_size, seq_len_Q, d_embed_Q)
        # x_kv(context): (batch_size, seq_len_KV, d_embed_KV)
        in_shape = x_q.shape
        bs, seq_len, d_embed = x_q.shape
        
        # (batch_size, seq_len, d_embed) -> (batch_size, seq_len, n_heads, d_head) -> (batch_size, n_heads, seq_len, d_head)
        q = self.Q(x_q).view(bs, -1, self.n_heads, self.d_head).transpose(1, 2)
        k = self.K(x_kv).view(bs, -1, self.n_heads, self.d_head).transpose(1, 2)
        v = self.V(x_kv).view(bs, -1, self.n_heads, self.d_head).transpose(1, 2)
        
        attn_scores = (q @ k.transpose(-2, -1)) * self.scale
        weights = F.softmax(attn_scores, dim=-1)
        output = weights @ v # (batch_size, n_heads, seq_len, seq_len) -> (batch_size, n_heads, seq_len, d_head) -> (batch_size, n_heads, seq_len, d_head)
        output = output.transpose(1, 2).contiguous().view(in_shape) # (batch_size, n_heads, seq_len, d_head) -> (batch_size, seq_len, n_heads, d_head) -> (batch_size, seq_len, d_embed)
        return self.O(output) # (batch_size, seq_len, d_embed)@(d_embed, d_embed) -> (batch_size, seq_len, d_embed)

In [4]:
class Unet(nn.Module):
    """
    A PyTorch implementation of the U-Net architecture.
    """
    
    def __init__(self) -> None:
        super().__init__()
        
        self.downs = nn.ModuleList([])
        self.ups = nn.ModuleList([])

In [5]:
unet = Unet()
unet

Unet(
  (downs): ModuleList()
  (ups): ModuleList()
)

In [6]:
class Block(nn.Module):
    
    def __init__(self, in_channels: int, out_channels: int) -> None:
        super().__init__()
        
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3) # in_chxout_ch
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3) # out_chxout_ch
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        return x
    
# enc_block1 = Block(in_channels=1, out_channels=64)
# x =  torch.randn(1, 1, 64, 64)
# print(enc_block1(x).shape)

class Encoder(nn.Module):
    
    def __init__(self, channels: tuple = (3, 64, 128, 256, 512, 1024)) -> None:
        super().__init__()
        
        self.enc_blocks = nn.ModuleList(Block(in_ch, out_ch) for in_ch, out_ch in zip(channels, channels[1:]))
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        features = []
        for block in self.enc_blocks:
            x = block(x)
            features.append(x)
            x = self.pool(x)
        return features
        
        
encoder = Encoder().to('mps')
x = torch.randn(1, 3, 572, 572).to('mps')
ftrs = encoder(x)
# for ftr in ftrs: print(ftr.shape)

# an upsampling of the feature map
# a 2x2 convolution (“up-convolution”) that halves the number of feature channels, 
# a concatenation with the correspondingly cropped feature map from the contracting path
# two 3x3 convolutions, each followed by a ReLU

class Decoder(nn.Module):
    def __init__(self, channels=(1024, 512, 256, 128, 64)):
        super().__init__()
        self.channels = channels
        self.up_convs = nn.ModuleList(nn.ConvTranspose2d(in_ch, out_ch, kernel_size=2, stride=2) 
                                      for in_ch, out_ch in zip(channels, channels[1:]))
        self.dec_blocks = nn.ModuleList(Block(in_ch, out_ch) for in_ch, out_ch in zip(channels, channels[1:]))
        
    def forward(self, x: torch.Tensor, features: list) -> torch.Tensor:
        for up_conv, dec_block, feature in zip(self.up_convs, self.dec_blocks, features[::-1][1:]):
            x = up_conv(x)
            enc_feature = CenterCrop(x.shape[2:])(feature)
            x = torch.cat([x, enc_feature], dim=1)
            x = dec_block(x)
        return x
    
decoder = Decoder().to('mps')
x = torch.randn(1, 1024, 28, 28).to('mps')
print(decoder(x, ftrs).shape)

# >> (torch.Size([1, 64, 388, 388])

torch.Size([1, 64, 388, 388])


In [7]:
class UNet(nn.Module):
    def __init__(self, enc_chs=(3,64,128,256,512,1024), dec_chs=(1024, 512, 256, 128, 64), num_class=1, retain_dim=False, out_sz=(572,572)):
        super().__init__()
        self.encoder     = Encoder(enc_chs)
        self.decoder     = Decoder(dec_chs)
        self.head        = nn.Conv2d(dec_chs[-1], num_class, 1)
        self.retain_dim  = retain_dim
        self.out_sz      = out_sz

    def forward(self, x):
        enc_ftrs = self.encoder(x)
        out      = self.decoder(enc_ftrs[::-1][0], enc_ftrs[::-1][1:])
        out      = self.head(out)
        if self.retain_dim:
            out = F.interpolate(out, self.out_sz)
        return out

In [8]:
class TimeEmbeddings(nn.Module):
    """
    A PyTorch implementation of the time embeddings used in Diffusion model.
    """
    def __init__(self, time_embedding_dim: int) -> None:
        super().__init__()
        
        self.linear_1 = nn.Linear(in_features=time_embedding_dim, out_features=time_embedding_dim*4) # 320x1280
        self.linear_2 = nn.Linear(in_features=time_embedding_dim*4, out_features=time_embedding_dim*4) # 1280x1280
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.linear_2(F.silu(self.linear_1(x))) # (1, 320) -> (1, 1280)


In [9]:
class Head(nn.Module):
    """
    A PyTorch implementation of the U-Net head.
    """
    def __init__(self, in_channels: int, out_channels: int) -> None:
        super().__init__()
        
        self.group_norm = nn.GroupNorm(num_groups=32, num_channels=in_channels)
        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (batch_size, in_channels, height, width) Bx320x64x64
        x = F.silu(self.group_norm(x))
        return self.conv(x) # Bx320x64x64 -> Bx4x64x64

In [10]:
class Upsample(nn.Module):
    
    def __init__(self, in_channels: int) -> None:
        super().__init__()
        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=3, padding=1)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x:  BxCxHxW -> BxCx2*Hx2*W
        x = F.interpolate(x, scale_factor=2, mode='nearest')
        return self.conv(x)


In [11]:
class UnetAttentionBlock(nn.Module):
    def __init__(self, n_heads: int, n_embed: int, hidden_size: int = 768) -> None:
        super().__init__()
        
        channel_embed = n_heads * n_embed
        self.group_norm = nn.GroupNorm(num_groups=32, num_channels=channel_embed)
        self.conv_in = nn.Conv2d(in_channels=channel_embed, out_channels=channel_embed, kernel_size=1)
        self.conv_out = nn.Conv2d(in_channels=channel_embed, out_channels=channel_embed, kernel_size=1)
        self.layer_norm_1 = nn.LayerNorm(channel_embed)
        self.layer_norm_2 = nn.LayerNorm(channel_embed)
        self.layer_norm_3 = nn.LayerNorm(channel_embed)
        self.attention_1 = SelfAttention(n_heads=n_heads, d_embed=channel_embed, qkv_bias=False)
        self.attention_2 = CrossAttention(n_heads=n_heads, d_embed=channel_embed, d_cross=hidden_size, qkv_bias=False) 
        self.linear_1 = nn.Linear(in_features=channel_embed, out_features=channel_embed*4*2)
        self.linear_2 = nn.Linear(in_features=channel_embed*4, out_features=channel_embed)
        
    def ffn(self, x: torch.Tensor) -> torch.Tensor:
        residue = x
        x, gate = self.linear_1(x).chunk(2, dim=-1)
        x = x * F.gelu(gate)
        return residue + self.linear_2(x)
        

    def forward(self, x: torch.Tensor, context: torch.Tensor) -> torch.Tensor:
        # x: (batch_size, n_heads*n_embed, height, width)
        # context: (batch_size, seq_len, hidden_size)
        n, c, h, w = x.shape
        long_residue = x
        x = self.group_norm(x)
        x = self.conv_in(x)
        
        x = x.view((n, c, h*w)).transpose(-1, -2) # (batch_size, c, height*width) -> (batch_size, height*width, c)
        
        # Layer Normalization + Self-Attention + Residual
        x = x + self.attention_1(self.layer_norm_1(x)) # (batch_size, height*width, c)
        
        # Layer Normalization + Cross-Attention + Residual
        x = x + self.attention_2(self.layer_norm_2(x), context) # (batch_size, height*width, c)
        
        # Layer Normalization + Feed Forward + Residual
        x = self.ffn(self.layer_norm_3(x)) # (batch_size, height*width, c)
        
        x = x.transpose(-1, -2).view((n, c, h, w)) # (batch_size, height*width, c) -> (batch_size, c, height, width)
        
        return self.conv_out(x) + long_residue # (batch_size, c, height, width)

In [12]:
class UnetResidualBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, time_embed_dim: int = 1280) -> None:
        super().__init__()
        
        self.group_norm = nn.GroupNorm(num_groups=32, num_channels=in_channels)
        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1)
        self.linear = nn.Linear(in_features=time_embed_dim, out_features=out_channels)
        
        self.group_norm_merged = nn.GroupNorm(num_groups=32, num_channels=out_channels)
        self.conv_merged = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, padding=1)
        
        if in_channels == out_channels:
            self.skip = nn.Identity()
        else:
            self.skip = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1)
            
    def forward(self, x: torch.Tensor, time: torch.Tensor) -> torch.Tensor:
        # x: (batch_size, in_channels, height, width) -> (batch_size, out_channels, height, width)
        residue = x
        x = self.conv(F.silu(self.group_norm(x)))
        t = self.linear(F.silu(time)).unsqueeze(-1).unsqueeze(-1)
        merged = x + t
        return self.conv_merged(F.silu(self.group_norm_merged(merged))) + self.skip(residue)

In [13]:
class ApplyLayer(nn.Sequential):
    
    def forward(self, x: torch.Tensor, time: torch.Tensor, context: torch.Tensor) -> torch.Tensor:
        print(f'x: {x.shape}, time: {time.shape}, context: {context.shape}')
        for layer in self:
            if isinstance(layer, UnetAttentionBlock):
                x = layer(x, context)
            elif isinstance(layer, UnetResidualBlock):
                x = layer(x, time)
            else:
                x = layer(x)
        return x


In [14]:
def make_encoder_layers():
    return nn.ModuleList([
        ApplyLayer(nn.Conv2d(in_channels=4, out_channels=320, kernel_size=3, padding=1)), # Bx4x64x64 -> Bx320x64x64
        ApplyLayer(UnetResidualBlock(in_channels=320, out_channels=320), UnetAttentionBlock(n_heads=8, n_embed=40)), # Bx320x64x64 -> Bx320x64x64
        ApplyLayer(UnetResidualBlock(in_channels=320, out_channels=320), UnetAttentionBlock(n_heads=8, n_embed=40)), # Bx320x64x64 -> Bx320x64x64
        
        ApplyLayer(nn.Conv2d(in_channels=320, out_channels=320, kernel_size=3, padding=1, stride=2)), # Bx320x64x64 -> Bx320x32x32
        ApplyLayer(UnetResidualBlock(in_channels=320, out_channels=640), UnetAttentionBlock(n_heads=8, n_embed=80)), # Bx320x32x32 -> Bx320x32x32
        ApplyLayer(UnetResidualBlock(in_channels=640, out_channels=640), UnetAttentionBlock(n_heads=8, n_embed=80)), # Bx320x32x32 -> Bx320x32x32
        
        ApplyLayer(nn.Conv2d(in_channels=640, out_channels=640, kernel_size=3, padding=1, stride=2)), # Bx640x32x32 -> Bx640x16x16
        ApplyLayer(UnetResidualBlock(in_channels=640, out_channels=1280), UnetAttentionBlock(n_heads=8, n_embed=160)), # Bx640x16x16 -> Bx640x16x16
        ApplyLayer(UnetResidualBlock(in_channels=1280, out_channels=1280), UnetAttentionBlock(n_heads=8, n_embed=160)), # Bx640x16x16 -> Bx640x16x16
        
        ApplyLayer(nn.Conv2d(in_channels=1280, out_channels=1280, kernel_size=3, padding=1, stride=2)), # Bx1280x16x16 -> Bx1280x8x8
        ApplyLayer(UnetResidualBlock(in_channels=1280, out_channels=1280)),
        ApplyLayer(UnetResidualBlock(in_channels=1280, out_channels=1280)),
    ])

In [15]:
def make_decoder_layers():
    return nn.ModuleList([
        ApplyLayer(UnetResidualBlock(in_channels=2560, out_channels=1280)), # Bx2560x8x8 -> Bx1280x8x8
        ApplyLayer(UnetResidualBlock(in_channels=2560, out_channels=1280)), # Bx2560x8x8 -> Bx1280x8x8
        
        ApplyLayer(UnetResidualBlock(in_channels=2560, out_channels=1280), Upsample(in_channels=1280)), # Bx2560x8x8 -> Bx1280x16x16
        ApplyLayer(UnetResidualBlock(in_channels=2560, out_channels=1280), UnetAttentionBlock(n_heads=8, n_embed=160)), # Bx2560x16x16 -> Bx1280x16x16
        ApplyLayer(UnetResidualBlock(in_channels=2560, out_channels=1280), UnetAttentionBlock(n_heads=8, n_embed=160)), # Bx2560x16x16 -> Bx1280x16x16
        
        ApplyLayer(UnetResidualBlock(in_channels=1920, out_channels=1280), UnetAttentionBlock(n_heads=8, n_embed=160), Upsample(in_channels=1280)), # Bx2560x16x16 -> Bx1280x32x32
        ApplyLayer(UnetResidualBlock(in_channels=1920, out_channels=640), UnetAttentionBlock(n_heads=8, n_embed=80)), # Bx1920x32x32 -> Bx640x32x32
        ApplyLayer(UnetResidualBlock(in_channels=1280, out_channels=640), UnetAttentionBlock(n_heads=8, n_embed=80)), # Bx1280x32x32 -> Bx640x32x32
        
        ApplyLayer(UnetResidualBlock(in_channels=960, out_channels=640), UnetAttentionBlock(n_heads=8, n_embed=80), Upsample(in_channels=640)), # Bx960x32x32 -> Bx640x64x64
        ApplyLayer(UnetResidualBlock(in_channels=960, out_channels=320), UnetAttentionBlock(n_heads=8, n_embed=40)), # Bx960x32x32 -> Bx320x64x64
        ApplyLayer(UnetResidualBlock(in_channels=640, out_channels=320), UnetAttentionBlock(n_heads=8, n_embed=40)), # Bx640x32x32 -> Bx320x64x64
        ApplyLayer(UnetResidualBlock(in_channels=640, out_channels=320), UnetAttentionBlock(n_heads=8, n_embed=40)), # Bx640x32x32 -> Bx320x64x64  
    ])

In [16]:
class UNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = self._make_encoder_layers()
        self.bottle_neck = ApplyLayer(
            UnetResidualBlock(in_channels=1280, out_channels=1280),
            UnetAttentionBlock(n_heads=8, n_embed=160),
            UnetResidualBlock(in_channels=1280, out_channels=1280)
        )
        self.decoder   = self._make_decoder_layers()
        
    def forward(self, x: torch.Tensor, time: torch.Tensor, context: torch.Tensor) -> torch.Tensor:
        # x: Bx4x64x64
        # time: 1x1280
        # context: Bx77x768
        skip_connections = []
        for enc_layer in self.encoder:
            x = enc_layer(x, time, context)
            skip_connections.append(x)
            
        x = self.bottle_neck(skip_connections[-1], time, context)
        
        for dec_layer in self.decoder:
            x = torch.cat([x, skip_connections.pop()], dim=1)
            x = dec_layer(x, time, context)
        
        return x # Bx320x64x64
    
    
    def _make_encoder_layers(self):
        return nn.ModuleList([
            ApplyLayer(nn.Conv2d(in_channels=4, out_channels=320, kernel_size=3, padding=1)), # Bx4x64x64 -> Bx320x64x64
            ApplyLayer(UnetResidualBlock(in_channels=320, out_channels=320), UnetAttentionBlock(n_heads=8, n_embed=40)), # Bx320x64x64 -> Bx320x64x64
            ApplyLayer(UnetResidualBlock(in_channels=320, out_channels=320), UnetAttentionBlock(n_heads=8, n_embed=40)), # Bx320x64x64 -> Bx320x64x64
            
            ApplyLayer(nn.Conv2d(in_channels=320, out_channels=320, kernel_size=3, padding=1, stride=2)), # Bx320x64x64 -> Bx320x32x32
            ApplyLayer(UnetResidualBlock(in_channels=320, out_channels=640), UnetAttentionBlock(n_heads=8, n_embed=80)), # Bx320x32x32 -> Bx320x32x32
            ApplyLayer(UnetResidualBlock(in_channels=640, out_channels=640), UnetAttentionBlock(n_heads=8, n_embed=80)), # Bx320x32x32 -> Bx320x32x32
            
            ApplyLayer(nn.Conv2d(in_channels=640, out_channels=640, kernel_size=3, padding=1, stride=2)), # Bx640x32x32 -> Bx640x16x16
            ApplyLayer(UnetResidualBlock(in_channels=640, out_channels=1280), UnetAttentionBlock(n_heads=8, n_embed=160)), # Bx640x16x16 -> Bx640x16x16
            ApplyLayer(UnetResidualBlock(in_channels=1280, out_channels=1280), UnetAttentionBlock(n_heads=8, n_embed=160)), # Bx640x16x16 -> Bx640x16x16
            
            ApplyLayer(nn.Conv2d(in_channels=1280, out_channels=1280, kernel_size=3, padding=1, stride=2)), # Bx1280x16x16 -> Bx1280x8x8
            ApplyLayer(UnetResidualBlock(in_channels=1280, out_channels=1280)),
            ApplyLayer(UnetResidualBlock(in_channels=1280, out_channels=1280)),
        ])
        
    def _make_decoder_layers(self):
        return nn.ModuleList([
            ApplyLayer(UnetResidualBlock(in_channels=2560, out_channels=1280)), # Bx2560x8x8 -> Bx1280x8x8
            ApplyLayer(UnetResidualBlock(in_channels=2560, out_channels=1280)), # Bx2560x8x8 -> Bx1280x8x8
            
            ApplyLayer(UnetResidualBlock(in_channels=2560, out_channels=1280), Upsample(in_channels=1280)), # Bx2560x8x8 -> Bx1280x16x16
            ApplyLayer(UnetResidualBlock(in_channels=2560, out_channels=1280), UnetAttentionBlock(n_heads=8, n_embed=160)), # Bx2560x16x16 -> Bx1280x16x16
            ApplyLayer(UnetResidualBlock(in_channels=2560, out_channels=1280), UnetAttentionBlock(n_heads=8, n_embed=160)), # Bx2560x16x16 -> Bx1280x16x16
            
            ApplyLayer(UnetResidualBlock(in_channels=1920, out_channels=1280), UnetAttentionBlock(n_heads=8, n_embed=160), Upsample(in_channels=1280)), # Bx2560x16x16 -> Bx1280x32x32
            ApplyLayer(UnetResidualBlock(in_channels=1920, out_channels=640), UnetAttentionBlock(n_heads=8, n_embed=80)), # Bx1920x32x32 -> Bx640x32x32
            ApplyLayer(UnetResidualBlock(in_channels=1280, out_channels=640), UnetAttentionBlock(n_heads=8, n_embed=80)), # Bx1280x32x32 -> Bx640x32x32
            
            ApplyLayer(UnetResidualBlock(in_channels=960, out_channels=640), UnetAttentionBlock(n_heads=8, n_embed=80), Upsample(in_channels=640)), # Bx960x32x32 -> Bx640x64x64
            ApplyLayer(UnetResidualBlock(in_channels=960, out_channels=320), UnetAttentionBlock(n_heads=8, n_embed=40)), # Bx960x32x32 -> Bx320x64x64
            ApplyLayer(UnetResidualBlock(in_channels=640, out_channels=320), UnetAttentionBlock(n_heads=8, n_embed=40)), # Bx640x32x32 -> Bx320x64x64
            ApplyLayer(UnetResidualBlock(in_channels=640, out_channels=320), UnetAttentionBlock(n_heads=8, n_embed=40)), # Bx640x32x32 -> Bx320x64x64  
        ])

In [17]:
class DiffusionModel(nn.Module):
    """
    A PyTorch implementation of the Diffusion Model.
    """
    
    def __init__(self) -> None:
        super().__init__()
        
        self.time_embeddings = TimeEmbeddings(time_embedding_dim=320)
        self.unet = UNet()
        self.head = Head(in_channels=320, out_channels=4)
        
    def forward(self, latents: torch.Tensor, time: torch.Tensor, context: torch.Tensor) -> torch.Tensor:
        # latents: (batch_size, latent_dim, height, width) Bx4x64x64
        # time: (1, 320)
        # context: (batch_size, seq_len, hidden_size)
        
        time_embed = self.time_embeddings(time) # (1, 320) -> (1, 1280)
        unet_out = self.unet(latents, time_embed, context) # (B, 4, 64, 64), (B, 320, 64, 64)
        return self.head(unet_out) # (B, 320, 64, 64) -> (B, 4, 64, 64)

In [18]:
x = torch.randn(1, 4, 64, 64).to('mps')
t = torch.randn(1, 320).to('mps')
context = torch.randn(1, 77, 768).to('mps')
model = DiffusionModel().to('mps')

pred = model(x, t, context)

x: torch.Size([1, 4, 64, 64]), time: torch.Size([1, 1280]), context: torch.Size([1, 77, 768])
x: torch.Size([1, 320, 64, 64]), time: torch.Size([1, 1280]), context: torch.Size([1, 77, 768])
x: torch.Size([1, 320, 64, 64]), time: torch.Size([1, 1280]), context: torch.Size([1, 77, 768])
x: torch.Size([1, 320, 64, 64]), time: torch.Size([1, 1280]), context: torch.Size([1, 77, 768])
x: torch.Size([1, 320, 32, 32]), time: torch.Size([1, 1280]), context: torch.Size([1, 77, 768])
x: torch.Size([1, 640, 32, 32]), time: torch.Size([1, 1280]), context: torch.Size([1, 77, 768])
x: torch.Size([1, 640, 32, 32]), time: torch.Size([1, 1280]), context: torch.Size([1, 77, 768])
x: torch.Size([1, 640, 16, 16]), time: torch.Size([1, 1280]), context: torch.Size([1, 77, 768])
x: torch.Size([1, 1280, 16, 16]), time: torch.Size([1, 1280]), context: torch.Size([1, 77, 768])
x: torch.Size([1, 1280, 16, 16]), time: torch.Size([1, 1280]), context: torch.Size([1, 77, 768])
x: torch.Size([1, 1280, 8, 8]), time: to

In [19]:
print(f'--------------> X SHAPE: {x.shape}')
print(f'--------------> PRED SHAPE: {pred.shape}')

--------------> X SHAPE: torch.Size([1, 4, 64, 64])
--------------> PRED SHAPE: torch.Size([1, 4, 64, 64])
