In [2]:
def q_sample_masked(x_0, t, T, mask_token_id):
    """
    Discrete diffusion process for LlaDA.
    Progressively masks tokens based on timestep t.
    """
    # Calculate masking ratio (linearly increasing)
    mask_ratio = t / T
    
    # Create a mask based on the ratio
    mask = torch.rand_like(x_0.float()) < mask_ratio
    
    # Apply mask: Replace tokens with mask_token_id
    x_t = x_0.clone()
    x_t[mask] = mask_token_id
    
    return x_t

In [3]:
import torch
import torch.nn.functional as F

def calculate_sequence_probability(model, input_ids):
    """
    Implements p(x) = p(x1) * product_{n=2}^N p(xn | x1:n-1)
    In practice, we use log-probabilities to avoid underflow:
    log p(x) = sum_{n=1}^N log p(xn | x1:n-1)
    """
    model.eval()
    with torch.no_grad():
        # 1. Forward pass to get logits for all positions
        outputs = model(input_ids)
        logits = outputs.logits # Shape: (batch, seq_len, vocab_size)
        
        # 2. Shift logits and targets to align p(xn | x_{1:n-1})
        # Logits at index i predict token at index i+1
        shift_logits = logits[:, :-1, :].contiguous()
        shift_labels = input_ids[:, 1:].contiguous()
        
        # 3. Calculate log p(xn | history) using CrossEntropy
        # CrossEntropyLoss(reduction='none') returns -log p for each token
        log_probs = -F.cross_entropy(
            shift_logits.view(-1, shift_logits.size(-1)), 
            shift_labels.view(-1), 
            reduction='none'
        )
        
        # 4. Add the probability of the first token p(x1)
        # Usually assumed to be 1.0 (0.0 log) if the sequence starts with BOS
        # Or calculated via a start-of-sequence distribution.
        
        # Total sequence log-probability (Sum of logs = Log of product)
        total_log_prob = log_probs.sum()
        
        return total_log_prob.exp(), total_log_prob

def autoregressive_generation(model, tokenizer, prompt, max_n=50):
    """
    Demonstrates the iterative application of p(xn | x1:n-1)
    to generate a sequence.
    """
    input_ids = tokenizer.encode(prompt, return_tensors="pt")
    
    for _ in range(max_n):
        # Current sequence: x1:n-1
        outputs = model(input_ids)
        
        # Get logits for the very last token: p(xn | x1:n-1)
        next_token_logits = outputs.logits[:, -1, :]
        
        # Sample or take Argmax to get xn
        next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(0)
        
        # Append xn to history: x1:n = [x1:n-1, xn]
        input_ids = torch.cat([input_ids, next_token], dim=-1)
        
        if next_token == tokenizer.eos_token_id:
            break
            
    return tokenizer.decode(input_ids[0])

In [4]:
import torch
import torch.nn.functional as F

def diffusion_loss_fn(model, x_0, mask_token_id):
    """
    Implements Equation (3):
    L(θ) = -E [ w(t) * sum_{n=1}^N 1_{xt_n=MASK} * log pθ(x0_n | xt) ]
    """
    batch_size, seq_len = x_0.shape
    device = x_0.device

    # 1. Sample t ~ U(0, 1)
    t = torch.rand(batch_size, device=device)

    # 2. Generate xt ~ q(xt | x0) 
    # For α_t = 1 - t, we mask tokens with probability t
    noise_probs = t.view(-1, 1).expand(batch_size, seq_len)
    mask_indices = torch.bernoulli(noise_probs).bool()
    
    x_t = x_0.clone()
    x_t[mask_indices] = mask_token_id

    # 3. Get model predictions pθ(x0 | xt)
    outputs = model(x_t)
    logits = outputs.logits  # Shape: (batch, seq_len, vocab_size)

    # 4. Calculate Cross-Entropy Loss
    # Flatten for F.cross_entropy
    ce_loss = F.cross_entropy(
        logits.view(-1, logits.size(-1)), 
        x_0.view(-1), 
        reduction='none'
    ).view(batch_size, seq_len)

    # 5. Apply Indicator Function: 1_{xt_n=MASK}
    # Only compute loss for positions that were masked
    masked_loss = ce_loss * mask_indices.float()

    # 6. Apply Reweighting Term: w(t) = 1/t
    # We use a small epsilon to avoid division by zero
    w_t = 1.0 / torch.clamp(t, min=1e-5)
    
    # Final Weighted Loss (Sum over N, Mean over Batch)
    loss = (w_t.view(-1, 1) * masked_loss).sum(dim=1).mean()

    return loss

In [5]:
import torch

# Create a sample tensor (e.g., 2 sentences, 4 tokens each)
# Shape: (Batch=2, Seq_Len=4)
x = torch.tensor([[10, 20, 30, 40], [50, 60, 70, 80]])
print(f"Original Shape: {x.shape}")

# --- VIEW ---
# Flatten the tensor into one long sequence
# Shape: (8)
flattened = x.view(8) 
# OR use -1 to auto-calculate: x.view(-1)
print(f"Flattened (view): {flattened.shape}")
print(f"Flattened (view): {flattened}")

# Reshape into (4, 2)
reshaped = x.view(4, 2)
print(f"Reshaped (4x2): {reshaped.shape}")
print(f"Reshaped (4x2): {reshaped}")

# --- UNSQUEEZE (Adding Dimensions) ---
# Add a dimension at the start (often needed for model input)
# Shape: (1, 2, 4) -> (Extra_Dim, Batch, Seq_Len)
unsqueezed = x.unsqueeze(0)
print(f"Unsqueezed at 0: {unsqueezed.shape}")
print(f"Unsqueezed at 0: {unsqueezed}")

# Add a dimension in the middle
# Shape: (2, 1, 4)
middle_dim = x.unsqueeze(1)
print(f"Unsqueezed at 1: {middle_dim.shape}")
print(f"Unsqueezed at 1: {middle_dim}")

# --- SQUEEZE (Removing Dimensions) ---
# Create a tensor with many 'empty' dimensions
# Shape: (1, 2, 1, 4)
y = torch.randn(1, 2, 1, 4)
print(f"\nComplex Tensor y: {y.shape}")
print(f"\nComplex Tensor y: {y}")
# Remove all size-1 dimensions
# Shape: (2, 4)
y_simple = y.squeeze()
print(f"Fully Squeezed: {y_simple.shape}")
print(f"Fully Squeezed: {y_simple}")
# Remove only the dimension at index 0
# Shape: (2, 1, 4)
y_specific = y.squeeze(0)
print(f"Squeezed at index 0: {y_specific.shape}")

Original Shape: torch.Size([2, 4])
Flattened (view): torch.Size([8])
Flattened (view): tensor([10, 20, 30, 40, 50, 60, 70, 80])
Reshaped (4x2): torch.Size([4, 2])
Reshaped (4x2): tensor([[10, 20],
        [30, 40],
        [50, 60],
        [70, 80]])
Unsqueezed at 0: torch.Size([1, 2, 4])
Unsqueezed at 0: tensor([[[10, 20, 30, 40],
         [50, 60, 70, 80]]])
Unsqueezed at 1: torch.Size([2, 1, 4])
Unsqueezed at 1: tensor([[[10, 20, 30, 40]],

        [[50, 60, 70, 80]]])

Complex Tensor y: torch.Size([1, 2, 1, 4])

Complex Tensor y: tensor([[[[ 0.4877,  0.9192,  0.0088,  0.8201]],

         [[ 0.8905,  0.6694, -1.3324, -0.0963]]]])
Fully Squeezed: torch.Size([2, 4])
Fully Squeezed: tensor([[ 0.4877,  0.9192,  0.0088,  0.8201],
        [ 0.8905,  0.6694, -1.3324, -0.0963]])
Squeezed at index 0: torch.Size([2, 1, 4])


In [23]:
import torch
import torch.nn as nn

# 1. Define the layer
# 1000 words in vocab, each represented by a 100-dim vector
embedding = nn.Embedding(1000, 100)

# 2. Create some "input_ids" (e.g., a batch of 2 sentences, 4 tokens each)
input_ids = torch.tensor([[1, 42, 5, 9], [10, 3, 999, 0]])

# 3. Pass through the embedding layer
embedded_output = embedding(input_ids)

print(f"Input Shape:  {input_ids}")      # torch.Size([2, 4])
print(f"Output Shape: {embedded_output}") # torch.Size([2, 4, 100])

Input Shape:  tensor([[  1,  42,   5,   9],
        [ 10,   3, 999,   0]])
Output Shape: tensor([[[ 4.1667e-01,  9.6009e-01, -3.7386e-01,  9.4238e-01, -6.1474e-01,
          -8.7216e-01,  6.3071e-01,  1.5478e+00,  3.3835e-01, -4.1710e-01,
          -4.6828e-02,  9.9981e-01, -6.6385e-01,  5.2313e-01,  6.0928e-01,
          -2.2387e-01, -1.0923e+00,  1.8785e+00, -8.2437e-01,  1.3850e-01,
           8.7227e-01,  9.0803e-01,  7.3128e-03,  7.4487e-01,  6.6348e-02,
          -1.6220e-02, -1.5093e+00, -1.9420e+00,  7.0590e-01, -9.9569e-01,
           2.8487e-01, -1.0313e+00,  1.1899e+00,  6.2449e-01,  6.7444e-01,
          -7.3596e-01,  9.8426e-01,  5.4845e-01, -1.0978e+00,  5.0840e-01,
          -1.0143e+00, -8.4057e-01,  1.1627e+00,  1.0908e+00,  1.0844e+00,
          -3.8818e-01,  6.9291e-01, -1.1719e+00,  1.2504e+00,  1.0046e+00,
          -2.8554e-01,  8.5279e-01, -2.5443e-03,  7.3911e-02, -2.0046e-01,
          -1.0361e+00, -6.4444e-01,  6.6033e-01,  1.6177e+00, -7.6691e-01,
          -