# Arquitectura DiT

In [1]:
import torch
import torch.nn as nn

In [2]:
class DiTBlock(nn.Module):
    def __init__(self, hidden_size, num_heads):
        super().__init__()
        self.norm1 = nn.LayerNorm(hidden_size)
        self.attn = nn.MultiheadAttention(hidden_size, num_heads, batch_first=True)
        self.norm2 = nn.LayerNorm(hidden_size)
        self.mlp = nn.Sequential(
            nn.Linear(hidden_size, 4 * hidden_size),
            nn.GELU(),
            nn.Linear(4 * hidden_size, hidden_size)
        )

    def forward(self, x):
        # x: (B, N, hidden_size)
        x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]  # (B, N, hidden_size)
        x = x + self.mlp(self.norm2(x))  # (B, N, hidden_size)
        return x  # (B, N, hidden_size)

In [3]:
class DiT(nn.Module):
    def __init__(self, hidden_size, num_heads, num_blocks, patch_size, in_channels=4):
        super().__init__()
        self.patch_size = patch_size
        self.in_channels = in_channels
        self.patch_embed = nn.Linear(in_channels * patch_size * patch_size, hidden_size)
        self.blocks = nn.ModuleList([DiTBlock(hidden_size, num_heads) for _ in range(num_blocks)])
        self.final_norm = nn.LayerNorm(hidden_size)
        self.head = nn.Linear(hidden_size, in_channels * patch_size * patch_size)

    def forward(self, x, t):
        # x: (B, C, H, W), t: (B,)
        B, C, H, W = x.shape
        assert C == self.in_channels, f"Expected {self.in_channels} channels, got {C}"

        x = x.reshape(B, C, H // self.patch_size, self.patch_size, W // self.patch_size, self.patch_size)
        x = x.permute(0, 2, 4, 1, 3, 5).reshape(B, -1, C * self.patch_size * self.patch_size)  # (B, N, C * patch_size * patch_size), where N = (H * W) / (patch_size * patch_size)

        x = self.patch_embed(x)  # (B, N, hidden_size)

        t = t.unsqueeze(1).expand(-1, x.size(1))  # (B, N)
        x = x + t.unsqueeze(-1)  # (B, N, hidden_size)

        for block in self.blocks:
            x = block(x)  # (B, N, hidden_size)

        x = self.final_norm(x)  # (B, N, hidden_size)

        x = self.head(x)  # (B, N, C * patch_size * patch_size)

        x = x.reshape(B, H // self.patch_size, W // self.patch_size, C, self.patch_size, self.patch_size)
        x = x.permute(0, 3, 1, 4, 2, 5).reshape(B, C, H, W)

        return x  # (B, C, H, W)

In [4]:
# Ejemplo:
hidden_size = 48
num_heads = 3
num_blocks = 2
patch_size = 2
in_channels = 4

model = DiT(hidden_size, num_heads, num_blocks, patch_size, in_channels)
x = torch.randn(1, in_channels, 32, 32)  # (1, 4, 32, 32)
t = torch.randint(0, 1000, (1,))  # (1,)
output = model(x, t)

assert output.shape == torch.Size([1, 4, 32, 32])