In [1]:
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from datasets.mfnet_dataset import MFNetDataset
from models.semantic_encoder import SemanticEncoder
from models.semantic_decoder import SemanticDecoder
from models.simple_diffusion import SimpleDiffusion
import torch.optim as optim
import torch.nn as nn
from config import root, max_samples, img_size, latent_dim, batch_size, num_epochs, lr

In [2]:
class NoisePredictor(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(latent_dim, latent_dim),
            nn.ReLU(),
            nn.Linear(latent_dim, latent_dim)
        )
    def forward(self, z, t):
        return self.net(z)

In [3]:
def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.repeat(3, 1, 1) if x.shape[0] == 1 else x)
    ])
    train_dataset = MFNetDataset(root=root, transform=transform, max_samples=max_samples)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    encoder = SemanticEncoder(latent_dim=latent_dim).to(device)
    decoder = SemanticDecoder(latent_dim=latent_dim).to(device)
    diffusion = SimpleDiffusion(latent_dim=latent_dim).to(device)
    noise_predictor = NoisePredictor(latent_dim).to(device)

    optimizer = optim.Adam(list(encoder.parameters()) + list(noise_predictor.parameters()), lr=lr)
    mse_loss = nn.MSELoss()
    timesteps = diffusion.timesteps

    for epoch in range(num_epochs):
        for batch in train_loader:
            rgb = batch['rgb'].to(device)
            z = encoder(rgb)
            t = torch.randint(0, timesteps, (z.size(0),), device=device)
            noise = torch.randn_like(z)
            z_noisy = diffusion.q_sample(z, t, noise)
            pred_noise = noise_predictor(z_noisy, t)
            loss = mse_loss(pred_noise, noise)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print(f"Epoch {epoch+1}: LDM loss={loss.item():.4f}")

if __name__ == "__main__":
    main()



Epoch 1: LDM loss=1.0076
