In [1]:
import torch
import torch.nn as nn 
import torch.nn.functional as F 

class SimpleTransformerBlock(nn.Module):
    def __init__(self,dim,heads=8):
        super().__init__()
        self.attn = nn.MultiheadAttention(dim ,heads, batch_first=True)
        self.ff = nn.Sequential(
            nn.Linear(dim, dim*4),
            nn.ReLU(),
            nn.Linear(dim * 4, dim)
        )
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)

    def forward(self,x):
        print(f"[SimpleTransformerBlock] Input: {x.shape}")
        attn_out, _ = self.attn(x,x,x)
        print(f"[SimpleTransformerBlock] After Attention: {attn_out.shape}")
        x = self.norm1(x + attn_out)
        ff_out = self.ff(x)
        print(f"[SimpleTransformerBlock] After FeedForward: {ff_out.shape}")
        x_out = self.norm2(x + ff_out)
        print(f"[SimpleTransformerBlock] Output: {x_out.shape}")
        return x_out

# diffusion noise schedule
def noise_schedule(t, beta_start=1e-4, beta_end=0.02, steps=1000):
    beta = torch.linspace(beta_start, beta_end, steps)
    alpha = 1 - beta
    alpha_bar = torch.cumprod(alpha, dim = 0)
    print(f"[Noise Schedule] alpha_bar[{t}] = {alpha_bar[t].item():.6f}")
    return alpha_bar[t]

# decoupled transformer diffusion model
class DecoupledDiffusionTransformer(nn.Module):
    def __init__(self, embed_dim, vocab_size, num_layers=4):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.transformer_x = nn.Sequential(*[SimpleTransformerBlock(embed_dim) for _ in range(num_layers)])
        self.transformer_cond = nn.Sequential(*[SimpleTransformerBlock(embed_dim) for _ in range(2)])
        self.output = nn.Linear(embed_dim, vocab_size)

    def forward(self, x_noisy, condition, t):
        print(f"[Model] Input IDs shape: {x_noisy.shape}")
        print(f"[Model] Condition IDs shape: {condition.shape}")

        x_embed = self.embed(x_noisy)
        cond_embed = self.embed(condition)
        print(f"[Model] Embedded input: {x_embed.shape}")
        print(f"[Model] Embedded condition: {cond_embed.shape}")

        noise_level = noise_schedule(t).sqrt().unsqueeze(-1).unsqueeze(-1).to(x_embed.device)
        print(f"[Model] Noise level shape: {noise_level.shape}")
        x_embed = x_embed * noise_level
        print(f"[Model] Noisy input embedding: {x_embed.shape}")

        cond_encoded = self.transformer_cond(cond_embed)
        print(f"[Model] Encoded condition: {cond_encoded.shape}")

        fused_input = x_embed + cond_encoded
        print(f"[Model] Fused input (x + condition): {fused_input.shape}")
        x_encoded = self.transformer_x(fused_input)
        print(f"[Model] Encoded x: {x_encoded.shape}")

        output_logits = self.output(x_encoded)
        print(f"[Model] Output logits: {output_logits.shape}")
        return output_logits

# Test the model
batch_size = 4
seq_len = 16
vocab_size = 10000
embed_dim = 256
timestep = 100

model = DecoupledDiffusionTransformer(embed_dim , vocab_size)

input_ids = torch.randint(0, vocab_size, (batch_size, seq_len)) #fake data
condition_ids = torch.randint(0, vocab_size, (batch_size, seq_len))

output_logits = model(input_ids , condition_ids , t=timestep) #forward pass with noisy input


[Model] Input IDs shape: torch.Size([4, 16])
[Model] Condition IDs shape: torch.Size([4, 16])
[Model] Embedded input: torch.Size([4, 16, 256])
[Model] Embedded condition: torch.Size([4, 16, 256])
[Noise Schedule] alpha_bar[100] = 0.895141
[Model] Noise level shape: torch.Size([1, 1])
[Model] Noisy input embedding: torch.Size([4, 16, 256])
[SimpleTransformerBlock] Input: torch.Size([4, 16, 256])
[SimpleTransformerBlock] After Attention: torch.Size([4, 16, 256])
[SimpleTransformerBlock] After FeedForward: torch.Size([4, 16, 256])
[SimpleTransformerBlock] Output: torch.Size([4, 16, 256])
[SimpleTransformerBlock] Input: torch.Size([4, 16, 256])
[SimpleTransformerBlock] After Attention: torch.Size([4, 16, 256])
[SimpleTransformerBlock] After FeedForward: torch.Size([4, 16, 256])
[SimpleTransformerBlock] Output: torch.Size([4, 16, 256])
[Model] Encoded condition: torch.Size([4, 16, 256])
[Model] Fused input (x + condition): torch.Size([4, 16, 256])
[SimpleTransformerBlock] Input: torch.Size(

DDT (Decoupled Diffusion Transformer)
1. a model that adds noise to input data progressively during training (diffusion)
2. learns to remove that noise using transformer architecture
3. separates the roles of input and conditioning informaton into decoupled transformers
    * one transformer (transformer_x) learns how to denoise the noisy input
    * another transformer (transformer_cond) handles additional context or conditioning information

Diffusion Process
1. start with clean input data
2. add gaussian noise step by step -> simulate a "diffusion" process
3. train a model to reverse this noise step by step (denoise)

In Action
* the condition provides context like a sentence prompt or a label
* the input is a noisy version of what we want to generate
* the model learns to denoise th einput guided by the condition

