In [1]:
import torch
torch.cuda.is_available(), torch.cuda.get_device_name(0)


(True, 'NVIDIA GeForce RTX 4050 Laptop GPU')

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image
import numpy as np
import os

torch.manual_seed(42)
np.random.seed(42)

class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = np.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings


class Block(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.time_mlp = nn.Linear(time_emb_dim, out_ch)
        self.norm1 = nn.GroupNorm(8, out_ch)
        self.norm2 = nn.GroupNorm(8, out_ch)
        self.relu = nn.ReLU()
        
        if in_ch != out_ch:
            self.residual_conv = nn.Conv2d(in_ch, out_ch, 1)
        else:
            self.residual_conv = nn.Identity()

    def forward(self, x, t):
        h = self.conv1(x)
        h = self.norm1(h)
        h = self.relu(h)
        
        time_emb = self.relu(self.time_mlp(t))
        h = h + time_emb[:, :, None, None]
        
        h = self.conv2(h)
        h = self.norm2(h)
        h = self.relu(h)
        
        return h + self.residual_conv(x)


class UNet(nn.Module):
    def __init__(self, in_channels=3, model_channels=64, time_emb_dim=256):
        super().__init__()
        
        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(time_emb_dim),
            nn.Linear(time_emb_dim, time_emb_dim),
            nn.ReLU()
        )
        
        self.down1 = Block(in_channels, model_channels, time_emb_dim)
        self.down2 = Block(model_channels, model_channels * 2, time_emb_dim)
        self.down3 = Block(model_channels * 2, model_channels * 4, time_emb_dim)
        
        self.pool = nn.MaxPool2d(2)
        
        self.bottleneck = Block(model_channels * 4, model_channels * 4, time_emb_dim)
        
        self.up1 = nn.ConvTranspose2d(model_channels * 4, model_channels * 4, 2, stride=2)
        self.up_block1 = Block(model_channels * 8, model_channels * 2, time_emb_dim)
        
        self.up2 = nn.ConvTranspose2d(model_channels * 2, model_channels * 2, 2, stride=2)
        self.up_block2 = Block(model_channels * 4, model_channels, time_emb_dim)
        
        self.up3 = nn.ConvTranspose2d(model_channels, model_channels, 2, stride=2)
        self.up_block3 = Block(model_channels * 2, model_channels, time_emb_dim)
        
        self.out = nn.Conv2d(model_channels, in_channels, 1)

    def forward(self, x, t):
        t = self.time_mlp(t)
        
        d1 = self.down1(x, t)  # 28x28
        d2 = self.down2(self.pool(d1), t)  # 14x14
        d3 = self.down3(self.pool(d2), t)  # 7x7
        
        b = self.bottleneck(self.pool(d3), t)  # 3x3
        
        u1 = self.up1(b)  # 6x6
        # to match d3 size (7x7)
        u1 = F.pad(u1, (0, 1, 0, 1))  # pad right and bottom
        u1 = torch.cat([u1, d3], dim=1)
        u1 = self.up_block1(u1, t)  # 7x7
        
        u2 = self.up2(u1)  # 14x14
        u2 = torch.cat([u2, d2], dim=1)
        u2 = self.up_block2(u2, t)  # 14x14
        
        u3 = self.up3(u2)  # 28x28
        u3 = torch.cat([u3, d1], dim=1)
        u3 = self.up_block3(u3, t)  # 28x28
        
        return self.out(u3)


class DiffusionModel:
    def __init__(self, timesteps=1000, beta_start=0.0001, beta_end=0.02, device='cuda'):
        self.timesteps = timesteps
        self.device = device
        
        self.betas = torch.linspace(beta_start, beta_end, timesteps).to(device)
        self.alphas = 1 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0)
        
        # for diffusion q(x_t | x_{t-1})
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - self.alphas_cumprod)
        
        # posterior q(x_{t-1} | x_t, x_0)
        self.posterior_variance = self.betas * (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod)

    def q_sample(self, x_start, t, noise=None):
        if noise is None:
            noise = torch.randn_like(x_start)
        
        sqrt_alphas_cumprod_t = self.sqrt_alphas_cumprod[t][:, None, None, None]
        sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t][:, None, None, None]
        
        return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise

    def p_losses(self, model, x_start, t, noise=None):
        if noise is None:
            noise = torch.randn_like(x_start)
        
        x_noisy = self.q_sample(x_start, t, noise)
        predicted_noise = model(x_noisy, t)
        
        loss = F.mse_loss(noise, predicted_noise)
        return loss

    @torch.no_grad()
    def p_sample(self, model, x, t, t_index):
        betas_t = self.betas[t][:, None, None, None]
        sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t][:, None, None, None]
        sqrt_recip_alphas_t = torch.sqrt(1.0 / self.alphas[t])[:, None, None, None]
        
        predicted_noise = model(x, t)
        
        model_mean = sqrt_recip_alphas_t * (
            x - betas_t * predicted_noise / sqrt_one_minus_alphas_cumprod_t
        )
        
        if t_index == 0:
            return model_mean
        else:
            posterior_variance_t = self.posterior_variance[t][:, None, None, None]
            noise = torch.randn_like(x)
            return model_mean + torch.sqrt(posterior_variance_t) * noise

    @torch.no_grad()
    def sample(self, model, image_size, batch_size=16, channels=3):
        model.eval()
        shape = (batch_size, channels, image_size, image_size)
        img = torch.randn(shape, device=self.device)
        
        for i in reversed(range(0, self.timesteps)):
            img = self.p_sample(
                model,
                img,
                torch.full((batch_size,), i, device=self.device, dtype=torch.long),
                i
            )
        
        return img


def train(model, diffusion, dataloader, optimizer, device, epochs=50):
    model.train()
    
    for epoch in range(epochs):
        total_loss = 0
        
        for batch_idx, (images, _) in enumerate(dataloader):
            images = images.to(device)
            optimizer.zero_grad()
            
            t = torch.randint(0, diffusion.timesteps, (images.shape[0],), device=device).long()
            
            loss = diffusion.p_losses(model, images, t)
            
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
            if (batch_idx + 1) % 100 == 0:
                print(f"Epoch {epoch+1}, Step {batch_idx+1}/{len(dataloader)}, Loss: {loss.item():.4f}")


def generate_images(model, diffusion, num_images=10000, batch_size=100, save_dir='generated_images'):
    os.makedirs(save_dir, exist_ok=True)
    
    num_batches = num_images // batch_size
    
    for i in range(num_batches):
        samples = diffusion.sample(model, image_size=28, batch_size=batch_size, channels=3)
        samples = (samples + 1) / 2  # from [-1, 1] to [0, 1]
        samples = torch.clamp(samples, 0, 1)
        
        for j in range(batch_size):
            img_idx = i * batch_size + j + 1
            save_image(samples[j], os.path.join(save_dir, f'{img_idx:05d}.png'))
        
        print(f'Generated batch {i+1}/{num_batches}')


if __name__ == '__main__':
    BATCH_SIZE = 128
    EPOCHS = 50
    LEARNING_RATE = 2e-4
    TIMESTEPS = 1000
    IMAGE_SIZE = 28
    CHANNELS = 3
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'device: {device}')
    
    transform = transforms.Compose([
        transforms.Grayscale(num_output_channels=3),  # cnvrt nvert to RGB
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])  # scale to [-1, 1]
    ])
    
    train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
    
    model = UNet(in_channels=CHANNELS, model_channels=64, time_emb_dim=256).to(device)
    diffusion = DiffusionModel(timesteps=TIMESTEPS, device=device)
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    
    print(f'model parameters: {sum(p.numel() for p in model.parameters()):,}')
    
    print('starting training..')
    train(model, diffusion, train_loader, optimizer, device, epochs=EPOCHS)
    
    torch.save(model.state_dict(), 'diffusion_model_final.pth')
    print('Final model saved!')
    
    print('generating images for FID evaluation...')
    generate_images(model, diffusion, num_images=10000, batch_size=100, save_dir='generated_images')
    print('image generation complete!')

device: cuda
model parameters: 4,151,299
starting training..
Epoch 1, Step 100/469, Loss: 0.0793
Epoch 1, Step 200/469, Loss: 0.0492
Epoch 1, Step 300/469, Loss: 0.0377
Epoch 1, Step 400/469, Loss: 0.0233
Epoch 2, Step 100/469, Loss: 0.0255
Epoch 2, Step 200/469, Loss: 0.0231
Epoch 2, Step 300/469, Loss: 0.0190
Epoch 2, Step 400/469, Loss: 0.0251
Epoch 3, Step 100/469, Loss: 0.0174
Epoch 3, Step 200/469, Loss: 0.0196
Epoch 3, Step 300/469, Loss: 0.0227
Epoch 3, Step 400/469, Loss: 0.0156
Epoch 4, Step 100/469, Loss: 0.0165
Epoch 4, Step 200/469, Loss: 0.0168
Epoch 4, Step 300/469, Loss: 0.0154
Epoch 4, Step 400/469, Loss: 0.0114
Epoch 5, Step 100/469, Loss: 0.0119
Epoch 5, Step 200/469, Loss: 0.0151
Epoch 5, Step 300/469, Loss: 0.0176
Epoch 5, Step 400/469, Loss: 0.0112
Epoch 6, Step 100/469, Loss: 0.0186
Epoch 6, Step 200/469, Loss: 0.0186
Epoch 6, Step 300/469, Loss: 0.0190
Epoch 6, Step 400/469, Loss: 0.0126
Epoch 7, Step 100/469, Loss: 0.0170
Epoch 7, Step 200/469, Loss: 0.0154
Epo

In [3]:
from torchvision.datasets import MNIST
from torchvision import transforms
from PIL import Image
import os

dataset = MNIST(root="data", train=False, download=True)
out_dir = "mnist_png"
os.makedirs(out_dir, exist_ok=True)

for i, (img, _) in enumerate(dataset):
    img.save(f"{out_dir}/{i}.png")


In [4]:
!python -m pytorch_fid generated_images mnist_png

FID:  6.127406288703725



  0%|          | 0/200 [00:00<?, ?it/s]
  0%|          | 1/200 [00:59<3:17:00, 59.40s/it]
  1%|1         | 2/200 [01:00<1:22:06, 24.88s/it]
  2%|1         | 3/200 [01:00<45:28, 13.85s/it]  
  2%|2         | 4/200 [01:01<28:18,  8.67s/it]
  2%|2         | 5/200 [01:02<18:51,  5.80s/it]
  3%|3         | 6/200 [01:03<13:10,  4.07s/it]
  4%|3         | 7/200 [01:03<09:34,  2.98s/it]
  4%|4         | 8/200 [01:04<07:13,  2.26s/it]
  4%|4         | 9/200 [01:05<05:39,  1.78s/it]
  5%|5         | 10/200 [01:05<04:35,  1.45s/it]
  6%|5         | 11/200 [01:06<03:52,  1.23s/it]
  6%|6         | 12/200 [01:07<03:21,  1.07s/it]
  6%|6         | 13/200 [01:08<03:00,  1.03it/s]
  7%|7         | 14/200 [01:08<02:46,  1.12it/s]
  8%|7         | 15/200 [01:09<02:35,  1.19it/s]
  8%|8         | 16/200 [01:10<02:28,  1.24it/s]
  8%|8         | 17/200 [01:10<02:22,  1.28it/s]
  9%|9         | 18/200 [01:11<02:18,  1.31it/s]
 10%|9         | 19/200 [01:12<02:15,  1.33it/s]
 10%|#         | 20/200 [01:13<

In [5]:
!python -m pytorch_fid generated_images mnist.npz

FID:  5.767365379638221



  0%|          | 0/200 [00:00<?, ?it/s]
  0%|          | 1/200 [00:57<3:09:04, 57.01s/it]
  1%|1         | 2/200 [00:57<1:18:51, 23.90s/it]
  2%|1         | 3/200 [00:58<43:43, 13.32s/it]  
  2%|2         | 4/200 [00:59<27:15,  8.34s/it]
  2%|2         | 5/200 [00:59<18:10,  5.59s/it]
  3%|3         | 6/200 [01:00<12:43,  3.94s/it]
  4%|3         | 7/200 [01:01<09:16,  2.89s/it]
  4%|4         | 8/200 [01:02<07:01,  2.20s/it]
  4%|4         | 9/200 [01:02<05:31,  1.74s/it]
  5%|5         | 10/200 [01:03<04:30,  1.42s/it]
  6%|5         | 11/200 [01:04<03:48,  1.21s/it]
  6%|6         | 12/200 [01:04<03:19,  1.06s/it]
  6%|6         | 13/200 [01:05<02:58,  1.05it/s]
  7%|7         | 14/200 [01:06<02:44,  1.13it/s]
  8%|7         | 15/200 [01:07<02:34,  1.20it/s]
  8%|8         | 16/200 [01:07<02:27,  1.25it/s]
  8%|8         | 17/200 [01:08<02:22,  1.29it/s]
  9%|9         | 18/200 [01:09<02:18,  1.32it/s]
 10%|9         | 19/200 [01:09<02:15,  1.34it/s]
 10%|#         | 20/200 [01:10<

In [6]:
@torch.no_grad()
def diffusion_progress_grid(model, diffusion, save_path="diffusion_grid.png",
                            num_samples=8, image_size=28, channels=3):

    model.eval()
    device = diffusion.device

    timesteps = diffusion.timesteps
    interval = timesteps // 7 

    per_sample_snapshots = []

    for _ in range(num_samples):
        x = torch.randn((1, channels, image_size, image_size), device=device)
        snapshots = []

        for t in reversed(range(timesteps)):
            if t % interval == 0 or t == 0:
                snapshots.append(x.clone())

            x = diffusion.p_sample(
                model,
                x,
                torch.full((1,), t, device=device, dtype=torch.long),
                t
            )

        snapshots = snapshots[:8]  # ensure 8 exactly
        per_sample_snapshots.append(torch.cat(snapshots, dim=0))  # [8, 3, 28, 28]

    rows = []
    for t_idx in range(8):
        row_imgs = [per_sample_snapshots[sample_idx][t_idx] 
                    for sample_idx in range(num_samples)]
        rows.append(torch.stack(row_imgs, dim=0))  # [8, 3, 28, 28]

    # Final grid: 64 images
    grid = torch.cat([row for row in rows], dim=0)  # [64, 3, 28, 28]

    save_image(grid, save_path, nrow=8)
    print(f"Saved diffusion process grid (top → bottom progression) at {save_path}")


In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = UNet(in_channels=3, model_channels=64, time_emb_dim=256).to(device)
diffusion = DiffusionModel(timesteps=1000, device=device)

state = torch.load("diffusion_model_final.pth", map_location=device)
model.load_state_dict(state)

model.eval()

diffusion_progress_grid(model, diffusion)


  state = torch.load("diffusion_model_final.pth", map_location=device)


Saved diffusion process grid (top → bottom progression) at diffusion_grid.png
