In [12]:
import math
import numpy as np
import matplotlib.pyplot as plt

# Requires TensorFlow >=2.11 for the GroupNormalization layer.
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_datasets as tfds

In [13]:
batch_size = 32
num_epochs = 1  # Just for the sake of demonstration
total_timesteps = 1000
norm_groups = 8  # Number of groups used in GroupNormalization layer
learning_rate = 2e-4

img_size = 64
img_channels = 3
clip_min = -1.0
clip_max = 1.0

first_conv_channels = 64
channel_multiplier = [1, 2, 4, 8]
widths = [first_conv_channels * mult for mult in channel_multiplier]
has_attention = [False, False, True, True]
num_res_blocks = 2  # Number of residual blocks

dataset_name = "oxford_flowers102"
splits = ["train"]

In [14]:
# Load the dataset
(ds,) = tfds.load(dataset_name, split=splits, with_info=False, shuffle_files=True)


def augment(img):
    """Flips an image left/right randomly."""
    return tf.image.random_flip_left_right(img)


def resize_and_rescale(img, size):
    """Resize the image to the desired size first and then
    rescale the pixel values in the range [-1.0, 1.0].

    Args:
        img: Image tensor
        size: Desired image size for resizing
    Returns:
        Resized and rescaled image tensor
    """

    height = tf.shape(img)[0]
    width = tf.shape(img)[1]
    crop_size = tf.minimum(height, width)

    img = tf.image.crop_to_bounding_box(
        img,
        (height - crop_size) // 2,
        (width - crop_size) // 2,
        crop_size,
        crop_size,
    )

    # Resize
    img = tf.cast(img, dtype=tf.float32)
    img = tf.image.resize(img, size=size, antialias=True)

    # Rescale the pixel values
    img = img / 127.5 - 1.0
    img = tf.clip_by_value(img, clip_min, clip_max)
    return img


def train_preprocessing(x):
    img = x["image"]
    img = resize_and_rescale(img, size=(img_size, img_size))
    img = augment(img)
    return img


train_ds = (
    ds.map(train_preprocessing, num_parallel_calls=tf.data.AUTOTUNE)
    .batch(batch_size, drop_remainder=True)
    .shuffle(batch_size * 2)
    .prefetch(tf.data.AUTOTUNE)
)

In [18]:
for i in train_ds.take(1):
    print()
i




<tf.Tensor: shape=(32, 64, 64, 3), dtype=float32, numpy=
array([[[[-0.8306644 , -0.7371212 , -0.95182765],
         [-0.7626058 , -0.61431676, -0.93218565],
         [-0.729423  , -0.54163826, -0.89981776],
         ...,
         [-0.80454034, -0.79692495, -0.9529009 ],
         [-0.8358447 , -0.77503514, -0.9722911 ],
         [-0.89407015, -0.8369471 , -0.98566157]],

        [[-0.9097879 , -0.87738204, -0.97831434],
         [-0.8989304 , -0.833475  , -0.9730195 ],
         [-0.8452291 , -0.72692204, -0.9410543 ],
         ...,
         [-0.7874836 , -0.7885021 , -0.9680369 ],
         [-0.7820765 , -0.707676  , -0.98519665],
         [-0.8749994 , -0.82552   , -0.98216295]],

        [[-0.9407452 , -0.92313987, -0.9856089 ],
         [-0.94402033, -0.92948395, -0.9812078 ],
         [-0.9340651 , -0.9103471 , -0.9849005 ],
         ...,
         [-0.76639307, -0.72621936, -0.946405  ],
         [-0.73303014, -0.649737  , -0.9730999 ],
         [-0.7895955 , -0.7183142 , -0.96416277

In [17]:
train_ds

<_PrefetchDataset element_spec=TensorSpec(shape=(32, 64, 64, 3), dtype=tf.float32, name=None)>

In [20]:
class GaussianDiffusion:
    """Gaussian diffusion utility.

    Args:
        beta_start: Start value of the scheduled variance
        beta_end: End value of the scheduled variance
        timesteps: Number of time steps in the forward process
    """

    def __init__(
        self,
        beta_start=1e-4,
        beta_end=0.02,
        timesteps=1000,
        clip_min=-1.0,
        clip_max=1.0,
    ):
        self.beta_start = beta_start
        self.beta_end = beta_end
        self.timesteps = timesteps
        self.clip_min = clip_min
        self.clip_max = clip_max

        # Define the linear variance schedule
        self.betas = betas = np.linspace(
            beta_start,
            beta_end,
            timesteps,
            dtype=np.float64,  # Using float64 for better precision
        )
        self.num_timesteps = int(timesteps)

        alphas = 1.0 - betas
        alphas_cumprod = np.cumprod(alphas, axis=0)
        alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])

        self.betas = tf.constant(betas, dtype=tf.float32)
        self.alphas_cumprod = tf.constant(alphas_cumprod, dtype=tf.float32)
        self.alphas_cumprod_prev = tf.constant(alphas_cumprod_prev, dtype=tf.float32)

        # Calculations for diffusion q(x_t | x_{t-1}) and others
        self.sqrt_alphas_cumprod = tf.constant(
            np.sqrt(alphas_cumprod), dtype=tf.float32
        )

        self.sqrt_one_minus_alphas_cumprod = tf.constant(
            np.sqrt(1.0 - alphas_cumprod), dtype=tf.float32
        )

        self.log_one_minus_alphas_cumprod = tf.constant(
            np.log(1.0 - alphas_cumprod), dtype=tf.float32
        )

        self.sqrt_recip_alphas_cumprod = tf.constant(
            np.sqrt(1.0 / alphas_cumprod), dtype=tf.float32
        )
        self.sqrt_recipm1_alphas_cumprod = tf.constant(
            np.sqrt(1.0 / alphas_cumprod - 1), dtype=tf.float32
        )

        # Calculations for posterior q(x_{t-1} | x_t, x_0)
        posterior_variance = (
            betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
        )
        self.posterior_variance = tf.constant(posterior_variance, dtype=tf.float32)

        # Log calculation clipped because the posterior variance is 0 at the beginning
        # of the diffusion chain
        self.posterior_log_variance_clipped = tf.constant(
            np.log(np.maximum(posterior_variance, 1e-20)), dtype=tf.float32
        )

        self.posterior_mean_coef1 = tf.constant(
            betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod),
            dtype=tf.float32,
        )

        self.posterior_mean_coef2 = tf.constant(
            (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod),
            dtype=tf.float32,
        )

    def _extract(self, a, t, x_shape):
        """Extract some coefficients at specified timesteps,
        then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.

        Args:
            a: Tensor to extract from
            t: Timestep for which the coefficients are to be extracted
            x_shape: Shape of the current batched samples
        """
        batch_size = x_shape[0]
        out = tf.gather(a, t)
        return tf.reshape(out, [batch_size, 1, 1, 1])

    def q_mean_variance(self, x_start, t):
        """Extracts the mean, and the variance at current timestep.

        Args:
            x_start: Initial sample (before the first diffusion step)
            t: Current timestep
        """
        x_start_shape = tf.shape(x_start)
        mean = self._extract(self.sqrt_alphas_cumprod, t, x_start_shape) * x_start
        variance = self._extract(1.0 - self.alphas_cumprod, t, x_start_shape)
        log_variance = self._extract(
            self.log_one_minus_alphas_cumprod, t, x_start_shape
        )
        return mean, variance, log_variance

    def q_sample(self, x_start, t, noise):
        """Diffuse the data.

        Args:
            x_start: Initial sample (before the first diffusion step)
            t: Current timestep
            noise: Gaussian noise to be added at the current timestep
        Returns:
            Diffused samples at timestep `t`
        """
        x_start_shape = tf.shape(x_start)
        return (
            self._extract(self.sqrt_alphas_cumprod, t, tf.shape(x_start)) * x_start
            + self._extract(self.sqrt_one_minus_alphas_cumprod, t, x_start_shape)
            * noise
        )

    def predict_start_from_noise(self, x_t, t, noise):
        x_t_shape = tf.shape(x_t)
        return (
            self._extract(self.sqrt_recip_alphas_cumprod, t, x_t_shape) * x_t
            - self._extract(self.sqrt_recipm1_alphas_cumprod, t, x_t_shape) * noise
        )

    def q_posterior(self, x_start, x_t, t):
        """Compute the mean and variance of the diffusion
        posterior q(x_{t-1} | x_t, x_0).

        Args:
            x_start: Stating point(sample) for the posterior computation
            x_t: Sample at timestep `t`
            t: Current timestep
        Returns:
            Posterior mean and variance at current timestep
        """

        x_t_shape = tf.shape(x_t)
        posterior_mean = (
            self._extract(self.posterior_mean_coef1, t, x_t_shape) * x_start
            + self._extract(self.posterior_mean_coef2, t, x_t_shape) * x_t
        )
        posterior_variance = self._extract(self.posterior_variance, t, x_t_shape)
        posterior_log_variance_clipped = self._extract(
            self.posterior_log_variance_clipped, t, x_t_shape
        )
        return posterior_mean, posterior_variance, posterior_log_variance_clipped

    def p_mean_variance(self, pred_noise, x, t, clip_denoised=True):
        x_recon = self.predict_start_from_noise(x, t=t, noise=pred_noise)
        if clip_denoised:
            x_recon = tf.clip_by_value(x_recon, self.clip_min, self.clip_max)

        model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
            x_start=x_recon, x_t=x, t=t
        )
        return model_mean, posterior_variance, posterior_log_variance

    def p_sample(self, pred_noise, x, t, clip_denoised=True):
        """Sample from the diffusion model.

        Args:
            pred_noise: Noise predicted by the diffusion model
            x: Samples at a given timestep for which the noise was predicted
            t: Current timestep
            clip_denoised (bool): Whether to clip the predicted noise
                within the specified range or not.
        """
        model_mean, _, model_log_variance = self.p_mean_variance(
            pred_noise, x=x, t=t, clip_denoised=clip_denoised
        )
        noise = tf.random.normal(shape=x.shape, dtype=x.dtype)
        # No noise when t == 0
        nonzero_mask = tf.reshape(
            1 - tf.cast(tf.equal(t, 0), tf.float32), [tf.shape(x)[0], 1, 1, 1]
        )
        return model_mean + nonzero_mask * tf.exp(0.5 * model_log_variance) * noise

In [41]:
beta_start=1e-4
beta_end=0.02
timesteps=1000
clip_min=-1.0
clip_max=1.0

# Define the linear variance schedule
betas = betas = np.linspace(
    beta_start,
    beta_end,
    timesteps,
    dtype=np.float64,  # Using float64 for better precision
)
np.mean(betas), np.std(betas)

(np.float64(0.010049999999999998), np.float64(0.005750382688807276))

In [57]:
num_timesteps = int(timesteps)
alphas = 1.0 - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
betas = tf.constant(betas, dtype=tf.float32)
alphas_cumprod = tf.constant(alphas_cumprod, dtype=tf.float32)
alphas_cumprod_prev = tf.constant(alphas_cumprod_prev, dtype=tf.float32)

In [56]:
alphas_cumprod

<tf.Tensor: shape=(1000,), dtype=float32, numpy=
array([9.99899983e-01, 9.99780059e-01, 9.99640286e-01, 9.99480605e-01,
       9.99301016e-01, 9.99101520e-01, 9.98882174e-01, 9.98642981e-01,
       9.98383999e-01, 9.98105168e-01, 9.97806549e-01, 9.97488141e-01,
       9.97149944e-01, 9.96792018e-01, 9.96414304e-01, 9.96016920e-01,
       9.95599866e-01, 9.95163143e-01, 9.94706810e-01, 9.94230807e-01,
       9.93735254e-01, 9.93220150e-01, 9.92685556e-01, 9.92131472e-01,
       9.91557896e-01, 9.90964949e-01, 9.90352631e-01, 9.89720941e-01,
       9.89069939e-01, 9.88399625e-01, 9.87710118e-01, 9.87001419e-01,
       9.86273587e-01, 9.85526621e-01, 9.84760582e-01, 9.83975530e-01,
       9.83171523e-01, 9.82348561e-01, 9.81506765e-01, 9.80646074e-01,
       9.79766607e-01, 9.78868425e-01, 9.77951586e-01, 9.77016151e-01,
       9.76062119e-01, 9.75089550e-01, 9.74098563e-01, 9.73089159e-01,
       9.72061455e-01, 9.71015394e-01, 9.69951153e-01, 9.68868792e-01,
       9.67768312e-01, 9.666

In [62]:
# Calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = (
    betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
)
posterior_variance = tf.constant(posterior_variance, dtype=tf.float32)
posterior_log_variance_clipped = tf.constant(
    np.log(np.maximum(posterior_variance, 1e-20)), dtype=tf.float32
)
posterior_mean_coef1 = tf.constant(
    betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod),
    dtype=tf.float32,
)
posterior_mean_coef2 = tf.constant(
    (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod),
    dtype=tf.float32,
)
# Calculations for diffusion q(x_t | x_{t-1}) and others
sqrt_alphas_cumprod = tf.constant(
    np.sqrt(alphas_cumprod), dtype=tf.float32
)
sqrt_one_minus_alphas_cumprod = tf.constant(
    np.sqrt(1.0 - alphas_cumprod), dtype=tf.float32
)
log_one_minus_alphas_cumprod = tf.constant(
    np.log(1.0 - alphas_cumprod), dtype=tf.float32
)
sqrt_recip_alphas_cumprod = tf.constant(
    np.sqrt(1.0 / alphas_cumprod), dtype=tf.float32
)
sqrt_recipm1_alphas_cumprod = tf.constant(
    np.sqrt(1.0 / alphas_cumprod - 1), dtype=tf.float32
)

In [22]:
gdf_util = GaussianDiffusion(timesteps=total_timesteps)

In [23]:
gdf_util.p_sample

<bound method GaussianDiffusion.p_sample of <__main__.GaussianDiffusion object at 0x133639850>>

In [24]:
images = i[0]


In [30]:
t = tf.random.uniform(
    minval=0, 
    maxval=total_timesteps, 
    shape=(batch_size,), 
    dtype=tf.int64
)
noise = tf.random.normal(shape=tf.shape(images), dtype=images.dtype)

In [37]:
gdf_util.betas

<tf.Tensor: shape=(1000,), dtype=float32, numpy=
array([0.0001    , 0.00011992, 0.00013984, 0.00015976, 0.00017968,
       0.0001996 , 0.00021952, 0.00023944, 0.00025936, 0.00027928,
       0.0002992 , 0.00031912, 0.00033904, 0.00035896, 0.00037888,
       0.0003988 , 0.00041872, 0.00043864, 0.00045856, 0.00047848,
       0.0004984 , 0.00051832, 0.00053824, 0.00055816, 0.00057808,
       0.000598  , 0.00061792, 0.00063784, 0.00065776, 0.00067768,
       0.0006976 , 0.00071752, 0.00073744, 0.00075736, 0.00077728,
       0.0007972 , 0.00081712, 0.00083704, 0.00085696, 0.00087688,
       0.0008968 , 0.00091672, 0.00093664, 0.00095656, 0.00097648,
       0.0009964 , 0.00101632, 0.00103624, 0.00105616, 0.00107608,
       0.001096  , 0.00111592, 0.00113584, 0.00115576, 0.00117568,
       0.0011956 , 0.00121552, 0.00123544, 0.00125536, 0.00127528,
       0.0012952 , 0.00131512, 0.00133504, 0.00135495, 0.00137487,
       0.00139479, 0.00141471, 0.00143463, 0.00145455, 0.00147447,
       0.0014

In [36]:
x_start_shape = tf.shape(x_start)
return (
    self._extract(self.sqrt_alphas_cumprod, t, tf.shape(x_start)) * x_start
    + self._extract(self.sqrt_one_minus_alphas_cumprod, t, x_start_shape)
    * noise
)

NameError: name 'x_start' is not defined

In [35]:
# 4. Diffuse the images with noise
images_t = gdf_util.q_sample(images, t, noise)

InvalidArgumentError: {{function_node __wrapped__Reshape_device_/job:localhost/replica:0/task:0/device:CPU:0}} Input to reshape is a tensor with 32 values, but the requested shape has 64 [Op:Reshape]