```
Copyright 2023 Nikolai Körber. All Rights Reserved.

Based on:
https://keras.io/examples/generative/ddpm/

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
```

# IADB (Sonar dataset)

This notebook demonstrates how to train an IADB diffusion model on the Sonar dataset. The overall design is inspired by https://keras.io/examples/generative/ddpm/.

The official PyTorch implementation can be found [here](https://github.com/tchambon/IADB).

## Setup

In [57]:
#!git clone https://github.com/Nikolai10/Diffusion-TF

In [58]:
import sys
sys.path.append('/Diffusion-TF/projects/IADB')

import math
import numpy as np
import matplotlib.pyplot as plt

# Requires TensorFlow >=2.11 for the GroupNormalization layer.
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
# import tensorflow_datasets as tfds
# from model import IADBModel
from Diffusion_TF.projects.IADB.model import IADBModel
from Diffusion_TF.projects.IADB.tutorials.helpers_flowers import plot_images, make_gif
#from tutorials.helpers_flowers import plot_images, make_gif

In [59]:
data_dir = '/home/republic/Documents/Ganadev/Sonar/Datasets/Data Preparation 4/mine_like_object_right'
gen_path = '/home/republic/Documents/Ganadev/Sonar/Outputs/Diffusion/Result1/Generated/'

## Hyperparameters

In [60]:
batch_size = 8
#num_epochs = 1500 # e.g. 1000
num_epochs = 1 # e.g. 1000
norm_groups = 8  # Number of groups used in GroupNormalization layer
learning_rate = 2e-4

# img_size = 64
# img_height = 128
#img_width = 800
img_height = 128
img_width = 800
img_size = (img_height, img_width)
img_channels = 1
clip_min = -1.0
clip_max = 1.0

first_conv_channels = 64
channel_multiplier = [1, 2, 4, 8]
widths = [first_conv_channels * mult for mult in channel_multiplier]
has_attention = [False, False, True, True]
num_res_blocks = 2  # Number of residual blocks

dataset_name = "sonar"
splits = ["train"]



## Dataset

In [61]:
# Load the dataset
# (ds,) = tfds.load(dataset_name, split=splits, with_info=False, shuffle_files=True)

# ds = tf.keras.utils.image_dataset_from_directory(
#   data_dir,
#   # validation_split=0.2,
#   # subset="training",
#   seed=123,
#   image_size=(img_height, img_width),
#   # batch_size=batch_size,
#   color_mode="grayscale",
#   shuffle=True,
#   label_mode = None # Set label_mode to None to yield only images
# )

# def augment(img):
#     """Flips an image left/right randomly."""
#     return tf.image.random_flip_left_right(img)


# def resize_and_rescale(img, size):
#     """Resize the image to the desired size first and then
#     rescale the pixel values in the range [-1.0, 1.0].

#     Args:
#         img: Image tensor
#         size: Desired image size for resizing
#     Returns:
#         Resized and rescaled image tensor
#     """

#     height = tf.shape(img)[0]
#     width = tf.shape(img)[1]
#     crop_size = tf.minimum(height, width)

#     img = tf.image.crop_to_bounding_box(
#         img,
#         (height - crop_size) // 2,
#         (width - crop_size) // 2,
#         crop_size,
#         crop_size,
#     )

#     # Resize
#     img = tf.cast(img, dtype=tf.float32)
#     img = tf.image.resize(img, size=size, antialias=True)

#     # Rescale the pixel values
#     img = img / 127.5 - 1.0
#     img = tf.clip_by_value(img, clip_min, clip_max)
#     return img


# def train_preprocessing(x):
#     img = x # x is already the image tensor
#     print(img)
#     # img = resize_and_rescale(img, size=(img_size[0], img_size[1]))
#     # img = augment(img)
#     return img

# for element in ds.take(1):
#   print("shape")
#   print(element.shape)
# train_ds = (
#     ds.map(train_preprocessing, num_parallel_calls=tf.data.AUTOTUNE)
#     .batch(batch_size, drop_remainder=True)
#     .shuffle(batch_size * 2)
#     .prefetch(tf.data.AUTOTUNE)
# )
# # AUTOTUNE = tf.data.AUTOTUNE
# # train_ds = ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
# print(train_ds)
# # <_PrefetchDataset element_spec=TensorSpec(shape=(32, 64, 64, 3), dtype=tf.float32, name=None)>

In [62]:
# import tensorflow as tf
import pathlib

# Assuming 'data_dir' now points to the directory with only images
data_dir = pathlib.Path(data_dir)
image_files = list(data_dir.glob('*.jpg'))  # Adjust the pattern if needed
image_files = [str(file) for file in image_files]

def load_and_preprocess_image(path):
    image = tf.io.read_file(path)
    image = tf.image.decode_jpeg(image, channels=img_channels)
    image = tf.image.resize(image, (img_height, img_width))
    print(image)
    #image = tf.expand_dims(image, axis=-1)
    print(image)
    return image

ds = tf.data.Dataset.from_tensor_slices(image_files)
print("ds:")
for path in ds:
    print("Path:", path.numpy().decode('utf-8'))
train_ds = ds.map(load_and_preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
print("\nAfter map:")
for image in train_ds.take(1):  # Take one element to check
    print("Image shape:", image.shape)
train_ds = train_ds.batch(batch_size, drop_remainder=True)
print("\nAfter batch:")
for batch in train_ds.take(1):
    print("Batch shape:", batch.shape)
train_ds = train_ds.shuffle(batch_size * 2)
print("\nAfter shuffle:")
for batch in train_ds.take(1):
    print("Batch shape:", batch.shape)
train_ds = train_ds.prefetch(tf.data.AUTOTUNE)
print("\nAfter prefetch:")
for batch in train_ds.take(1):
    print("Batch shape:", batch.shape)
# Now inspect the shape
# print("here")
# print(ds.take(1))
# for element in ds.take(1):
#     print(element.shape)
print("leng")
print(len(train_ds))


ds:
Path: /home/republic/Documents/Ganadev/Sonar/Datasets/Data Preparation 4/mine_like_object_right/Clipboard_05-15-2024_24.jpg
Path: /home/republic/Documents/Ganadev/Sonar/Datasets/Data Preparation 4/mine_like_object_right/Clipboard_05-15-2024_18.jpg
Path: /home/republic/Documents/Ganadev/Sonar/Datasets/Data Preparation 4/mine_like_object_right/Clipboard_05-15-2024_42.jpg
Path: /home/republic/Documents/Ganadev/Sonar/Datasets/Data Preparation 4/mine_like_object_right/Clipboard_05-15-2024_23.jpg
Path: /home/republic/Documents/Ganadev/Sonar/Datasets/Data Preparation 4/mine_like_object_right/Clipboard_05-15-2024_50.jpg
Path: /home/republic/Documents/Ganadev/Sonar/Datasets/Data Preparation 4/mine_like_object_right/Clipboard_05-15-2024_32.jpg
Path: /home/republic/Documents/Ganadev/Sonar/Datasets/Data Preparation 4/mine_like_object_right/Clipboard_05-15-2024_51.jpg
Path: /home/republic/Documents/Ganadev/Sonar/Datasets/Data Preparation 4/mine_like_object_right/Clipboard_05-15-2024_54.jpg
Path

## Network architecture

In [63]:
# Kernel initializer to use
def kernel_init(scale):
    scale = max(scale, 1e-10)
    return keras.initializers.VarianceScaling(
        scale, mode="fan_avg", distribution="uniform"
    )


class AttentionBlock(layers.Layer):
    """Applies self-attention.

    Args:
        units: Number of units in the dense layers
        groups: Number of groups to be used for GroupNormalization layer
    """

    def __init__(self, units, groups=8, **kwargs):
        self.units = units
        self.groups = groups
        super().__init__(**kwargs)

        self.norm = layers.GroupNormalization(groups=groups)
        self.query = layers.Dense(units, kernel_initializer=kernel_init(1.0))
        self.key = layers.Dense(units, kernel_initializer=kernel_init(1.0))
        self.value = layers.Dense(units, kernel_initializer=kernel_init(1.0))
        self.proj = layers.Dense(units, kernel_initializer=kernel_init(0.0))

    def call(self, inputs):
        batch_size = tf.shape(inputs)[0]
        height = tf.shape(inputs)[1]
        width = tf.shape(inputs)[2]
        scale = tf.cast(self.units, tf.float32) ** (-0.5)

        inputs = self.norm(inputs)
        q = self.query(inputs)
        k = self.key(inputs)
        v = self.value(inputs)

        attn_score = tf.einsum("bhwc, bHWc->bhwHW", q, k) * scale
        attn_score = tf.reshape(attn_score, [batch_size, height, width, height * width])

        attn_score = tf.nn.softmax(attn_score, -1)
        attn_score = tf.reshape(attn_score, [batch_size, height, width, height, width])

        proj = tf.einsum("bhwHW,bHWc->bhwc", attn_score, v)
        proj = self.proj(proj)
        return inputs + proj


class TimeEmbedding(layers.Layer):
    def __init__(self, dim, **kwargs):
        super().__init__(**kwargs)
        self.dim = dim
        self.half_dim = dim // 2
        self.emb = math.log(10000) / (self.half_dim - 1)
        self.emb = tf.exp(tf.range(self.half_dim, dtype=tf.float32) * -self.emb)

    def call(self, inputs):
        inputs = tf.cast(inputs, dtype=tf.float32)
        emb = inputs[:, None] * self.emb[None, :]
        emb = tf.concat([tf.sin(emb), tf.cos(emb)], axis=-1)
        return emb


def ResidualBlock(width, groups=8, activation_fn=keras.activations.swish):
    def apply(inputs):
        x, t = inputs
        input_width = x.shape[3]

        if input_width == width:
            residual = x
        else:
            residual = layers.Conv2D(
                width, kernel_size=1, kernel_initializer=kernel_init(1.0)
            )(x)

        temb = activation_fn(t)
        temb = layers.Dense(width, kernel_initializer=kernel_init(1.0))(temb)[
            :, None, None, :
        ]

        x = layers.GroupNormalization(groups=groups)(x)
        x = activation_fn(x)
        x = layers.Conv2D(
            width, kernel_size=3, padding="same", kernel_initializer=kernel_init(1.0)
        )(x)

        x = layers.Add()([x, temb])
        x = layers.GroupNormalization(groups=groups)(x)
        x = activation_fn(x)

        x = layers.Conv2D(
            width, kernel_size=3, padding="same", kernel_initializer=kernel_init(0.0)
        )(x)
        x = layers.Add()([x, residual])
        return x

    return apply


def DownSample(width):
    def apply(x):
        x = layers.Conv2D(
            width,
            kernel_size=3,
            strides=2,
            padding="same",
            kernel_initializer=kernel_init(1.0),
        )(x)
        return x

    return apply


def UpSample(width, interpolation="nearest"):
    def apply(x):
        x = layers.UpSampling2D(size=2, interpolation=interpolation)(x)
        x = layers.Conv2D(
            width, kernel_size=3, padding="same", kernel_initializer=kernel_init(1.0)
        )(x)
        return x

    return apply


def TimeMLP(units, activation_fn=keras.activations.swish):
    def apply(inputs):
        temb = layers.Dense(
            units, activation=activation_fn, kernel_initializer=kernel_init(1.0)
        )(inputs)
        temb = layers.Dense(units, kernel_initializer=kernel_init(1.0))(temb)
        return temb

    return apply


def build_model(
    img_size,
    img_channels,
    widths,
    has_attention,
    num_res_blocks=2,
    norm_groups=8,
    interpolation="nearest",
    activation_fn=keras.activations.swish,
):
    image_input = layers.Input(
        shape=(img_size[0], img_size[1], img_channels), name="image_input"
    )
    print(image_input)
    time_input = keras.Input(shape=(), dtype=tf.int64, name="time_input")

    x = layers.Conv2D(
        first_conv_channels,
        kernel_size=(3, 3),
        padding="same",
        kernel_initializer=kernel_init(1.0),
    )(image_input)

    temb = TimeEmbedding(dim=first_conv_channels * 4)(time_input)
    temb = TimeMLP(units=first_conv_channels * 4, activation_fn=activation_fn)(temb)

    skips = [x]

    # DownBlock
    for i in range(len(widths)):
        for _ in range(num_res_blocks):
            x = ResidualBlock(
                widths[i], groups=norm_groups, activation_fn=activation_fn
            )([x, temb])
            if has_attention[i]:
                x = AttentionBlock(widths[i], groups=norm_groups)(x)
            skips.append(x)

        if widths[i] != widths[-1]:
            x = DownSample(widths[i])(x)
            skips.append(x)

    # MiddleBlock
    x = ResidualBlock(widths[-1], groups=norm_groups, activation_fn=activation_fn)(
        [x, temb]
    )
    x = AttentionBlock(widths[-1], groups=norm_groups)(x)
    x = ResidualBlock(widths[-1], groups=norm_groups, activation_fn=activation_fn)(
        [x, temb]
    )

    # UpBlock
    for i in reversed(range(len(widths))):
        for _ in range(num_res_blocks + 1):
            x = layers.Concatenate(axis=-1)([x, skips.pop()])
            x = ResidualBlock(
                widths[i], groups=norm_groups, activation_fn=activation_fn
            )([x, temb])
            if has_attention[i]:
                x = AttentionBlock(widths[i], groups=norm_groups)(x)

        if i != 0:
            x = UpSample(widths[i], interpolation=interpolation)(x)

    # End block
    x = layers.GroupNormalization(groups=norm_groups)(x)
    x = activation_fn(x)
    x = layers.Conv2D(3, (3, 3), padding="same", kernel_initializer=kernel_init(0.0))(x)
    return keras.Model([image_input, time_input], x, name="unet")

## Training

In [64]:
# Build the unet model
network = build_model(
    img_size=img_size,
    img_channels=img_channels,
    widths=widths,
    has_attention=has_attention,
    num_res_blocks=num_res_blocks,
    norm_groups=norm_groups,
    activation_fn=keras.activations.swish,
)

print(network)
print(img_size)
print(img_channels)

KerasTensor(type_spec=TensorSpec(shape=(None, 128, 800, 1), dtype=tf.float32, name='image_input'), name='image_input', description="created by layer 'image_input'")
<keras.src.engine.functional.Functional object at 0x7be8cd2dd240>
(128, 800)
1


In [65]:
# create IADB model
model = IADBModel(network=network)

# Compile the model
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
)

In [66]:
print(train_ds)
print(type(train_ds))
# Reset the iterator of the dataset before iterating again
train_ds_iterator = iter(train_ds)
element = next(train_ds_iterator)
print(element)

<_PrefetchDataset element_spec=TensorSpec(shape=(8, 128, 800, 1, 1), dtype=tf.float32, name=None)>
<class 'tensorflow.python.data.ops.prefetch_op._PrefetchDataset'>
tf.Tensor(
[[[[[  3.]]

   [[  0.]]

   [[  0.]]

   ...

   [[ 28.]]

   [[ 18.]]

   [[  6.]]]


  [[[  3.]]

   [[  0.]]

   [[  0.]]

   ...

   [[ 25.]]

   [[ 16.]]

   [[  6.]]]


  [[[  3.]]

   [[  0.]]

   [[  0.]]

   ...

   [[ 21.]]

   [[ 14.]]

   [[  7.]]]


  ...


  [[[  7.]]

   [[  5.]]

   [[  3.]]

   ...

   [[ 15.]]

   [[ 25.]]

   [[ 23.]]]


  [[[  7.]]

   [[  5.]]

   [[  3.]]

   ...

   [[ 12.]]

   [[ 21.]]

   [[ 18.]]]


  [[[  7.]]

   [[  5.]]

   [[  3.]]

   ...

   [[ 10.]]

   [[ 17.]]

   [[ 14.]]]]



 [[[[  4.]]

   [[  4.]]

   [[  4.]]

   ...

   [[ 43.]]

   [[ 42.]]

   [[ 41.]]]


  [[[  4.]]

   [[  4.]]

   [[  4.]]

   ...

   [[ 42.]]

   [[ 40.]]

   [[ 39.]]]


  [[[  4.]]

   [[  4.]]

   [[  4.]]

   ...

   [[ 40.]]

   [[ 37.]]

   [[ 35.]]]


  ...


  [[[  0.]]

 

In [67]:
# Train the model (we used a V100 GPU -> ~8s/epoch)
model.fit(
    train_ds,
    epochs=num_epochs,
    batch_size=batch_size)

ValueError: in user code:

    File "/home/republic/.local/lib/python3.10/site-packages/keras/src/engine/training.py", line 1377, in train_function  *
        return step_function(self, iterator)
    File "/home/republic/.local/lib/python3.10/site-packages/keras/src/engine/training.py", line 1360, in step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/home/republic/.local/lib/python3.10/site-packages/keras/src/engine/training.py", line 1349, in run_step  **
        outputs = model.train_step(data)
    File "/home/republic/Documents/Ganadev/Sonar/Models/Diffusion/Diffusion_TF/projects/IADB/model.py", line 62, in train_step
        x_alpha = alpha_bc * x1 + (1 - alpha_bc) * x0

    ValueError: Dimensions must be equal, but are 8 and 128 for '{{node mul}} = Mul[T=DT_FLOAT](ExpandDims_2, IteratorGetNext)' with input shapes: [8,1,1,1], [8,128,800,1,1].


## Results

In [None]:
num_images = 4 * 8 # don't change this value
print(img_size)
print(img_channels)
x0 = tf.random.normal(shape=(num_images, img_size[0], img_size[1], img_channels))
print(tf.shape(x0))
_, trajectory = model.sample_iadb(x0, nb_step=1024)

In [None]:
plot_images(trajectory, gen_path)

In [None]:
make_gif(gen_path)

In [None]:
from IPython.display import Image

Image(open(gen_path + 'diffusion.gif','rb').read())