In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from utils import SinusoidalEmbedding, get_loss, sample_timestep
import matplotlib.pyplot as plt

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        assert d_model % n_heads == 0
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)
        
    def forward(self, x):
        batch_size, seq_len, d_model = x.shape
        
        # Generate Q, K, V
        Q = self.w_q(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        K = self.w_k(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        V = self.w_v(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        
        # Attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)
        attn_weights = F.softmax(scores, dim=-1)
        attn_output = torch.matmul(attn_weights, V)
        
        # Concatenate heads
        attn_output = attn_output.transpose(1, 2).contiguous().view(
            batch_size, seq_len, d_model
        )
        
        return self.w_o(attn_output)

class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        return self.linear2(self.dropout(F.relu(self.linear1(x))))

class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        self.attention = MultiHeadAttention(d_model, n_heads)
        self.feed_forward = FeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        # Self-attention with residual connection
        attn_output = self.attention(x)
        x = self.norm1(x + self.dropout(attn_output))
        
        # Feed-forward with residual connection
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        
        return x

class DiTMNIST(nn.Module):
    def __init__(self, img_size=28, patch_size=4, d_model=256, n_heads=8, n_layers=6, 
                 d_ff=1024, dropout=0.1, time_embed_dim=128):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.d_model = d_model
        
        # Calculate number of patches
        self.n_patches = (img_size // patch_size) ** 2
        self.patch_dim = patch_size * patch_size  # For grayscale images
        
        # Patch embedding
        self.patch_embedding = nn.Linear(self.patch_dim, d_model)
        
        # Time embedding
        self.time_embedding = SinusoidalEmbedding(time_embed_dim)
        self.time_mlp = nn.Sequential(
            nn.Linear(time_embed_dim, d_model),
            nn.SiLU(),
            nn.Linear(d_model, d_model)
        )
        
        # Positional encoding
        self.pos_encoding = nn.Parameter(torch.randn(self.n_patches, d_model))
        
        # Transformer blocks
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(d_model, n_heads, d_ff, dropout)
            for _ in range(n_layers)
        ])
        
        # Output projection
        self.output_projection = nn.Linear(d_model, self.patch_dim)
        self.dropout = nn.Dropout(dropout)
        
    def patchify(self, x):
        # x: (batch_size, 1, img_size, img_size)
        batch_size = x.shape[0]
        x = x.view(batch_size, 1, self.img_size // self.patch_size, self.patch_size,
                   self.img_size // self.patch_size, self.patch_size)
        x = x.permute(0, 2, 4, 1, 3, 5).contiguous()
        x = x.view(batch_size, self.n_patches, self.patch_dim)
        return x
    
    def unpatchify(self, x):
        # x: (batch_size, n_patches, patch_dim)
        batch_size = x.shape[0]
        h = w = self.img_size // self.patch_size
        x = x.view(batch_size, h, w, 1, self.patch_size, self.patch_size)
        x = x.permute(0, 3, 1, 4, 2, 5).contiguous()
        x = x.view(batch_size, 1, self.img_size, self.img_size)
        return x
        
    def forward(self, x, t):
        # x: (batch_size, 1, img_size, img_size)
        # t: (batch_size,) timestep
        
        batch_size = x.shape[0]
        
        # Convert to patches
        x = self.patchify(x)  # (batch_size, n_patches, patch_dim)
        
        # Patch embedding
        x = self.patch_embedding(x)  # (batch_size, n_patches, d_model)
        
        # Add positional encoding
        x = x + self.pos_encoding
        
        # Get time embedding and add to each patch
        time_emb = self.time_embedding(t)  # (batch_size, time_embed_dim)
        time_emb = self.time_mlp(time_emb)  # (batch_size, d_model)
        time_emb = time_emb.unsqueeze(1).expand(-1, self.n_patches, -1)  # (batch_size, n_patches, d_model)
        
        x = x + time_emb
        x = self.dropout(x)
        
        # Pass through transformer blocks
        for block in self.transformer_blocks:
            x = block(x)
        
        # Project back to patch dimension
        x = self.output_projection(x)  # (batch_size, n_patches, patch_dim)
        
        # Convert back to image
        x = self.unpatchify(x)  # (batch_size, 1, img_size, img_size)
        
        return x

# Load MNIST dataset
def get_mnist_dataloader(batch_size=32):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))  # Normalize to [-1, 1]
    ])
    
    dataset = torchvision.datasets.MNIST(
        root='../data', 
        train=True, 
        download=True, 
        transform=transform
    )
    
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    return dataloader

# Training function
def train_dit_mnist():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Initialize model
    model = DiTMNIST(
        img_size=28,
        patch_size=4,
        d_model=256,
        n_heads=8,
        n_layers=6,
        d_ff=1024,
        dropout=0.1,
        time_embed_dim=128
    ).to(device)
    
    # Training parameters
    learning_rate = 1e-4
    epochs = 10
    batch_size = 32
    
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    dataloader = get_mnist_dataloader(batch_size)
    
    print(f"Training DiT on MNIST with {sum(p.numel() for p in model.parameters())} parameters")
    
    # Define diffusion parameters (simplified for example)
    T = 1000
    
    for epoch in range(epochs):
        total_loss = 0
        for batch_idx, (images, _) in enumerate(dataloader):
            images = images.to(device)
            
            # Sample random timesteps
            t = torch.randint(0, T, (images.shape[0],), device=device)
            
            # Calculate loss using the diffusion framework
            loss = get_loss(model, images, t)
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            
            if batch_idx % 500 == 0:
                print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.6f}')
        
        avg_loss = total_loss / len(dataloader)
        print(f'Epoch {epoch} completed, Average Loss: {avg_loss:.6f}')
    
    return model

# Sampling function for MNIST
@torch.no_grad()
def sample_dit_mnist(model, n_samples=8, timesteps=1000):
    device = next(model.parameters()).device
    
    # Start with random noise
    x = torch.randn(n_samples, 1, 28, 28, device=device)
    
    for i in reversed(range(timesteps)):
        t = torch.full((n_samples,), i, device=device)
        x = sample_timestep(model, x, t)
    
    return x

# Visualization function
def visualize_samples(samples, n_samples=8):
    fig, axes = plt.subplots(1, n_samples, figsize=(n_samples * 2, 2))
    
    for i in range(n_samples):
        img = samples[i].cpu().squeeze().numpy()
        img = (img + 1) / 2  # Denormalize from [-1, 1] to [0, 1]
        axes[i].imshow(img, cmap='gray')
        axes[i].axis('off')
        axes[i].set_title(f'Sample {i+1}')
    
    plt.tight_layout()
    plt.show()

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# Train the model
model = train_dit_mnist()

# Generate samples
print("Generating MNIST samples...")
with torch.no_grad():
    samples = sample_dit_mnist(model, n_samples=8, timesteps=100)  # Using fewer timesteps for faster sampling
    visualize_samples(samples)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 40463806.31it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 2587042.84it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz





Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 20328898.87it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 2344969.08it/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw

Training DiT on MNIST with 4858384 parameters
Epoch 0, Batch 0, Loss: 1.349607
