In [1]:
# Import required libraries
import os
import numpy as np
import torch
import torch.nn as nn
from matplotlib import pyplot as plt
from tqdm import tqdm
from torch import optim
import torchvision
from PIL import Image
from Unet import UNet
from Diffusion_model import Diffusion
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader

In [3]:
import os
import numpy as np

directory = '../../00_data/Samples'
numpy_list = []

for file_name in os.listdir(directory):
    if file_name.endswith('.npy'):
        file_path = os.path.join(directory, file_name)
        image_np = np.load(file_path)

        numpy_list.append(image_np)



In [4]:


def save_images(generated, path, **kwargs):
    # Convert to numpy arrays
    generated = generated.to('cpu').numpy()

    # Determine the number of images
    num_images = generated.shape[0]
    # Calculate the number of rows needed (assuming 4 images per row for display)
    num_rows = (num_images + 3) // 4  # Round up division

    # Create a figure and axes
    fig, axes = plt.subplots(nrows=num_rows, ncols=4, figsize=(20, num_rows * 5))
    axes = axes.flatten()  # Flatten the axes array for easier indexing15

    # Adjust space between images
    plt.subplots_adjust(wspace=0.3, hspace=0.3)

    for i, gen_img in enumerate(generated):
        # Display generated images
        axes[i].imshow(gen_img.squeeze(), cmap='gray', **kwargs)
        axes[i].axis('off')
        axes[i].set_title(f"Generated {i+1}", fontsize=10)

    # In case the last row is not fully populated
    for j in range(i + 1, len(axes)):
        axes[j].axis('off')

    # Save the figure
    fig.savefig(path, bbox_inches='tight')
    plt.close(fig)


In [5]:
from PIL import Image

In [6]:

x_train = np.array(numpy_list)
# Convert to a list of PIL images, resize each, and then convert back to a NumPy array
x_train = np.array([np.array(Image.fromarray(img.squeeze()).resize((64, 64))) for img in x_train])

# The new shape will be (10000, 64, 64). If you need to add the channel dimension back:
x_train = x_train[:, np.newaxis, :, :]
x_train = torch.Tensor(x_train)

print(x_train.shape)



torch.Size([10000, 1, 64, 64])


In [7]:
type(x_train)

torch.Tensor

In [8]:
dataset = TensorDataset(x_train)
dataloader = DataLoader(dataset, batch_size=10)


In [9]:
dataloader

<torch.utils.data.dataloader.DataLoader at 0x7f646ef05100>

In [19]:
device = "cuda"
model = UNet().to(device)
#model.load_state_dict(torch.load('./Weights/Diff_ckpt_1.pt'))
optimizer = optim.AdamW(model.parameters(), lr=3e-4)
mse = nn.MSELoss()
diffusion = Diffusion(img_size=64, device=device)
l = len(dataloader)

In [20]:
start_epoch = 13
epochs = 100

#if model exists, load it
if os.path.exists('./Weights/Diff_ckpt_1.pt'):
    model.load_state_dict(torch.load('./Weights/Diff_ckpt_1.pt'))
    print("Model loaded")


Model loaded


In [21]:
for epoch in range(start_epoch, epochs):
    print(f"Starting epoch {epoch}:")
    pbar = tqdm(dataloader)
    avg_mse = 0
    # for i, (images, conditions) in enumerate(pbar):
    for i, (images,) in enumerate(pbar): #removed conditions for now to make the model DDPM
        images = images.to(device)
        # conditions = conditions.to(device)
        t = diffusion.sample_timesteps(images.shape[0]).to(device)
        x_t, noise = diffusion.noise_images(images, t)
        # predicted_noise = model(x_t, t, conditions)
        predicted_noise = model(x_t, t)
        loss = mse(noise, predicted_noise)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        pbar.set_postfix(MSE=loss.item())
        avg_mse += loss.item()

    print(f'Average MSE: {avg_mse/1000:.5f}\n')
    sampled_images = diffusion.sample(model, n=images.shape[0])
    save_images(sampled_images, os.path.join("Results", f"{epoch}.png"))
    torch.save(model.state_dict(), os.path.join("Weights", f"Diff_ckpt_1.pt"))

Starting epoch 13:


100%|██████████| 1000/1000 [02:03<00:00,  8.11it/s, MSE=0.00151]


Average MSE: 0.00186

Sampling 10 new images....


999it [00:44, 22.34it/s]


Starting epoch 14:


100%|██████████| 1000/1000 [02:02<00:00,  8.13it/s, MSE=0.00155]


Average MSE: 0.00170

Sampling 10 new images....


999it [00:42, 23.65it/s]


Starting epoch 15:


100%|██████████| 1000/1000 [02:04<00:00,  8.05it/s, MSE=0.000491]


Average MSE: 0.00165

Sampling 10 new images....


999it [00:42, 23.42it/s]


Starting epoch 16:


100%|██████████| 1000/1000 [02:06<00:00,  7.93it/s, MSE=0.00249]


Average MSE: 0.00161

Sampling 10 new images....


999it [00:41, 23.81it/s]


Starting epoch 17:


100%|██████████| 1000/1000 [02:24<00:00,  6.90it/s, MSE=0.00165]


Average MSE: 0.00156

Sampling 10 new images....


999it [00:50, 19.89it/s]


Starting epoch 18:


100%|██████████| 1000/1000 [02:17<00:00,  7.28it/s, MSE=0.000841]


Average MSE: 0.00156

Sampling 10 new images....


999it [00:46, 21.36it/s]


Starting epoch 19:


100%|██████████| 1000/1000 [02:13<00:00,  7.49it/s, MSE=0.000709]


Average MSE: 0.00152

Sampling 10 new images....


999it [00:46, 21.48it/s]


Starting epoch 20:


100%|██████████| 1000/1000 [02:15<00:00,  7.38it/s, MSE=0.00784]


Average MSE: 0.00154

Sampling 10 new images....


999it [00:56, 17.82it/s]


Starting epoch 21:


100%|██████████| 1000/1000 [02:15<00:00,  7.40it/s, MSE=0.00198]


Average MSE: 0.00146

Sampling 10 new images....


999it [00:47, 21.24it/s]


Starting epoch 22:


100%|██████████| 1000/1000 [02:18<00:00,  7.22it/s, MSE=0.00265]


Average MSE: 0.00145

Sampling 10 new images....


999it [00:48, 20.74it/s]


Starting epoch 23:


100%|██████████| 1000/1000 [02:18<00:00,  7.21it/s, MSE=0.00224]


Average MSE: 0.00148

Sampling 10 new images....


999it [00:47, 20.83it/s]


Starting epoch 24:


100%|██████████| 1000/1000 [04:57<00:00,  3.36it/s, MSE=0.000707] 


Average MSE: 0.00141

Sampling 10 new images....


999it [02:04,  8.01it/s]


Starting epoch 25:


100%|██████████| 1000/1000 [03:17<00:00,  5.07it/s, MSE=0.000907]


Average MSE: 0.00141

Sampling 10 new images....


999it [01:49,  9.11it/s]


Starting epoch 26:


100%|██████████| 1000/1000 [03:54<00:00,  4.27it/s, MSE=0.000731]


Average MSE: 0.00140

Sampling 10 new images....


999it [01:09, 14.40it/s]


Starting epoch 27:


100%|██████████| 1000/1000 [02:14<00:00,  7.44it/s, MSE=0.000647]


Average MSE: 0.00137

Sampling 10 new images....


999it [00:48, 20.55it/s]


Starting epoch 28:


100%|██████████| 1000/1000 [04:05<00:00,  4.08it/s, MSE=0.0014] 


Average MSE: 0.00141

Sampling 10 new images....


999it [00:57, 17.35it/s]


Starting epoch 29:


100%|██████████| 1000/1000 [03:19<00:00,  5.00it/s, MSE=0.00121]


Average MSE: 0.00142

Sampling 10 new images....


999it [00:54, 18.48it/s]


Starting epoch 30:


100%|██████████| 1000/1000 [03:44<00:00,  4.45it/s, MSE=0.00108] 


Average MSE: 0.00139

Sampling 10 new images....


999it [01:24, 11.80it/s]


Starting epoch 31:


100%|██████████| 1000/1000 [08:35<00:00,  1.94it/s, MSE=0.000518] 


Average MSE: 0.00133

Sampling 10 new images....


999it [08:21,  1.99it/s]


Starting epoch 32:


100%|██████████| 1000/1000 [04:45<00:00,  3.51it/s, MSE=0.000396] 


Average MSE: 0.00138

Sampling 10 new images....


999it [00:42, 23.41it/s]


Starting epoch 33:


100%|██████████| 1000/1000 [02:06<00:00,  7.91it/s, MSE=0.00115]


Average MSE: 0.00134

Sampling 10 new images....


999it [00:42, 23.50it/s]


Starting epoch 34:


100%|██████████| 1000/1000 [02:02<00:00,  8.19it/s, MSE=0.00117]


Average MSE: 0.00133

Sampling 10 new images....


999it [00:41, 24.20it/s]


Starting epoch 35:


100%|██████████| 1000/1000 [02:48<00:00,  5.92it/s, MSE=0.0031] 


Average MSE: 0.00138

Sampling 10 new images....


999it [02:01,  8.21it/s]


Starting epoch 36:


100%|██████████| 1000/1000 [02:44<00:00,  6.09it/s, MSE=0.000933] 


Average MSE: 0.00130

Sampling 10 new images....


999it [00:41, 23.94it/s]


Starting epoch 37:


100%|██████████| 1000/1000 [02:07<00:00,  7.86it/s, MSE=0.000718]


Average MSE: 0.00133

Sampling 10 new images....


999it [00:44, 22.67it/s]


Starting epoch 38:


100%|██████████| 1000/1000 [02:03<00:00,  8.08it/s, MSE=0.00242]


Average MSE: 0.00129

Sampling 10 new images....


999it [00:41, 23.95it/s]


Starting epoch 39:


100%|██████████| 1000/1000 [02:04<00:00,  8.05it/s, MSE=0.00433]


Average MSE: 0.00129

Sampling 10 new images....


999it [00:42, 23.70it/s]


Starting epoch 40:


100%|██████████| 1000/1000 [02:11<00:00,  7.62it/s, MSE=0.000923]


Average MSE: 0.00127

Sampling 10 new images....


999it [00:41, 23.94it/s]


Starting epoch 41:


100%|██████████| 1000/1000 [02:04<00:00,  8.04it/s, MSE=0.0012] 


Average MSE: 0.00133

Sampling 10 new images....


999it [00:41, 23.94it/s]


Starting epoch 42:


100%|██████████| 1000/1000 [02:03<00:00,  8.12it/s, MSE=0.00176]


Average MSE: 0.00128

Sampling 10 new images....


999it [00:41, 23.95it/s]


Starting epoch 43:


100%|██████████| 1000/1000 [02:48<00:00,  5.93it/s, MSE=0.000824]


Average MSE: 0.00123

Sampling 10 new images....


999it [00:39, 25.54it/s]


Starting epoch 44:


100%|██████████| 1000/1000 [01:56<00:00,  8.62it/s, MSE=0.00101]


Average MSE: 0.00123

Sampling 10 new images....


999it [00:38, 26.07it/s]


Starting epoch 45:


100%|██████████| 1000/1000 [01:53<00:00,  8.79it/s, MSE=0.00172]


Average MSE: 0.00127

Sampling 10 new images....


999it [00:38, 26.12it/s]


Starting epoch 46:


100%|██████████| 1000/1000 [01:53<00:00,  8.81it/s, MSE=0.000995]


Average MSE: 0.00135

Sampling 10 new images....


999it [00:37, 26.32it/s]


Starting epoch 47:


100%|██████████| 1000/1000 [01:54<00:00,  8.77it/s, MSE=0.00135]


Average MSE: 0.00123

Sampling 10 new images....


999it [00:39, 25.34it/s]


Starting epoch 48:


100%|██████████| 1000/1000 [01:58<00:00,  8.44it/s, MSE=0.000496]


Average MSE: 0.00127

Sampling 10 new images....


999it [00:39, 25.29it/s]


Starting epoch 49:


100%|██████████| 1000/1000 [01:59<00:00,  8.35it/s, MSE=0.000624]


Average MSE: 0.00122

Sampling 10 new images....


999it [00:40, 24.63it/s]


Starting epoch 50:


100%|██████████| 1000/1000 [01:59<00:00,  8.38it/s, MSE=0.000916]


Average MSE: 0.00127

Sampling 10 new images....


999it [00:40, 24.83it/s]


Starting epoch 51:


100%|██████████| 1000/1000 [02:00<00:00,  8.27it/s, MSE=0.00104]


Average MSE: 0.00124

Sampling 10 new images....


999it [00:40, 24.87it/s]


Starting epoch 52:


100%|██████████| 1000/1000 [01:57<00:00,  8.54it/s, MSE=0.000298]


Average MSE: 0.00125

Sampling 10 new images....


999it [00:40, 24.54it/s]


Starting epoch 53:


100%|██████████| 1000/1000 [01:59<00:00,  8.36it/s, MSE=0.00623]


Average MSE: 0.00123

Sampling 10 new images....


999it [00:38, 26.07it/s]


Starting epoch 54:


100%|██████████| 1000/1000 [01:53<00:00,  8.80it/s, MSE=0.000726]


Average MSE: 0.00123

Sampling 10 new images....


999it [00:38, 26.18it/s]


Starting epoch 55:


100%|██████████| 1000/1000 [01:53<00:00,  8.80it/s, MSE=0.00068]


Average MSE: 0.00123

Sampling 10 new images....


999it [00:38, 26.06it/s]


Starting epoch 56:


100%|██████████| 1000/1000 [01:54<00:00,  8.75it/s, MSE=0.000562]


Average MSE: 0.00126

Sampling 10 new images....


999it [00:40, 24.58it/s]


Starting epoch 57:


100%|██████████| 1000/1000 [01:58<00:00,  8.42it/s, MSE=0.00313]


Average MSE: 0.00118

Sampling 10 new images....


999it [00:39, 25.00it/s]


Starting epoch 58:


100%|██████████| 1000/1000 [01:57<00:00,  8.52it/s, MSE=0.00112]


Average MSE: 0.00121

Sampling 10 new images....


999it [00:39, 25.22it/s]


Starting epoch 59:


100%|██████████| 1000/1000 [01:57<00:00,  8.51it/s, MSE=0.000621]


Average MSE: 0.00120

Sampling 10 new images....


999it [00:38, 25.65it/s]


Starting epoch 60:


100%|██████████| 1000/1000 [01:55<00:00,  8.66it/s, MSE=0.00277]


Average MSE: 0.00126

Sampling 10 new images....


999it [00:38, 25.68it/s]


Starting epoch 61:


100%|██████████| 1000/1000 [01:57<00:00,  8.51it/s, MSE=0.00137]


Average MSE: 0.00125

Sampling 10 new images....


999it [00:40, 24.74it/s]


Starting epoch 62:


100%|██████████| 1000/1000 [02:00<00:00,  8.32it/s, MSE=0.000803]


Average MSE: 0.00115

Sampling 10 new images....


999it [00:40, 24.92it/s]


Starting epoch 63:


100%|██████████| 1000/1000 [01:59<00:00,  8.38it/s, MSE=0.00101]


Average MSE: 0.00123

Sampling 10 new images....


999it [00:40, 24.71it/s]


Starting epoch 64:


100%|██████████| 1000/1000 [01:54<00:00,  8.73it/s, MSE=0.000544]


Average MSE: 0.00116

Sampling 10 new images....


999it [00:38, 26.26it/s]


Starting epoch 65:


100%|██████████| 1000/1000 [01:53<00:00,  8.80it/s, MSE=0.00194]


Average MSE: 0.00121

Sampling 10 new images....


999it [00:37, 26.29it/s]


Starting epoch 66:


100%|██████████| 1000/1000 [01:53<00:00,  8.81it/s, MSE=0.00121]


Average MSE: 0.00118

Sampling 10 new images....


999it [00:38, 26.22it/s]


Starting epoch 67:


100%|██████████| 1000/1000 [01:53<00:00,  8.79it/s, MSE=0.00054]


Average MSE: 0.00120

Sampling 10 new images....


999it [00:38, 26.16it/s]


Starting epoch 68:


100%|██████████| 1000/1000 [01:55<00:00,  8.68it/s, MSE=0.00118]


Average MSE: 0.00121

Sampling 10 new images....


999it [00:38, 25.77it/s]


Starting epoch 69:


100%|██████████| 1000/1000 [01:58<00:00,  8.46it/s, MSE=0.00171]


Average MSE: 0.00117

Sampling 10 new images....


999it [00:40, 24.96it/s]


Starting epoch 70:


100%|██████████| 1000/1000 [01:58<00:00,  8.46it/s, MSE=0.00097]


Average MSE: 0.00117

Sampling 10 new images....


999it [00:39, 25.33it/s]


Starting epoch 71:


100%|██████████| 1000/1000 [01:57<00:00,  8.48it/s, MSE=0.000831]


Average MSE: 0.00116

Sampling 10 new images....


999it [00:39, 25.42it/s]


Starting epoch 72:


100%|██████████| 1000/1000 [01:56<00:00,  8.58it/s, MSE=0.00252]


Average MSE: 0.00122

Sampling 10 new images....


999it [00:38, 25.67it/s]


Starting epoch 73:


100%|██████████| 1000/1000 [01:55<00:00,  8.66it/s, MSE=0.000775]


Average MSE: 0.00118

Sampling 10 new images....


999it [00:38, 26.07it/s]


Starting epoch 74:


100%|██████████| 1000/1000 [01:53<00:00,  8.77it/s, MSE=0.000484]


Average MSE: 0.00123

Sampling 10 new images....


999it [00:38, 26.27it/s]


Starting epoch 75:


100%|██████████| 1000/1000 [01:53<00:00,  8.80it/s, MSE=0.00113]


Average MSE: 0.00118

Sampling 10 new images....


999it [00:38, 26.24it/s]


Starting epoch 76:


100%|██████████| 1000/1000 [01:53<00:00,  8.79it/s, MSE=0.00058]


Average MSE: 0.00118

Sampling 10 new images....


999it [00:37, 26.29it/s]


Starting epoch 77:


100%|██████████| 1000/1000 [01:53<00:00,  8.81it/s, MSE=0.00137]


Average MSE: 0.00114

Sampling 10 new images....


999it [00:38, 26.20it/s]


Starting epoch 78:


100%|██████████| 1000/1000 [01:54<00:00,  8.75it/s, MSE=0.00115]


Average MSE: 0.00113

Sampling 10 new images....


999it [00:38, 26.23it/s]


Starting epoch 79:


100%|██████████| 1000/1000 [01:53<00:00,  8.78it/s, MSE=0.00177]


Average MSE: 0.00116

Sampling 10 new images....


999it [00:38, 26.26it/s]


Starting epoch 80:


100%|██████████| 1000/1000 [01:54<00:00,  8.73it/s, MSE=0.00144]


Average MSE: 0.00115

Sampling 10 new images....


999it [00:38, 26.07it/s]


Starting epoch 81:


100%|██████████| 1000/1000 [01:53<00:00,  8.79it/s, MSE=0.00134]


Average MSE: 0.00119

Sampling 10 new images....


999it [00:38, 26.27it/s]


Starting epoch 82:


100%|██████████| 1000/1000 [01:53<00:00,  8.79it/s, MSE=0.000534]


Average MSE: 0.00111

Sampling 10 new images....


999it [00:38, 26.24it/s]


Starting epoch 83:


100%|██████████| 1000/1000 [01:53<00:00,  8.79it/s, MSE=0.000958]


Average MSE: 0.00119

Sampling 10 new images....


999it [00:38, 26.27it/s]


Starting epoch 84:


100%|██████████| 1000/1000 [01:53<00:00,  8.79it/s, MSE=0.00109]


Average MSE: 0.00117

Sampling 10 new images....


999it [00:38, 26.20it/s]


Starting epoch 85:


100%|██████████| 1000/1000 [01:54<00:00,  8.74it/s, MSE=0.000805]


Average MSE: 0.00115

Sampling 10 new images....


999it [00:38, 26.17it/s]


Starting epoch 86:


100%|██████████| 1000/1000 [01:53<00:00,  8.79it/s, MSE=0.00104]


Average MSE: 0.00115

Sampling 10 new images....


999it [00:38, 26.27it/s]


Starting epoch 87:


100%|██████████| 1000/1000 [01:53<00:00,  8.80it/s, MSE=0.000852]


Average MSE: 0.00115

Sampling 10 new images....


999it [00:38, 26.15it/s]


Starting epoch 88:


100%|██████████| 1000/1000 [01:53<00:00,  8.80it/s, MSE=0.000355]


Average MSE: 0.00113

Sampling 10 new images....


999it [00:37, 26.29it/s]


Starting epoch 89:


100%|██████████| 1000/1000 [01:53<00:00,  8.80it/s, MSE=0.00043]


Average MSE: 0.00112

Sampling 10 new images....


999it [00:38, 26.28it/s]


Starting epoch 90:


100%|██████████| 1000/1000 [01:53<00:00,  8.81it/s, MSE=0.000724]


Average MSE: 0.00119

Sampling 10 new images....


999it [00:38, 26.27it/s]


Starting epoch 91:


100%|██████████| 1000/1000 [01:55<00:00,  8.66it/s, MSE=0.000362]


Average MSE: 0.00113

Sampling 10 new images....


999it [00:40, 24.80it/s]


Starting epoch 92:


100%|██████████| 1000/1000 [01:58<00:00,  8.41it/s, MSE=0.00154]


Average MSE: 0.00116

Sampling 10 new images....


999it [00:40, 24.46it/s]


Starting epoch 93:


100%|██████████| 1000/1000 [02:00<00:00,  8.29it/s, MSE=0.000286]


Average MSE: 0.00112

Sampling 10 new images....


999it [00:40, 24.90it/s]


Starting epoch 94:


100%|██████████| 1000/1000 [01:55<00:00,  8.68it/s, MSE=0.00106]


Average MSE: 0.00115

Sampling 10 new images....


999it [00:38, 26.08it/s]


Starting epoch 95:


100%|██████████| 1000/1000 [01:53<00:00,  8.80it/s, MSE=0.00123]


Average MSE: 0.00116

Sampling 10 new images....


999it [00:37, 26.31it/s]


Starting epoch 96:


100%|██████████| 1000/1000 [01:53<00:00,  8.78it/s, MSE=0.000962]


Average MSE: 0.00116

Sampling 10 new images....


999it [00:38, 26.24it/s]


Starting epoch 97:


100%|██████████| 1000/1000 [01:53<00:00,  8.81it/s, MSE=0.00106]


Average MSE: 0.00113

Sampling 10 new images....


999it [00:38, 26.08it/s]


Starting epoch 98:


100%|██████████| 1000/1000 [01:54<00:00,  8.75it/s, MSE=0.00084]


Average MSE: 0.00113

Sampling 10 new images....


999it [00:38, 26.16it/s]


Starting epoch 99:


100%|██████████| 1000/1000 [01:54<00:00,  8.76it/s, MSE=0.000923]


Average MSE: 0.00113

Sampling 10 new images....


999it [00:38, 26.25it/s]
