In [None]:
!pip install transformers
# !pip install --upgrade --force-reinstall "tensorflow==2.15.1"

In [None]:
import torch
from diffusers import AutoencoderKL
import os, math, numpy as np, pandas as pd, tensorflow as tf
from tensorflow import keras
from keras import layers
from keras.models import Model
from matplotlib import pyplot as plt
from tensorflow.keras import mixed_precision
# mixed_precision.set_global_policy("mixed_float16")  


In [None]:
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema")

torch_device = "cuda" if torch.cuda.is_available() else "cpu"
vae.to(torch_device)
vae.eval()
print("VAE device:", torch_device)

In [None]:

train_label_embeddings =  np.load("/kaggle/input/text-to-image-dataset/text_embeddings.npy")
latent_train_data = np.load("/kaggle/input/text-to-image-dataset/image_latents.npy")
 

print("training dataset ready")
print("latent_train_data shape:", latent_train_data.shape)

In [None]:
def imshow(img):
    def norm_0_1(img):
        return (img + 1) / 2  # from [-1,1] to [0,1]
    if img.shape[-1] == 1:
        img = img.reshape(img.shape[0], img.shape[1])
    img = np.clip(img, -1, 1)
    plt.imshow(norm_0_1(img))

def plot_images(imgs, size=8, nrows=8, save_name=None):
    plt.rcParams["figure.figsize"] = (size, size)
    for i in range(len(imgs)):
        ax = plt.subplot(nrows, nrows, i + 1)
        imshow(imgs[i])
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
    if save_name:
        plt.savefig(f"{save_name}.png", dpi=200)
    plt.show()

In [None]:
def decode_latents(latent_arr, std_latent=1.0, batch_size=32):

    decoded_imgs = []
    N = latent_arr.shape[0]
    with torch.no_grad():
        for start in range(0, N, batch_size):
            end = min(start + batch_size, N)
            # NHWC -> NCHW, float32 tensor on the VAE device
            encoded = torch.from_numpy(latent_arr[start:end] * float(std_latent)).permute(0, 3, 1, 2).to(torch_device, dtype=torch.float32)
            # Decode with the PyTorch VAE
            decoded = vae.decode(encoded).sample  # (N, 3, H_dec, W_dec), in [-1, 1]
            # NCHW -> NHWC, move to CPU numpy
            decoded_imgs.append(decoded.permute(0, 2, 3, 1).cpu().numpy())

    return np.concatenate(decoded_imgs, axis=0)



In [None]:
image_size = 16
num_channels = 4
epochs = 130
class_guidance = 4

block_depth = 3
emb_size = 512
embedding_dims = 32

batch_size = 256
num_imgs = 36

validation_num = 300
train_size = 100000
learning_rate = 3e-4

MODEL_NAME = "text_to_image"
home_dir = MODEL_NAME
os.makedirs(home_dir, exist_ok=True)
model_path = os.path.join(home_dir, MODEL_NAME + ".h5")

In [None]:
def attention(qkv):
    q, k, v = qkv
    vector = tf.matmul(k, q, transpose_b=True)
    score = tf.nn.softmax(vector)
    o = tf.matmul(score, v)
    return o

def spatial_attention(img):
    filters = img.shape[3]
    orig_shape = (img.shape[1], img.shape[2], img.shape[3])
    img = layers.BatchNormalization()(img)

    q = layers.Conv2D(filters // 8, 1, padding="same")(img)
    k = layers.Conv2D(filters // 8, 1, padding="same")(img)
    v = layers.Conv2D(filters, 1, padding="same")(img)

    k = layers.Reshape((k.shape[1] * k.shape[2], k.shape[3]))(k)
    q = layers.Reshape((q.shape[1] * q.shape[2], q.shape[3]))(q)
    v = layers.Reshape((v.shape[1] * v.shape[2], v.shape[3]))(v)

    img = layers.Lambda(attention)([q, k, v])
    img = layers.Reshape(orig_shape)(img)

    img = layers.Conv2D(filters, 1, padding="same")(img)
    img = layers.BatchNormalization()(img)
    return img

In [None]:
def cross_attention(img, text):
    filters = img.shape[3]
    orig_shape = (img.shape[1], img.shape[2], img.shape[3])
    img = layers.BatchNormalization()(img)
    text = layers.BatchNormalization()(text)

    q = layers.Conv2D(filters // 8, 1, padding="same")(text)
    k = layers.Conv2D(filters // 8, 1, padding="same")(img)
    v = layers.Conv2D(filters, 1, padding="same")(text)

    q = layers.Reshape((q.shape[1] * q.shape[2], q.shape[3]))(q)
    k = layers.Reshape((k.shape[1] * k.shape[2], k.shape[3]))(k)
    v = layers.Reshape((v.shape[1] * v.shape[2], v.shape[3]))(v)

    img = layers.Lambda(attention)([q, k, v])
    img = layers.Reshape(orig_shape)(img)

    img = layers.Conv2D(filters, 1, padding="same")(img)
    img = layers.BatchNormalization()(img)
    return img

In [None]:
def sinusoidal_embedding(x):
    embedding_min_frequency = 1.0
    embedding_max_frequency = 1000.0
    embedding_dims = 32
    frequencies = tf.exp(
        tf.linspace(
            tf.math.log(embedding_min_frequency),
            tf.math.log(embedding_max_frequency),
            embedding_dims // 2,
        )
    )
    angular_speeds = 2.0 * math.pi * frequencies
    embeddings = tf.concat(
        [tf.sin(angular_speeds * x), tf.cos(angular_speeds * x)], axis=3
    )
    return embeddings

In [None]:
def ResidualBlock(channel_num):
    def apply(x):
        input_channel = x.shape[3]
        residual = x if input_channel == channel_num else layers.Conv2D(channel_num, 1)(x)
        x = layers.BatchNormalization(center=False, scale=False)(x)
        x = layers.Conv2D(channel_num, 3, padding="same", activation=keras.activations.swish)(x)
        x = layers.Conv2D(channel_num, 3, padding="same")(x)
        x = layers.Add()([x, residual])
        return x
    return apply

In [None]:
def DownBlock(channel_num, block_depth, use_self_attention=True):
    def apply(x):
        x, skip, emb_and_noise = x
        for _ in range(block_depth):
            x = ResidualBlock(channel_num)(x)
            if use_self_attention:
                att = spatial_attention(x)
                x = layers.Add()([x, att])
                cross_att = cross_attention(x, emb_and_noise)
                x = layers.Add()([x, cross_att])
            skip.append(x)
        x = layers.AveragePooling2D(pool_size=2)(x)
        return x
    return apply

In [None]:
def UpBlock(channel_num, block_depth, use_self_attention=True):
    def apply(x):
        x, skips, emb_and_noise = x
        x = layers.UpSampling2D(size=2, interpolation="bilinear")(x)
        for _ in range(block_depth):
            x = layers.Concatenate()([x, skips.pop()])
            x = ResidualBlock(channel_num)(x)
            if use_self_attention:
                att = spatial_attention(x)
                x = layers.Add()([x, att])
                cross_att = cross_attention(x, emb_and_noise)
                x = layers.Add()([x, cross_att])
        return x
    return apply

In [None]:
def get_network(latent_image_size, block_depth=3, emb_size=512, latent_channels=4):
    noisy_images = keras.Input(shape=(latent_image_size, latent_image_size, latent_channels))
    x = layers.Conv2D(128, 1)(noisy_images)

    noise_variances = keras.Input(shape=(1, 1, 1))
    e = layers.Lambda(sinusoidal_embedding, output_shape=(1, 1, 32),name="time_embedding")(noise_variances)
    e = layers.UpSampling2D(size=(latent_image_size, latent_image_size), interpolation="nearest")(e)

    input_label = layers.Input(shape=(emb_size,))
    emb_label = layers.Dense(emb_size // 2)(input_label)
    emb_label = layers.Reshape((1, 1, emb_size // 2))(emb_label)
    emb_label = layers.UpSampling2D(size=(latent_image_size, latent_image_size), interpolation="nearest")(emb_label)

    emb_and_noise = layers.Concatenate()([e, emb_label])
    skips_connections = []

    x = DownBlock(128, block_depth, use_self_attention=False)([x, skips_connections, emb_and_noise])
    emb_and_noise = layers.AveragePooling2D(pool_size=(2, 2))(emb_and_noise)

    x = DownBlock(256, block_depth)([x, skips_connections, emb_and_noise])
    emb_and_noise = layers.AveragePooling2D(pool_size=(2, 2))(emb_and_noise)

    x = DownBlock(512, block_depth)([x, skips_connections, emb_and_noise])
    emb_and_noise = layers.AveragePooling2D(pool_size=(2, 2))(emb_and_noise)

    for _ in range(block_depth):
        x = ResidualBlock(128*5)(x)
        o = spatial_attention(x)
        x = layers.Add()([x, o])
        cross_att = cross_attention(x, emb_and_noise)
        x = layers.Add()([x, cross_att])

    x = UpBlock(512, block_depth)([x, skips_connections, emb_and_noise])
    emb_and_noise = layers.UpSampling2D(size=(2, 2), interpolation="nearest")(emb_and_noise)

    x = UpBlock(256, block_depth)([x, skips_connections, emb_and_noise])
    emb_and_noise = layers.UpSampling2D(size=(2, 2), interpolation="nearest")(emb_and_noise)

    x = UpBlock(128, block_depth)([x, skips_connections, emb_and_noise])
    emb_and_noise = layers.UpSampling2D(size=(2, 2), interpolation="nearest")(emb_and_noise)

    x = layers.Conv2D(latent_channels, 1, kernel_initializer="zeros")(x)
    return keras.Model([noisy_images, noise_variances, input_label], x, name="unet")


In [None]:
np.random.seed(100)
num_imgs = 25
rand_image = np.random.normal(0, 1, (num_imgs, image_size, image_size, num_channels))

# with strategy.scope():
#     unet = get_network(16)
#     unet.compile(
#         optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
#         loss="mae",
#         steps_per_execution=50,   # try 50 or 100
#     )

unet = get_network(16)
unet.compile(optimizer=keras.optimizers.Adam(learning_rate=learning_rate), loss="mae")
print("Number of parameters:", unet.count_params())

In [None]:
ns = np.arange(10020,10020+5)
labels = []
for n in ns:
    for i in range(5):
        labels.append(train_label_embeddings[n])
labels = np.array(labels)

In [None]:
std_latent = np.std(latent_train_data)*2.5
latent_train_data = np.clip(latent_train_data, -std_latent, std_latent)
latent_train_data = latent_train_data/std_latent
print("std_latent", std_latent)

In [None]:
def add_noise(array, mu=0, std=1):
    x = np.abs(np.random.normal(0, std, 2 * len(array)))
    x = x[x < 3]
    x = x / 3
    x = x[:len(array)]
    noise_levels = x
    noise_levels = np.sin(noise_levels)
    signal_levels = np.sqrt(1-np.square(noise_levels))
    noise_level_reshape = noise_levels[:, None, None, None]
    signal_level_reshape = signal_levels[:, None, None, None]
    pure_noise = np.random.normal(0, 1, size=array.shape).astype("float32")
    noisy_data = array * signal_level_reshape + pure_noise * noise_level_reshape
    return noisy_data, noise_levels


In [None]:
def dynamic_thresholding(img, perc=99.5):
    s = np.percentile(np.abs(img.ravel()), perc)
    s = np.max([s, 1])
    img = img.clip(-s, s) / s
    return img

class Diffuser:
    def __init__(self, denoiser, class_guidance, diffusion_steps, perc_thresholding=99.5, batch_size=64):
        self.denoiser = denoiser
        self.class_guidance = class_guidance
        self.diffusion_steps = diffusion_steps
        self.noise_levels = 1 - np.power(np.arange(0.0001, 0.99, 1 / self.diffusion_steps), 1 / 3)
        self.noise_levels[-1] = 0.01
        self.perc_thresholding = perc_thresholding
        self.batch_size = batch_size

    def predict_x_zero(self, x_t, label, noise_level):
        num_imgs = len(x_t)
        label_empty_ohe = np.zeros(shape=label.shape)
        noise_in = np.array([noise_level] * num_imgs)[:, None, None, None]
        nn_inputs = [np.vstack([x_t, x_t]),
                     np.vstack([noise_in, noise_in]),
                     np.vstack([label, label_empty_ohe])]
        x0_pred = self.denoiser.predict(nn_inputs, batch_size=self.batch_size, verbose=0)
        x0_pred_label = x0_pred[:num_imgs]
        x0_pred_no_label = x0_pred[num_imgs:]
        x0_pred = self.class_guidance * x0_pred_label + (1 - self.class_guidance) * x0_pred_no_label
        x0_pred = dynamic_thresholding(x0_pred, perc=self.perc_thresholding)
        return x0_pred

    def reverse_diffusion(self, seeds, label, show_img=False):
        new_img = seeds
        for i in range(len(self.noise_levels) - 1):
            curr_noise, next_noise = self.noise_levels[i], self.noise_levels[i + 1]
            x0_pred = self.predict_x_zero(new_img, label, curr_noise)
            new_img = ((curr_noise - next_noise) * x0_pred + next_noise * new_img) / curr_noise
            if show_img:
                plot_images(x0_pred, nrows=int(np.sqrt(len(new_img))), save_name=str(i), size=12)
                plt.show()
        return x0_pred


In [None]:
def batch_generator(model, model_path, train_data, train_label_embeddings, epochs,
                    batch_size, rand_image, labels, home_dir, diffuser, epoch=0):
    indices = np.arange(len(train_data))
    batch = []
    print("Training for {0}".format(epochs))
    while epoch < epochs:
        print("saving model:")
        model.save(model_path)  # TF SavedModel

        if epoch % 1 == 0:
            diffuser.denoiser = model
            imgs = diffuser.reverse_diffusion(rand_image, labels)
            imgs = decode_latents(imgs, std_latent=std_latent)
            img_path = os.path.join(home_dir, str(epoch))
            plot_images(imgs, save_name=img_path, nrows=int(np.sqrt(len(imgs))))

        print("new epoch {0}".format(epoch))
        np.random.shuffle(indices)
        for i in indices:
            batch.append(i)
            if len(batch) == batch_size:
                tr_batch = train_data[batch].copy()

                s = np.random.binomial(1, 0.15, size=batch_size).astype("bool")
                train_label_dropout = train_label_embeddings[batch].copy()
                train_label_dropout[s] = np.zeros(shape=train_label_embeddings.shape[1])

                noisy_train_data, noise_level_train = add_noise(tr_batch, mu=0, std=1)
                noise_level_train = noise_level_train[:, None, None, None]

                yield (noisy_train_data, noise_level_train, train_label_dropout), tr_batch
                batch = []
        epoch += 1


In [None]:
batch_size = 256
diffuser = Diffuser(unet, class_guidance=class_guidance, diffusion_steps=100, perc_thresholding=99.75)

train_generator = batch_generator(unet,
                                  model_path,
                                  latent_train_data,
                                  train_label_embeddings,
                                  epochs,
                                  batch_size,
                                  rand_image,
                                  labels,
                                  home_dir,
                                  diffuser)

In [None]:
history = unet.fit(x=train_generator, epochs=epochs)

In [None]:
!pip install git+https://github.com/openai/CLIP.git

In [None]:
unet = keras.models.load_model(
        "/kaggle/working/text_to_image/text_to_image.h5",
        custom_objects={"sinusoidal_embedding": sinusoidal_embedding, "attention": attention},
        compile=False,
    )

In [None]:
diffuser = Diffuser(unet,
                    class_guidance=6,
                    diffusion_steps=100, perc_thresholding=99.75)

In [None]:
import torch
import clip
# Load the CLIP model and tokenizer  cuda
model, tokenizer = clip.load("ViT-B/32")

In [None]:
text = "beautiful landscape with river and flowers"
text_tokens = clip.tokenize(text, truncate=False).cuda()

with torch.no_grad():
    text_encoding = model.encode_text(text_tokens)

text_encoding = np.vstack(text_encoding.repeat(4, 1).cpu())

In [None]:
text_encoding.shape

In [None]:
rand_image = np.random.normal(0, 1, (4, 16, 16, 4))

In [None]:
from tqdm import tqdm
import matplotlib.pyplot as plt
diffuser.denoiser = unet

with torch.no_grad():
    imgs = diffuser.reverse_diffusion(rand_image, text_encoding)
    imgs = decode_latents(imgs, std_latent=std_latent)

In [None]:
imgs.shape

In [None]:
plot_images(imgs, nrows=2, save_name = text)