In [1]:
import torch
from torch.nn import functional as F
from torch import nn
import math
import os
import shutil
from sklearn.model_selection import train_test_split
import os


In [None]:
class SelfAttention(nn.Module):
    """
    Multi-head self-attention mechanism for processing sequences.
    
    Args:
        n_heads: Number of attention heads
        embed_dim: Dimension of the embedding space
        in_proj_bias: Whether to include bias in input projection
        out_proj_bias: Whether to include bias in output projection
    """
    def __init__(self, n_heads, embed_dim, in_proj_bias=True, out_proj_bias=True):
        super().__init__()
        self.n_heads = n_heads
        self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=in_proj_bias)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=out_proj_bias)
        self.d_head = embed_dim // n_heads

    def forward(self, x, causal_mask=False):
        """
        Forward pass for multi-head self-attention.
        
        Args:
            x: Input tensor of shape (batch_size, sequence_length, embedding_dim)
            causal_mask: Whether to apply causal masking for autoregressive generation
            
        Returns:
            Output tensor of shape (batch_size, sequence_length, embedding_dim)
        """
        batch_size, sequence_length, embedding_dim = x.shape

        # Project input to query, key, and value
        interim_shape = (batch_size, sequence_length, self.n_heads, self.d_head)
        query, key, value = self.in_proj(x).chunk(3, dim=-1)
        
        # Reshape for multi-head attention
        query = query.view(interim_shape).transpose(1, 2)
        key = key.view(interim_shape).transpose(1, 2)
        value = value.view(interim_shape).transpose(1, 2)

        # Compute attention scores
        attention_weights = query @ key.transpose(-1, -2)

        # Apply causal mask if specified
        if causal_mask:
            mask = torch.ones_like(attention_weights, dtype=torch.bool).triu(1)
            attention_weights.masked_fill_(mask, -torch.inf)

        # Scale and normalize attention weights
        attention_weights = attention_weights / math.sqrt(self.d_head)
        attention_weights = F.softmax(attention_weights, dim=-1)

        # Apply attention to values
        output = attention_weights @ value
        
        # Reshape output back to original dimensions
        output = output.transpose(1, 2)
        output = output.reshape(batch_size, sequence_length, embedding_dim)
        
        # Final output projection
        output = self.out_proj(output)

        return output


In [None]:
class AttentionBlock(nn.Module):
    """
    Attention block for spatial feature processing in convolutional networks.
    
    Args:
        channels: Number of input and output channels
    """
    def __init__(self, channels):
        super().__init__()
        self.groupnorm = nn.GroupNorm(32, channels)
        self.attention = SelfAttention(1, channels)

    def forward(self, x):
        """
        Apply self-attention to spatial features.
        
        Args:
            x: Input tensor of shape (batch_size, channels, height, width)
            
        Returns:
            Output tensor of shape (batch_size, channels, height, width)
        """
        residual = x.clone()

        # Normalize features
        x = self.groupnorm(x)

        batch_size, channels, height, width = x.shape

        # Reshape spatial dimensions into sequence for attention
        # (batch_size, channels, height, width) -> (batch_size, channels, height * width)
        x = x.view((batch_size, channels, height * width))

        # Transpose to sequence format
        # (batch_size, channels, height * width) -> (batch_size, height * width, channels)
        x = x.transpose(-1, -2)

        # Apply self-attention without causal masking
        # (batch_size, height * width, channels) -> (batch_size, height * width, channels)
        x = self.attention(x)

        # Transpose back to channel-first format
        # (batch_size, height * width, channels) -> (batch_size, channels, height * width)
        x = x.transpose(-1, -2)

        # Reshape back to spatial dimensions
        # (batch_size, channels, height * width) -> (batch_size, channels, height, width)
        x = x.view((batch_size, channels, height, width))

        # Add residual connection
        x += residual

        return x


In [None]:
class ResidualBlock(nn.Module):
    """
    Residual block with group normalization and convolutional layers.
    
    Args:
        in_channels: Number of input channels
        out_channels: Number of output channels
    """
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.groupnorm1 = nn.GroupNorm(32, in_channels)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)

        self.groupnorm2 = nn.GroupNorm(32, out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)

        # Use identity mapping if channels match, otherwise use 1x1 convolution
        if in_channels == out_channels:
            self.residual_layer = nn.Identity()
        else:
            self.residual_layer = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)

    def forward(self, x):
        """
        Forward pass through residual block.
        
        Args:
            x: Input tensor of shape (batch_size, in_channels, height, width)
            
        Returns:
            Output tensor of shape (batch_size, out_channels, height, width)
        """
        residual = x.clone()

        # First normalization and convolution
        x = self.groupnorm1(x)
        x = F.selu(x)
        x = self.conv1(x)
        
        # Second normalization and convolution
        x = self.groupnorm2(x)
        x = self.conv2(x)

        # Add residual connection
        return x + self.residual_layer(residual)


In [None]:
class Encoder(nn.Sequential):
    """
    Variational Autoencoder (VAE) Encoder with hierarchical feature extraction.
    
    Progressively downsamples input images while increasing channel depth,
    culminating in a latent space representation with mean and log variance.
    """
    def __init__(self):
        super().__init__(
            # Initial convolution: (batch_size, 3, height, width) -> (batch_size, 128, height, width)
            nn.Conv2d(3, 128, kernel_size=3, padding=1),

            # First residual block: (batch_size, 128, height, width) -> (batch_size, 128, height, width)
            ResidualBlock(128, 128),

            # Downsample by 2: (batch_size, 128, height, width) -> (batch_size, 128, height/2, width/2)
            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=0),

            # Increase channels to 256: (batch_size, 128, height/2, width/2) -> (batch_size, 256, height/2, width/2)
            ResidualBlock(128, 256),

            # Maintain resolution: (batch_size, 256, height/2, width/2) -> (batch_size, 256, height/2, width/2)
            ResidualBlock(256, 256),

            # Downsample by 2: (batch_size, 256, height/2, width/2) -> (batch_size, 256, height/4, width/4)
            nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=0),

            # Increase channels to 512: (batch_size, 256, height/4, width/4) -> (batch_size, 512, height/4, width/4)
            ResidualBlock(256, 512),

            # Maintain resolution: (batch_size, 512, height/4, width/4) -> (batch_size, 512, height/4, width/4)
            ResidualBlock(512, 512),

            # Downsample by 2: (batch_size, 512, height/4, width/4) -> (batch_size, 512, height/8, width/8)
            nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=0),

            # Deep processing at lowest resolution
            # (batch_size, 512, height/8, width/8) -> (batch_size, 512, height/8, width/8)
            ResidualBlock(512, 512),
            ResidualBlock(512, 512),
            ResidualBlock(512, 512),

            # Apply spatial attention: (batch_size, 512, height/8, width/8) -> (batch_size, 512, height/8, width/8)
            AttentionBlock(512),

            # Final residual processing: (batch_size, 512, height/8, width/8) -> (batch_size, 512, height/8, width/8)
            ResidualBlock(512, 512),

            # Normalize and activate: (batch_size, 512, height/8, width/8) -> (batch_size, 512, height/8, width/8)
            nn.GroupNorm(32, 512),
            nn.SiLU(),

            # Project to latent space: (batch_size, 512, height/8, width/8) -> (batch_size, 8, height/8, width/8)
            nn.Conv2d(512, 8, kernel_size=3, padding=1),

            # Final projection: (batch_size, 8, height/8, width/8) -> (batch_size, 8, height/8, width/8)
            nn.Conv2d(8, 8, kernel_size=1, padding=0)
        )
        
    def forward(self, x):
        """
        Encode input image to latent space representation.
        
        Args:
            x: Input tensor of shape (batch_size, 3, height, width)
            
        Returns:
            Latent representation of shape (batch_size, 4, height/8, width/8)
        """
        # Apply each module in sequence
        for module in self:
            # Add padding for downsampling convolutions to maintain proper dimensions
            if isinstance(module, nn.Conv2d) and module.stride == (2, 2):
                x = F.pad(x, (0, 1, 0, 1))  # Pad right and bottom
            x = module(x)

        # Split output into mean and log variance for reparameterization
        # (batch_size, 8, height/8, width/8) -> 2 tensors of shape (batch_size, 4, height/8, width/8)
        mean, log_variance = torch.chunk(x, 2, dim=1)

        # Clamp log variance to prevent numerical instability
        log_variance = torch.clamp(log_variance, -30, 20)

        # Reparameterization trick: sample from N(mean, variance)
        standard_deviation = torch.exp(0.5 * log_variance)
        epsilon = torch.randn_like(standard_deviation)
        latent = mean + epsilon * standard_deviation

        # Scale latent representation by constant factor for stable training
        latent *= 0.18215

        return latent
