<a href="https://colab.research.google.com/github/ekonishi8524/my-colab-notebooks/blob/main/mnist_ddpm_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np

In [None]:
class SimpleUNet(nn.Module):
 def __init__(self):
  super().__init__()
  self.down = nn.Sequential(nn.Conv2d(1, 32, 3, 1, 1), nn.ReLU(), \
                            nn.Conv2d(32, 32, 3, 1, 1), nn.ReLU())
  self.up = nn.Sequential(nn.Conv2d(32, 32, 3, 1, 1), nn.ReLU(), \
                          nn.Conv2d(32, 1, 3, 1, 1))
 def forward(self, x, t):
    d = self.down(x)
    out = self.up(d)
    return out

In [None]:
class Diffusion:
  def __init__(self, noise_steps=1000, beta_start=1e-4, beta_end=0.02,\
               image_size=28, device="cuda"):
      self.noise_steps = noise_steps
      self.beta_start = beta_start
      self.beta_end = beta_end
      self.image_size = image_size
      self.device = device

      self.beta = self.prepare_noise_schedule().to(device)
      self.alpha = 1. - self.beta
      self.alpha_hat = torch.cumprod(self.alpha, dim=0)

  def prepare_noise_schedule(self):
      return torch.linspace(self.beta_start, self.beta_end, self.noise_steps)

  def forward_diffusion(self, x0, t):
      sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None]
      sqrt_one_minus_alpha_hat = torch.sqrt(1. - self.alpha_hat[t])[:, None, \
                                                                    None, None]
      epsilon = torch.randn_like(x0)

      xt = sqrt_alpha_hat * x0 + sqrt_one_minus_alpha_hat * epsilon
      return xt, epsilon

  def sample(self, model, n_samples):
      model.eval()
      with torch.no_grad():
        x = torch.randn((n_samples, 1, self.image_size, self.image_size)).\
        to(self.device)
        for i in tqdm(reversed(range(self.noise_steps)), position=0):
            t = torch.full((n_samples,), i, dtype=torch.long).to(self.device)
            predicted_noise = model(x, t)
            alpha = self.alpha[t][:, None, None, None]
            alpha_hat = self.alpha_hat[t][:, None, None, None]
            beta = self.beta[t][:, None, None, None]

            mean = (1/torch.sqrt(alpha)) * (x - (beta / \
                                                 torch.sqrt(1 - alpha_hat))\
                                             * predicted_noise)

            if i>0:
              noise = torch.randn_like(x)
              sigma = torch.sqrt(beta)
              x = mean + sigma * noise
            else:
              x=mean
        model.train()
        x = torch.clamp(x, -1, 1)
        return x

In [None]:
def train_diffusion_model():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device for training: {device}")

    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.5,),(0.5,)),
    ])
    dataset = datasets.MNIST(root="dataset", train=True, download=True,\
                             transform=transform)
    dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

    model = SimpleUNet().to(device)
    diffusion = Diffusion(device=device)
    optimizer = optim.Adam(model.parameters(), lr=0.0002)
    criterion = nn.MSELoss()

    print("Starting training...")
    epoch_loss = []
    num_epochs = 200
    for epoch in range(num_epochs):
      running_loss = []
      for batch in tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        images, _ =batch
        images = images.to(device)

        t = torch.randint(0, diffusion.noise_steps, (images.shape[0],),\
                          device=device).long()

        noisy_images, actual_noise = diffusion.forward_diffusion(images, t)

        predicted_noise = model(noisy_images, t)

        loss = criterion(predicted_noise, actual_noise)
        optimizer.zero_grad()
        running_loss.append(loss.item())
        loss.backward()
        optimizer.step()
      avg_epoch_loss = sum(running_loss)/len(running_loss)
      epoch_loss.append(avg_epoch_loss)
      print(f"Epoch {epoch+1} average loss: {avg_epoch_loss:.4f}")
    print("Training finished.")
    return model, diffusion, epoch_loss

In [None]:
if __name__ == '__main__':

  trained_model, diffusion_process, epoch_losses = train_diffusion_model()

  print("Generating images from the trained model.")

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device for sampling: {device}")

sampled_images = diffusion_process.sample(trained_model, n_samples=8)

for i in range(8):
    image_tensor = sampled_images[i].cpu().squeeze()
    image_tensor = (image_tensor + 1)/2.0
    image_np = image_tensor.numpy()
    plt.imshow(image_np,cmap='gray')
    plt.axis('off')
    plt.title(f"Sample {i+1}")
    plt.show()

print("Sampled images shape:", sampled_images.shape)



In [None]:
plt.figure()
plt.plot(epoch_losses)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training Loss")
plt.show()