<a href="https://colab.research.google.com/github/norhum/deep_learning/blob/main/mini_diffusion.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DiffusionModel(nn.Module):
  def __init__(self, input_dim, hidden_dim, timesteps=100):
    super().__init__()
    self.fc1 = nn.Linear(input_dim, hidden_dim)
    self.fc2 = nn.Linear(hidden_dim, input_dim)
    self.alphas = torch.linspace(1,0,timesteps)
    self.alphas_cumpord = torch.cumprod(self.alphas, dim=0)

  def noising_process(self, x, t):
    noise = torch.randn_like(x)
    alpha_t = self.alphas[t]
    return torch.sqrt(alpha_t) * x + torch.sqrt(1- alpha_t) * noise

  def denoising_process(self, x_noised, t):
    predicted_noise = self.forward(x_noised, t)

    alpha_t = self.alphas[t]
    denoised_x = denoised_x = (x_noised - torch.sqrt(1 - alpha_t) * predicted_noise) / torch.sqrt(alpha_t)
    return denoised_x

  def forward(self, x, t):
    x_noised = self.noising_process(x, t)
    x = F.relu(self.fc1(x))
    x = self.fc2(x)
    return x


In [None]:
class DiffusionModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, timesteps=100):
        super().__init__()
        self.timesteps = timesteps

        # Time embedding
        self.time_embed = nn.Sequential(
            nn.Linear(1, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

        # Main network
        self.net = nn.Sequential(
            nn.Linear(input_dim + hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim)
        )

        # Setup noise schedule
        self.register_buffer('alphas', torch.linspace(1, 0.0001, timesteps))
        self.register_buffer('alphas_cumprod', torch.cumprod(self.alphas, dim=0))

    def get_time_embedding(self, t):
        # Convert t to float and reshape
        t_float = t.float().view(-1, 1)
        return self.time_embed(t_float)

    def forward(self, x_noisy, t):
        # Embed time
        t_emb = self.get_time_embedding(t)

        # Expand time embedding to match batch size
        t_emb = t_emb.repeat(x_noisy.shape[0], 1)

        # Concatenate input with time embedding
        x = torch.cat([x_noisy, t_emb], dim=1)

        # Predict noise
        return self.net(x)

    def noising_process(self, x, t):
        alpha_t = self.alphas_cumprod[t]
        noise = torch.randn_like(x)
        return torch.sqrt(alpha_t) * x + torch.sqrt(1 - alpha_t) * noise, noise

    def denoising_process(self, x_noised, t):
        predicted_noise = self.forward(x_noised, t)
        alpha_t = self.alphas_cumprod[t]

        denoised_x = (x_noised - torch.sqrt(1 - alpha_t) * predicted_noise) / torch.sqrt(alpha_t)
        return denoised_x

