In [4]:
import torch
import torch.nn as nn

# Hyperparameters
timesteps = 1000  # Number of diffusion steps


# Beta schedule (linear or cosine)
beta_start = 0.0001
beta_end = 0.02
betas = torch.linspace(beta_start, beta_end, timesteps)

# Calculate alpha and cumulative alpha
alphas = 1.0 - betas
alpha_cumprod = torch.cumprod(alphas, dim=0)
alpha_cumprod_prev = torch.cat([torch.tensor([1.0]), alpha_cumprod[:-1]])
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
sqrt_alphas_cumprod = torch.sqrt(alpha_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alpha_cumprod)
posterior_variance = betas * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)



In [None]:
@torch.no_grad()
def sample_timestep(x, t):
    """
    Calls the model to predict the noise in the image and returns 
    the denoised image. 
    Applies noise to this image, if we are not in the last step yet.
    """
    betas_t = get_index_from_list(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
        sqrt_one_minus_alphas_cumprod, t, x.shape
    )
    sqrt_recip_alphas_t = get_index_from_list(sqrt_recip_alphas, t, x.shape)
    
    # Call model (current image - noise prediction)
    model_mean = sqrt_recip_alphas_t * (
        x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
    )
    posterior_variance_t = get_index_from_list(posterior_variance, t, x.shape)
    
    if t == 0:
        return model_mean
    else:
        noise = torch.randn_like(x)
        return model_mean + torch.sqrt(posterior_variance_t) * noise

In [None]:
def sample(model, shape):
    x = torch.randn(shape)  # Start with noise
    for t in reversed(range(timesteps)):
        t_tensor = torch.tensor([t])
        noise_pred = model(x, t_tensor)
        alpha_t = alpha_cumprod[t]
        alpha_t_prev = alpha_cumprod_prev[t]
        beta_t = betas[t]

        # Reverse step
        mean = (x - beta_t / torch.sqrt(1 - alpha_t) * noise_pred) / torch.sqrt(alpha_t_prev)
        noise = torch.randn_like(x) if t > 0 else 0  # Add noise unless final step
        x = mean + torch.sqrt(beta_t) * noise
    return x


In [10]:
import math

class TimestepEmbedding(nn.Module):
    def __init__(self, embedding_dim):
        super().__init__()
        self.embedding_dim = embedding_dim

    def forward(self, t):
        half_dim = self.embedding_dim // 2
        # Create frequency bands
        freqs = torch.exp(-math.log(10000) * torch.arange(half_dim).float() / half_dim)
        angles = t[:, None].float() * freqs[None, :]  # Shape: (batch_size, half_dim)
        # Sinusoidal embedding
        embedding = torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1)
        return embedding

In [5]:
def forward_diffusion(x, t):
    noise = torch.randn_like(x)
    sqrt_alpha_cumprod = torch.sqrt(alpha_cumprod[t]).view(-1, 1, 1)
    sqrt_one_minus_alpha_cumprod = torch.sqrt(1 - alpha_cumprod[t]).view(-1, 1, 1)
    return sqrt_alpha_cumprod * x + sqrt_one_minus_alpha_cumprod * noise, noise


In [9]:
(1,) * (4 - 1)

(1, 1, 1)

In [11]:
class DenoisingModel(nn.Module):
    def __init__(self, input_dim, embedding_dim, hidden_dim = 128):
        super().__init__()
        self.input_dim = input_dim
        self.embedding_dim = embedding_dim
        self.timestep_embedding = TimestepEmbedding(embedding_dim)

        # Model layers
        self.net = nn.Sequential(
            nn.Linear(input_dim + embedding_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim),
        )

    def forward(self, x, t):
        # Embed the timestep
        t_emb = self.timestep_embedding(t)
        # Concatenate the timestep embedding with the input
        x_t = torch.cat([x, t_emb], dim=-1)
        return self.net(x_t)

In [12]:
def diffusion_loss(model, x, t):
    noisy_x, noise = forward_diffusion(x, t)
    predicted_noise = model(noisy_x, t)
    return nn.MSELoss()(predicted_noise, noise)


In [None]:
import torch.optim as optim

epochs = 1

# Instantiate the model
input_dim = 10  # Example input dimension
embedding_dim = 16  # Embedding dimension for timestep
model = DenoisingModel(input_dim, embedding_dim)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Training loop
for epoch in range(epochs):
    optimizer.zero_grad()

    # Random data batch
    x = torch.randn(32, input_dim)
    t = torch.randint(0, timesteps, (32,))  # Random timesteps

    # Compute loss
    loss = diffusion_loss(model, x, t)
    loss.backward()
    optimizer.step()

    print(f"Epoch {epoch+1}, Loss: {loss.item()}")