In [1]:
import torch
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from denoising_diffusion_pytorch import Unet, GaussianDiffusion
import matplotlib.pyplot as plt
from torch.utils.data import RandomSampler

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Set device to CPU (since you want to run without CUDA)
device = torch.device('cpu')

In [3]:
# Load the MNIST dataset
transform = transforms.Compose([
    transforms.Resize((128, 128)),  # Resize to 128x128
    transforms.ToTensor(),  # Convert to tensor
    transforms.Normalize((0.5,), (0.5,)),  # Normalize to [-1, 1]
    transforms.Lambda(lambda x: x.repeat(3, 1, 1))  # Repeat channel to make it 3 channels
])


In [4]:
mnist_dataset = MNIST(root='./data', train=True, download=True, transform=transform)

# Create a RandomSampler
sampler = RandomSampler(mnist_dataset, num_samples=100)  

# Create the DataLoader with the sampler
data_loader = DataLoader(mnist_dataset, batch_size=2, sampler=sampler, num_workers=4)

In [5]:
# Initialize the Unet model
model = Unet(
    dim=64,
    dim_mults=(1, 2, 4, 8)
).to(device)

In [6]:
# Initialize the GaussianDiffusion object
diffusion = GaussianDiffusion(
    model,
    image_size=128,
    timesteps=10,   # number of steps
).to(device)

In [7]:
# Initialize your optimizer
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [None]:
# Training loop
num_epochs = 1
for epoch in range(num_epochs):
    for images, _ in data_loader:
        images = images.to(device)
        
        optimizer.zero_grad()
        loss = diffusion(images)
        loss.backward()
        optimizer.step()

        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')


Epoch [1/1], Loss: 1.8970
Epoch [1/1], Loss: 0.7690
Epoch [1/1], Loss: 0.6157
Epoch [1/1], Loss: 0.7643
Epoch [1/1], Loss: 0.3624
Epoch [1/1], Loss: 0.3418
Epoch [1/1], Loss: 1.0761
Epoch [1/1], Loss: 0.4282
Epoch [1/1], Loss: 1.6978
Epoch [1/1], Loss: 0.4365
Epoch [1/1], Loss: 0.1379
Epoch [1/1], Loss: 0.3744
Epoch [1/1], Loss: 0.4183
Epoch [1/1], Loss: 0.2997
Epoch [1/1], Loss: 0.4708
Epoch [1/1], Loss: 0.3994
Epoch [1/1], Loss: 1.0513
Epoch [1/1], Loss: 1.6409
Epoch [1/1], Loss: 1.4830
Epoch [1/1], Loss: 0.6126
Epoch [1/1], Loss: 0.5166
Epoch [1/1], Loss: 0.7833
Epoch [1/1], Loss: 0.3782
Epoch [1/1], Loss: 0.0588
Epoch [1/1], Loss: 0.9906
Epoch [1/1], Loss: 0.5061
Epoch [1/1], Loss: 0.1911
Epoch [1/1], Loss: 0.3169
Epoch [1/1], Loss: 0.5829
Epoch [1/1], Loss: 0.9013
Epoch [1/1], Loss: 0.3964
Epoch [1/1], Loss: 0.1202


In [None]:
# Sampling images after training
sampled_images = diffusion.sample(batch_size=4)

sampling loop time step:  62%|██████▏   | 31/50 [00:45<00:27,  1.46s/it]

In [None]:
# Convert sampled images to numpy for visualization or saving
sampled_images = (sampled_images + 1) / 2  # Rescale to [0, 1]
sampled_images = sampled_images.clamp(0, 1)  # Ensure values are within bounds

In [None]:
def show_images(images):
    images = images.detach().cpu().numpy()
    fig, axes = plt.subplots(1, 4, figsize=(10, 5))
    for i, ax in enumerate(axes):
        ax.imshow(images[i].squeeze(), cmap='gray')
        ax.axis('off')
    plt.show()

In [None]:
show_images(sampled_images)