In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class ConditionalDiffusionModel(nn.Module):
    def __init__(self, input_dim, cond_dim, hidden_dim):
        super(ConditionalDiffusionModel, self).__init__()
        # Encoder to map input + condition to a hidden representation
        self.encoder = nn.Sequential(
            nn.Linear(input_dim + cond_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
        )
        # Decoder to predict the noise from the hidden representation
        self.decoder = nn.Sequential(
            nn.Linear(hidden_dim, input_dim),
        )

    def forward(self, x, condition):
        # Concatenate input and condition
        combined = torch.cat([x, condition], dim=-1)
        # Encode the combined input and condition
        hidden = self.encoder(combined)
        # Decode to predict noise
        predicted_noise = self.decoder(hidden)
        return predicted_noise

def diffusion_training_loop(model, data_loader, condition_loader, optimizer, num_timesteps, device):
    model.train()
    mse_loss = nn.MSELoss()

    for epoch in range(10):  # Number of epochs (can be adjusted)
        for (x, condition) in zip(data_loader, condition_loader):
            x = x.to(device)
            condition = condition.to(device)
            
            # Sample random noise
            noise = torch.randn_like(x).to(device)
            
            # Randomly choose a timestep
            t = torch.randint(0, num_timesteps, (x.size(0),), device=device).float() / num_timesteps
            t = t.view(-1, 1)  # Reshape for broadcasting

            # Add noise to input (forward diffusion process)
            noisy_x = x * torch.sqrt(1 - t) + noise * torch.sqrt(t)

            # Predict noise using the model
            predicted_noise = model(noisy_x, condition)

            # Calculate loss between predicted and true noise
            loss = mse_loss(predicted_noise, noise)

            # Backpropagation and optimization step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f"Epoch [{epoch+1}/10], Loss: {loss.item():.4f}")

# Example usage
if __name__ == "__main__":
    # Parameters
    input_dim = 2
    cond_dim = 1
    hidden_dim = 128
    num_timesteps = 1000
    batch_size = 32

    # Device configuration
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Create model and optimizer
    model = ConditionalDiffusionModel(input_dim, cond_dim, hidden_dim).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    # Create dummy data loaders (for demonstration purposes)
    data_loader = [torch.randn(batch_size, input_dim) for _ in range(100)]
    condition_loader = [torch.randn(batch_size, cond_dim) for _ in range(100)]

    # Train the diffusion model
    diffusion_training_loop(model, data_loader, condition_loader, optimizer, num_timesteps, device)

Epoch [1/10], Loss: 0.5884
Epoch [2/10], Loss: 0.7069
Epoch [3/10], Loss: 0.3952
Epoch [4/10], Loss: 0.4699
Epoch [5/10], Loss: 0.7835
Epoch [6/10], Loss: 0.6919
Epoch [7/10], Loss: 0.6965
Epoch [8/10], Loss: 0.4070
Epoch [9/10], Loss: 0.8690
Epoch [10/10], Loss: 0.6602
