In [None]:
# Import
from Diffusion import *
from utils import makedir
import torch
from torch.utils.data import Dataset as Dataset
from torch.utils.data import DataLoader as DataLoader
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from datetime import datetime
import h5py as h5
import pickle
import os

torch.set_float32_matmul_precision('medium')

In [None]:
# Parameters
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
local_storage_dir = "/home/md775/LocalStorage/MLProjects/Diffusion/" # Change this to your storage directory
dataset_path = local_storage_dir + "Datasets/RegularTriangulations/dataset.h5"
checkpoint_dir = local_storage_dir + "Checkpoints/Triangulations/"
log_dir = local_storage_dir + "Logs/"
sample_dir = os.getcwd() + "/samples/triangulations/"
gif_dir = os.getcwd() + "/gifs/"
makedir(checkpoint_dir)
makedir(log_dir)
makedir(sample_dir)
makedir(gif_dir)
num_channels = 1 # 1 for grayscale
num_timesteps = 2000 # Number of timesteps of the diffusion process
beta_min = 1e-6
beta_max = 0.99
image_size = 8
batch_size = 2**10
max_dataset_size = -1 # Set to -1 to use the entire dataset

In [None]:
# Load dataset
class hdf5Dataset(Dataset):
    def __init__(self, dataset_path, load_all=False, transform=None):
        self.transform = transform
        self.dataset_path = dataset_path
        self.load_all = load_all
        if load_all:
            self.dataset = h5.File(dataset_path, 'r')['height_images'][:]
            self.dataset = self.dataset[:,None,:,:]
            self.dataset = torch.from_numpy(self.dataset).float()
        else:
            self.dataset = h5.File(dataset_path, 'r')['height_images']

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        if self.load_all:
            image = self.dataset[idx]
        else:
            image = self.dataset[idx][None,:,:]
            image = torch.from_numpy(image).float()
        if self.transform:
            image = self.transform(image)
        return image        

#transform = transforms.Resize((image_size, image_size), antialias=True)
# transform = transforms.Compose([
#     transforms.Resize((image_size, image_size), antialias=True), 
# ])
transform = transforms.Lambda(lambda t: 16*t) # Scale the images to make sure noising is not too weak or too strong

image_dataset = hdf5Dataset(dataset_path, load_all=False, transform=transform)
dataloader = DataLoader(image_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=16, pin_memory=True, persistent_workers=True)

In [None]:
# Create diffusion model
DiffusionModel = Diffusion(
    image_size=image_size,
    num_channels = num_channels,
    num_timesteps=num_timesteps,
    beta_min=beta_min,
    beta_max=beta_max,
    beta_schedule="cosine",
    batch_size=batch_size,
    device=device
)

In [None]:
# Visualize forward process
def image_from_tensor(tensor):
    reverse_transforms = transforms.Compose([
        transforms.Lambda(lambda t: t.permute(1, 2, 0)), # CHW to HWC
        transforms.Lambda(lambda t: t.cpu().numpy().astype(np.float32)),
    ])
    return reverse_transforms(tensor)

initial_tensor = next(iter(dataloader)).to(device)
plt.imshow(image_from_tensor(initial_tensor[0]), cmap='gray')

plt.figure(figsize=(30,60))
num_images = 16
stepsize = int(num_timesteps/num_images)
for idx in range(0, num_timesteps):
    if idx % stepsize == 0:
        t = torch.Tensor([idx]).type(torch.int64)
        plt.subplot(int(num_images+1/8)+1, 8, int(idx/stepsize) + 1)
        tensor, noise = DiffusionModel.forward_process(initial_tensor[0,None], t)
        plt.imshow(image_from_tensor(tensor[0]), cmap="gray")

In [None]:
# Create Unet model
model = DiffusionModel.create_model(
    num_init_ch=64,
    num_downsamples=2,
    num_mid_convs=1
    )
print("Num params: ", sum(p.numel() for p in model.parameters()))
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99, last_epoch=-1, verbose=False)

In [None]:
# Training
def loss_fn(true,pred):
    return F.mse_loss(true, pred) + F.l1_loss(true, pred)

load_from_checkpoint = False
if load_from_checkpoint:
    DiffusionModel.load_from_checkpoint(checkpoint_dir+"model_min_loss.pt", model, optimizer, lr_scheduler)
    
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
DiffusionModel.train_model(
    epochs=200,
    data_loader=dataloader,
    model=model,
    optimizer=optimizer,
    lr_scheduler=lr_scheduler,
    loss_function=loss_fn,
    checkpoint_dir=checkpoint_dir,
    checkpoint_interval=1,
    log_dir=log_dir+timestamp
    )

In [None]:
# Sample from model
DiffusionModel.load_from_checkpoint(checkpoint_dir+"model_min_loss.pt", model)
tensor_sample = DiffusionModel.sample(num_images=1, variance_coeff=1.0)
tensor_sample = tensor_sample.cpu()

In [None]:
# Save sample
makedir(sample_dir)
with open(sample_dir + "height_images.pkl", "wb") as f:
    pickle.dump(tensor_sample.numpy()[0], f)


In [None]:
# Create triangulations
# This requires CYTools. See https://cy.tools/ for installation instructions and more information.
! docker run --rm -it --name cytools-uid-$UID -v ./:/home/cytools/mounted_volume -p $(($UID+2875)):$(($UID+2875)) cytools:uid-$UID python3 ./cytools_triangulate.py

In [None]:
# VISUALIZE REVERSE PROCESS
plt.figure(figsize=(60,60))
num_step_images = 10
stepsize = int(num_timesteps/num_step_images)

tensors = tensor_sample[0]
for i, tensor in enumerate(tensors):
    if i % stepsize == 0:
        plt.subplot(int(num_step_images+1/8)+1, 8, int(i/stepsize) + 1)
        plt.imshow(image_from_tensor(tensor), cmap="gray")
plt.figure(figsize=(5,5))
plt.imshow(image_from_tensor(tensors[-1]), cmap="gray")

In [None]:
# Visualize triangulations

# Load triangulations
points_list = np.load(sample_dir + "points.pkl", allow_pickle=True)
simplices_list = np.load(sample_dir + "simplices.pkl", allow_pickle=True)

def Plot_2D_Triangulation(points, simplices):
    edges = np.unique(np.concatenate([((ss[0],ss[1]),(ss[0],ss[2]),(ss[1],ss[2])) for ss in simplices]),axis=0)
    x = points[:,0].flatten()
    y = points[:,1].flatten()
    return plt.plot(x[edges.T], y[edges.T], linestyle='-', color='b', markerfacecolor='red', marker='.')  
    

plt.figure(figsize=(30,30))
num_images = 10
num_timesteps = len(points_list)
stepsize = int(num_timesteps/num_images)

for idx in range(0, num_timesteps, stepsize):
    plt.subplot(int(num_images+1/8)+1, 8, int(idx/stepsize) + 1)
    Plot_2D_Triangulation(points_list[idx], simplices_list[idx])

plt.figure(figsize=(5,5))
Plot_2D_Triangulation(points_list[-1], simplices_list[-1])
plt.show()


In [None]:
# Quantify fine-ness
point_counts = []
for s in tqdm(simplices_list):
    unique_points = np.unique(s.flatten())
    point_counts.append(len(unique_points))

plt.figure(figsize=(5,5))
plt.plot(point_counts)

In [None]:
# Create gif
from PIL import Image
images = []
num_images = 200
stepsize = int(num_timesteps/num_images)
for t in range(0, num_timesteps):
    if t % stepsize == 0:
        plt_figure = plt.figure(figsize=(5,5))
        Plot_2D_Triangulation(points_list[t], simplices_list[t])
        plt_figure.canvas.draw()
        image = Image.frombytes('RGB',plt_figure.canvas.get_width_height(),plt_figure.canvas.tostring_rgb())
        images.append(image)
        plt.close()
images[0].save(gif_dir + '/triangulations.gif', save_all=True, append_images=images[1:], duration=50)