In [1]:
import numpy as np
import matplotlib.pyplot as plt
import os
import fnmatch
from tqdm import tqdm

from img_utils import ImgUtils
from noise_scheduler import NoiseScheduler
from unet import UNet



Init Plugin
Init Graph Optimizer
Init Kernel


In [2]:
filenames = []
for dirpath, dirs, files in os.walk('Images'): 
  for filename in fnmatch.filter(files, '*.jpg'):
    filenames.append(os.path.join(dirpath, filename))

filenames = filenames[:1000] # only use first 100 images temporarily for quicker runtime

imgs = [plt.imread(fn) for fn in tqdm(filenames, desc="Loading")]
imgs = [ImgUtils.resize_img(img, (256, 256)) for img in tqdm(imgs, desc="Resizing")]
imgs = [ImgUtils.int_to_float_img(img) for img in tqdm(imgs, desc="Casting")]
imgs = [ImgUtils.scale_img(img) for img in tqdm(imgs, desc="Scaling")]

Loading: 100%|██████████| 1000/1000 [00:01<00:00, 554.41it/s]
Resizing: 100%|██████████| 1000/1000 [00:00<00:00, 6099.27it/s]
Casting: 100%|██████████| 1000/1000 [00:00<00:00, 4769.04it/s]
Scaling: 100%|██████████| 1000/1000 [00:00<00:00, 1074.46it/s]


In [3]:
n_timesteps = 100
noiser = NoiseScheduler(n_timesteps, start=0.0001, end=0.06)

training_inputs = []
training_outputs = []
for img in tqdm(imgs, desc="Noising"):
    for step in range(n_timesteps):
        noised_img = noiser.forward(img, step)
        training_inputs.append(noised_img)
        noise = img - noised_img
        training_outputs.append(noise)

training_inputs = np.array(training_inputs)
training_outputs = np.array(training_outputs)
assert training_inputs.shape == training_outputs.shape == (len(imgs) * n_timesteps, 256, 256, 3)

Noising:  14%|█▍        | 144/1000 [01:29<17:14,  1.21s/it]

: 

: 

In [None]:
chosen_imgs = [training_inputs[0], training_inputs[n_timesteps-1], 
                training_outputs[0], training_outputs[n_timesteps-1]]
shown_imgs = [ImgUtils.unscale_img(img).clip(0,1) for img in chosen_imgs]
ImgUtils.show_images(shown_imgs, cols=2)

In [None]:
unet = UNet.new()
unet.compile(optimizer="adam", loss=lambda hx, y: (y-hx)**2)
unet.summary()

In [None]:
hist = unet.fit(training_inputs, training_outputs, epochs=10, validation_split=0.2)

In [None]:
plt.plot(hist.history["loss"], label="Training Loss")
plt.plot(hist.history["val_loss"], label="Validation Loss")
plt.legend()
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training Loss over Epoch")
plt.show()

In [None]:
noise = np.random.randn(1, 256, 256, 3) * 2 - 1
new_dog = unet.predict(noise)[0]
plt.imshow(ImgUtils.unscale_img(new_dog).clip(0,1))
plt.show()