In [None]:
import tensorflow as tf
from tensorflow import keras
import numpy as np
from datetime import datetime
import glob
import os
import sys
import tensorflow_addons as tfa

In [None]:
from clear_diffusion_keras.dataset import prepare_dataset
from clear_diffusion_keras.architecture import get_augmenter, get_network
from clear_diffusion_keras.model import DiffusionModel

In [None]:
def Generate_Iteration(version, dataset_name,
                      num_epochs = 40, uncropped_image_size = 64, image_size = 64, batch_size = 64, generated_images=1280,
                      kid_image_size = 75  , kid_diffusion_steps = 5,
                      prediction_type = "noise", loss_type = "noise", ema = 0.999, learning_rate = 1e-3, weight_decay = 1e-4,
                      schedule_type = "cosine", start_log_snr = 2.5, end_log_snr = -7.5,
                      noise_embedding_max_frequency = 200.0, noise_embedding_dims = 32, image_embedding_dims = 64, widths = [32, 64, 96, 128], block_depth = 2,
  ):

  def load_data(dataset, version):
    all_files = glob.glob(os.path.join("", f"./Results_experiments/{dataset}/Version_{version-1}/*.npy"))
    combined_data = np.array([np.load(fname) for fname in all_files])
    return combined_data
  def generate_images(dataset, version):
    Generated_images = model.generate(num_images=64, diffusion_steps=20,  stochasticity=1.0, variance_preserving=True,  num_multisteps=2, second_order_alpha=0.5)
    np.save(f"./Results_experiments/{dataset}/Version_{version}/Generated_{int(round(datetime.now().timestamp()))}.npy" , Generated_images)

  ###
  #Model
  ###
  model = DiffusionModel(
      id=version,
      augmenter=get_augmenter(
          uncropped_image_size=uncropped_image_size, image_size=image_size
      ),
      network=get_network(
          image_size=image_size,
          noise_embedding_max_frequency=noise_embedding_max_frequency,
          noise_embedding_dims=noise_embedding_dims,
          image_embedding_dims=image_embedding_dims,
          widths=widths,
          block_depth=block_depth,
      ),
      prediction_type=prediction_type,
      loss_type=loss_type,
      batch_size=batch_size,
      ema=ema,
      schedule_type=schedule_type,
      start_log_snr=start_log_snr,
      end_log_snr=end_log_snr,
      kid_image_size=kid_image_size,
      kid_diffusion_steps=kid_diffusion_steps,
      is_jupyter=True,
  )

  model.compile(
      optimizer=tfa.optimizers.AdamW(
          learning_rate=learning_rate, weight_decay=weight_decay
      ),
      loss=keras.losses.mean_absolute_error,
  )

  ###
  #Load_Dataset
  ###
  if version < 0:
    return print("")
  elif version == 1:
    datasets_names = {
    "Birds": "caltech_birds2011",
    "Flowers": "oxford_flowers102",
    }
    train_dataset = prepare_dataset(datasets_names[dataset_name], "train", uncropped_image_size, batch_size)
    val_dataset = prepare_dataset(
        datasets_names[dataset_name], "validation", uncropped_image_size, batch_size
    )
  else:
    total_data = load_data(version=version)
    train_dataset_gen = total_data[:1024]
    val_dataset_gen = total_data[-256:]
    train_dataset = tf.data.Dataset.from_tensor_slices(train_dataset_gen)
    val_dataset = tf.data.Dataset.from_tensor_slices(val_dataset_gen)

  ###
  #Train the model
  ###
  checkpoint_path = "checkpoints/model_{}".format(version)
  checkpoint_callback = keras.callbacks.ModelCheckpoint(
      filepath=checkpoint_path,
      save_weights_only=True,
      monitor="val_kid",
      mode="min",
      save_best_only=True,
  )
  model.augmenter.layers[0].adapt(train_dataset)
  model.fit(
      train_dataset,
      epochs=num_epochs,
      validation_data=val_dataset,
      callbacks=[
          keras.callbacks.LambdaCallback(on_epoch_end=model.plot_images),
          checkpoint_callback,
      ],
  )
  model.load_weights(checkpoint_path)
  model.evaluate(val_dataset)

  ###
  #Generate Dataset
  ###
  for _ in range(generated_images):
    generate_images(dataset_name, version=version)
    print(f"Generated: {_}")


In [None]:
for version in range(1, 10):
  print(version)
  Generate_Iteration(version, 'Birds')
  #Generate_Iteration(version, 'Flowers')