In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torch.optim import Adam
from torchvision.transforms import transforms, Compose, ToTensor, Lambda
from torchvision.datasets.mnist import MNIST, FashionMNIST
from torch.utils.data import DataLoader, random_split
import sys
sys.path.append('../..')
from tools.models import *
from tools.data_utils import *
import matplotlib.pyplot as plt
import numpy as np
import random
import imageio
from argparse import ArgumentParser
import einops
from tools.plot_utils import show_images, show_forward, generate_new_images
from tools.models import Autoencoder
from torch.utils.data import Subset

# Setting reproducibility
SEED = 0
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

# Definitions
STORE_PATH_MNIST = f"./history/ddpm_model_mnist.pt"
STORE_PATH_SYNTH = f"./history/ddpm_model_synth.pt"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Set the seed for reproducibility
torch.manual_seed(42)

# Define the input dimensions
batch_size = 2
channels = 4
height = 100
width = 100

# Generate random inputs for testing
input_data = torch.randn(batch_size, channels, height, width)
n_steps = 1000
time_steps = torch.randint(0, n_steps, (batch_size,))
output = unet(input_data, time_steps)

In [2]:
train_losses = []
test_losses = []


def training_loop(dataset, ddpm, loader, n_epochs, optim, device, display=False, store_path="ddpm_model.pt"):
    mse = nn.MSELoss()
    best_loss = float("inf")
    n_steps = ddpm.n_steps
    for epoch in range(num_epochs):
        train_loss = 0
        for data in loader:
            # Loading data
            if dataset =="MNIST":
                x0 = data[0].to(device)
            else:
                x0 = data["image"].float().to(device)
            n = len(x0)

            # Picking some noise for each of the images in the batch, a timestep and the respective alpha_bars
            eta = torch.randn_like(x0).to(device)
            t = torch.randint(0, n_steps, (n,)).to(device)

            # Computing the noisy image based on x0 and the time-step (forward process)
            noisy_imgs = ddpm(x0, t, eta)

            # Getting model estimation of noise based on the images and the time-step
            eta_theta, _ = ddpm.backward(noisy_imgs, t.reshape(n, -1))

            # Optimizing the MSE between the noise plugged and the predicted noise
            loss = mse(eta_theta, eta)
            optim.zero_grad()
            loss.backward()
            optim.step()

            train_loss += loss.item() * len(x0) / len(loader.dataset)

        # Display images generated at this epoch
        if display:
            show_images(generate_new_images(ddpm, device=device), f"Images generated at epoch {epoch + 1}")

        log_string = f"Loss at epoch {epoch + 1}: {train_loss:.3f}"
    
        # Storing the model
        if best_loss > train_loss:
            best_loss = train_loss
            torch.save(ddpm.state_dict(), store_path)
            log_string += " --> Best model ever (stored)"

        print(log_string)

In [None]:
# Hyperparameters
dataset = "Cars"
data_dir = "/work/DNAL/Datasets" # ./data
store_path = f"../../history/ddpm_model_{dataset}.pt"
batch_size = 128
learning_rate = 0.001
num_epochs = 100
if dataset == "PinMNIST":
    input_channel = 1
elif dataset == "Cars":
    input_channel = 3
else:
    input_channel = 4

# Loading the data (converting each image into a tensor and normalizing between [-1, 1])
if dataset == "Building":
    resize = Resize100
elif dataset == "Cars":
    resize = Resize200
else:
    resize = Resize
    
if dataset == "MNIST":
    ds_fn = MNIST
    dataset_fn = ds_fn("./datasets", download=True, train=True, transform=transform)
    train_dataloader = DataLoader(dataset_fn, batch_size, shuffle=True)
elif dataset == "Synthetic":
    # Use enough images to get great generation
    data_folder = "./data/Synthetic/10images_28by28pixels_4_distanced_grid_pins_4seed"
    dataset_fn = PinDataset(csv_file=f"{data_folder}/pins.csv",
                                  root_dir=f"{data_folder}/images/",
                                  transform=transforms.Compose([ToTensor(), resize()]))
    train_dataloader = DataLoader(dataset_fn, batch_size, shuffle=True)
    
elif dataset == "Building":
    transformed_dataset = PinDataset(csv_file=f"./data/{dataset}/pins_full.csv",
                             root_dir=f"./data/{dataset}/PS-RGBNIR/",
                             transform=Compose([ToTensor(), resize(), Lambda()]))
    
else: # Cars
    train_data_folder = os.path.join(data_dir, "Cars/train/mesh_20step_800by800pixels_100radius_4seed")
    transformed_dataset = PinDataset(csv_file=f"{train_data_folder}/pins.csv",
                                     root_dir=os.path.join(data_dir, "Cars/images/train"), 
                                     transform=Compose([ExtractImage(),ToTensor(), Resize200()]), n=1000)
            
if os.path.exists(f"./data/{dataset}/train_indices.npy"):
    train_indices = np.load(f'{data_dir}/{dataset}/train_indices.npy')
    val_indices = np.load(f'{data_dir}/{dataset}/val_indices.npy')
    test_indices = np.load(f'{data_dir}/{dataset}/test_indices.npy')
    train_indices = np.concatenate((train_indices, np.arange(1000, 1697))) # add all un-used images for DDPM training
    # Use the indices to create new datasets
    train_dataset = Subset(transformed_dataset, train_indices)
    val_dataset = Subset(transformed_dataset, val_indices)
    test_dataset = Subset(transformed_dataset, test_indices)
else:
    dataset_size = len(transformed_dataset)
    train_size = int(0.7 * dataset_size)
    val_size = int(0.10 * dataset_size)
    test_size = dataset_size - train_size - val_size
    # Split the dataset into train, validation, and test sets
    train_dataset, val_dataset, test_dataset = random_split(
        transformed_dataset, [train_size, val_size, test_size]
    )
    np.save(f'{data_dir}/{dataset}/train_indices.npy', train_dataset.indices)
    np.save(f'{data_dir}/{dataset}/val_indices.npy', val_dataset.indices)
    np.save(f'{data_dir}/{dataset}/test_indices.npy', test_dataset.indices)

# Create your DataLoader with the custom_collate_fn
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=custom_collate_fn)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, collate_fn=custom_collate_fn)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=custom_collate_fn)

# Initialize the autoencoder
n_steps, min_beta, max_beta = 1000, 10 ** -4, 0.02  # Originally used by the authors
if dataset == "Building":
    model = DDPM(UNet(input_channel, shape=100), n_steps=n_steps, min_beta=min_beta, max_beta=max_beta, device=device)
elif dataset == "Cars":
    model = DDPM(UNet(input_channel, shape=200), n_steps=n_steps, min_beta=min_beta, max_beta=max_beta, device=device)
else:
    model = DDPM(UNet(input_channel), n_steps=n_steps, min_beta=min_beta, max_beta=max_beta, device=device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
training_loop(dataset, model, train_loader, num_epochs, optimizer, device, store_path=store_path)

Loss at epoch 1: 1.000 --> Best model ever (stored)
Loss at epoch 2: 0.988 --> Best model ever (stored)
Loss at epoch 3: 0.960 --> Best model ever (stored)
Loss at epoch 4: 0.914 --> Best model ever (stored)
Loss at epoch 5: 0.857 --> Best model ever (stored)
Loss at epoch 6: 0.800 --> Best model ever (stored)
Loss at epoch 7: 0.766 --> Best model ever (stored)
Loss at epoch 8: 0.737 --> Best model ever (stored)
Loss at epoch 9: 0.728 --> Best model ever (stored)
Loss at epoch 10: 0.709 --> Best model ever (stored)
Loss at epoch 11: 0.697 --> Best model ever (stored)
Loss at epoch 12: 0.677 --> Best model ever (stored)
Loss at epoch 13: 0.655 --> Best model ever (stored)
Loss at epoch 14: 0.629 --> Best model ever (stored)
Loss at epoch 15: 0.602 --> Best model ever (stored)
Loss at epoch 16: 0.580 --> Best model ever (stored)
Loss at epoch 17: 0.562 --> Best model ever (stored)
Loss at epoch 18: 0.539 --> Best model ever (stored)
Loss at epoch 19: 0.514 --> Best model ever (stored)
Lo

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

count = 0
for batch in train_loader:
    images = batch['image'].to(device) # get RGB instead of RGBA
    pins = batch['pins']
    outputs = batch['outputs']
    print(count, len(images))
    count+=1

In [None]:
# Loading the trained model
store_path = f"./history/ddpm_model_{dataset}.pt"
best_model = DDPM(UNet(), n_steps=n_steps, device=device)
best_model.load_state_dict(torch.load(store_path, map_location=device))
best_model.eval()
print("Model loaded")

print("Generating new images")
generated = generate_new_images(
        best_model,
        n_samples=100,
        device=device,
        gif_name="mnist.gif"
    )
show_images(generated, "Final result")

In [None]:
show_forward(model, train_dataloader, device)