# About Notebook

It is basic starter notebook for Stable-Diffusion Image-to-Prompt, written in `keras`, `keras-cv`, and `transformer` from huggingface. It contains both **training code** and **inference code**. Please note, the training code is configured on **TPU-VM**, otherwise on **GPU**.

In [None]:
# set True for 'Inference' on GPU (turn off the internet)
# set False for 'Training' on TPU-VM (turn on the internet)
SUBMIT = False

In [None]:
import os
import warnings
import logging
from IPython.display import clear_output
warnings.filterwarnings('ignore')
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
logging.disable(logging.WARNING)


if SUBMIT:
    !pip install --no-deps ../input/keras-cv/keras_cv-0.4.2-py3-none-any.whl
else:
    !pip install -U -q scikit-learn
    !pip install -U -q transformers
    
clear_output()

In [None]:
import os
import glob
import numpy as np
import pandas as pd 
import warnings
import random
from sklearn.model_selection import train_test_split

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import optimizers
from tensorflow.keras import metrics
from tensorflow.keras import callbacks
from tensorflow.keras.optimizers import schedules
from tensorflow.python.client import device_lib

try:
    import keras_cv
    from keras_cv.layers import RandomBrightness
    from keras_cv.layers import RandomChannelShift
    from keras_cv.layers import RandomFlip
except:
    from keras.layers import RandomFlip

import transformers
from transformers import AutoConfig
from transformers import TFBertTokenizer
from transformers import TFAutoModel
from transformers import ViTFeatureExtractor
from transformers import TFViTModel
transformers.logging.disable_progress_bar()

tf.__version__, transformers.__version__

# Devices

In [None]:
if SUBMIT:
    physical_devices = tf.config.list_physical_devices('GPU')
    tf.config.optimizer.set_jit(True)
    keras.mixed_precision.set_global_policy("mixed_float16")
    [tf.config.experimental.set_memory_growth(pd, True) for pd in physical_devices]
    strategy = tf.distribute.MirroredStrategy()
else:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect(tpu="local") # "local" for 1VM TPU
    strategy = tf.distribute.TPUStrategy(tpu)

    
seed = 80
input_size = 224
batch_size = 40 * strategy.num_replicas_in_sync
num_epochs = 14
keras.utils.set_random_seed(seed)

# Data [Training]

In [None]:
df = pd.read_csv('/kaggle/input/diffusiondb-data-cleansing/diffusiondb.csv')
trn_df, val_df = train_test_split(
    df, test_size=0.1, random_state=seed
)
trn_df.shape, val_df.shape

In [None]:
# get max sequence
trn_df["prompt"].apply(lambda x: len(x.split())).describe()

# Dataloader [Training]

In [None]:
model_id = 'sentence-transformers/all-MiniLM-L6-v2'
tokenizer = TFBertTokenizer.from_pretrained(model_id)

def read_image(image_path):
    image = tf.io.read_file(image_path)
    image = tf.io.decode_png(image, 3)
    image = tf.image.resize(image, (input_size, input_size))
    return image

def batch_tokenize(prompt):
    tokenized = tokenizer(
        prompt, 
        padding="max_length", 
        truncation=True, 
        max_length=512
    )
    return tokenized

In [None]:
def dataloader(
    image_paths, 
    prompts, 
    shuffle=True, 
    repeat=True, 
    batch_size=1,
):
    dataset = tf.data.Dataset.from_tensor_slices((image_paths, prompts))
    dataset = dataset.map(
        lambda x, y: (read_image(x), y),
        num_parallel_calls=tf.data.AUTOTUNE
    )
    
    dataset = dataset.batch(
        batch_size, drop_remainder=True
    )

    dataset = dataset.map(
        lambda x, y: (x, (batch_tokenize(y))),
        num_parallel_calls=tf.data.AUTOTUNE
    )
    
    if shuffle:
        dataset = dataset.shuffle(
            batch_size * 8, reshuffle_each_iteration = False
        )
    
    if repeat:
        dataset = dataset.repeat()
        
    return dataset.prefetch(tf.data.AUTOTUNE)

In [None]:
train_ds = dataloader(
    trn_df["filepath"].values, 
    trn_df['prompt'].values, 
    shuffle=True, 
    repeat=False,
    batch_size=batch_size
)

valid_ds = dataloader(
    val_df["filepath"].values, 
    val_df['prompt'].values, 
    shuffle=False, 
    repeat=False, 
    batch_size=batch_size
)

# Model

In [None]:
def preprocess(x):
    x = keras.layers.Rescaling(scale=1./127.5, offset=-1.0)(x)
    x = keras.layers.Permute(dims=(3, 1, 2))(x)
    return x
    
def augment(x):
    try:
        pipeline = keras_cv.layers.RandomChoice(
            layers=[
                RandomFlip("horizontal")
            ],
            auto_vectorize=True
        )
    except:
        pipeline = keras.Sequential(
            [
                keras.layers.RandomFlip("horizontal")
            ]
        )
    return pipeline(x)

In [None]:
def get_model(mode='inference'):
    input = keras.Input(shape=(input_size, input_size, 3))
    x = augment(input)
    x = preprocess(x)

    # vit model
    if mode=='inference':
        config =  AutoConfig.from_pretrained(
            '/kaggle/input/cmp-stablediffusion-imgtopt', 
            local_files_only=True
        )
        hf_model = TFViTModel.from_config(config)
        x = hf_model(x)
    else:
        x = TFViTModel.from_pretrained(
            'google/vit-base-patch16-224'
        )(x)
    x = keras.layers.Lambda(lambda v: v[:, 0])(x.last_hidden_state)
    
    # output
    output = keras.layers.Dense(
        384, activation=None, dtype=tf.float32
    )(x)

    model = keras.Model(input, output)
    return model

**Sentence Transformer (all-MiniLM-L6-v2)**

It will be run on inside the model.

In [None]:
# ref. https://www.philschmid.de/tensorflow-sentence-transformers
class TFSentenceTransformer(keras.layers.Layer):
    def __init__(self,  model_id, **kwargs):
        super().__init__(**kwargs)
        self.st_model = TFAutoModel.from_pretrained(model_id, **kwargs)

    def call(self, inputs, normalize=True):
        model_output = self.st_model(inputs)
        embeddings = self.mean_pooling(
            model_output, inputs["attention_mask"]
        )
        if normalize:
            embeddings = self.normalize(embeddings)
        return embeddings

    def mean_pooling(self, model_output, attention_mask):
        token_embeddings = model_output[0]
        input_mask_expanded = tf.cast(
            tf.broadcast_to(
                tf.expand_dims(attention_mask, -1), 
                tf.shape(token_embeddings)
            ),
            token_embeddings.dtype
        )
        token_mask = tf.reduce_sum(
            token_embeddings * input_mask_expanded, axis=1
        )
        mask_clip = tf.clip_by_value(
            tf.reduce_sum(input_mask_expanded, axis=1), 
            1e-9, 
            tf.float32.max
        )
        return token_mask / mask_clip
        
    def normalize(self, embeddings):
        embeddings, _ = tf.linalg.normalize(
            embeddings, 2, axis=1
        )
        return embeddings

In [None]:
class TextToEmbedding(keras.Model):
    '''Transform the text/prompt gt to 384-D embedding with sentence transformer.
    '''
    def __init__(self, model, **kwargs):
        super().__init__(**kwargs)
        self.model = model 
    
    def call(self, inputs):
        return self.model(inputs)
    
    def train_step(self, data):
        x, y = data
        y = tfst_model(y, normalize=False)
        return super().train_step((x, y))
    
    def test_step(self, data):
        x, y = data
        y = tfst_model(y, normalize=False)
        return super().test_step((x, y))
    
    # kaggle.com/code/ipythonx/keras-rsna-breast-cancer-detection
    def save_weights(
        self, filepath, overwrite=True, save_format=None, options=None
    ):
        # Overriding this method will allow us to use the `ModelCheckpoint`
        self.model.save_weights(
            filepath=filepath,
            overwrite=overwrite,
            save_format=save_format,
            options=options,
        )
        
    def save(
        self, filepath, overwrite=True, include_optimizer=True, 
        save_format=None, signatures=None, options=None
    ):
        # Overriding this method will allow us to use the `ModelCheckpoint`
        self.model.save(
            filepath=filepath,
            overwrite=overwrite,
            save_format=save_format,
            options=options,
            include_optimizer=include_optimizer,
            signatures=signatures
        )

In [None]:
# https://stackoverflow.com/a/44933346/9215780
cosine_similarity_loss = keras.losses.CosineSimilarity(
    reduction='none'
)

def CosineEmbeddingLoss(margin=0., target=1):
    def cosine_embedding_loss_fn(input_one, input_two):
        similarity = - cosine_similarity_loss(input_one, input_two)
        return tf.reduce_mean(
            tf.where(
                tf.equal(target, 1),
                1. - similarity,
                tf.maximum(
                    tf.zeros_like(similarity), similarity - margin
                )
            )
        )
    return cosine_embedding_loss_fn

In [None]:
keras.backend.clear_session()

with strategy.scope():
    if SUBMIT:
        model = get_model(mode='inference')
        model.load_weights(
            '/kaggle/input/cmp-stablediffusion-imgtopt/model.02-0.533.h5'
        )
        model.compile(jit_compile=True)
        model.trainable = False
    else:
        # token to text-embedding
        tfst_model = TFSentenceTransformer(model_id)
        tfst_model.trainable = False

        model = get_model(mode='training')
        model = TextToEmbedding(model)

        model.compile(
            optimizer='Adamax',
            loss=CosineEmbeddingLoss(margin=0., target=1),
            metrics=[metrics.CosineSimilarity(name='cos')],
        )
        model.build(
            input_shape=(None, input_size, input_size, 3)
        )
    
clear_output()

In [None]:
# keras.backend.clear_session()
# with strategy.scope():
#     processor = AutoProcessor.from_pretrained(
#         "/kaggle/input/salesforceblip-image-caption"
#     )
#     model = TFBlipForConditionalGeneration.from_pretrained(
#         "/kaggle/input/salesforceblip-image-caption"
#     )
#     model.compile(
#         jit_compile=True
#     )
#     model.trainable = False

In [None]:
model.summary(
    expand_nested=True, 
    line_length=100, 
    show_trainable=True
)

In [None]:
if SUBMIT:
    training_log = pd.read_csv('/kaggle/input/cmp-stablediffusion-imgtopt/training_log.csv')
    display(training_log.head())
else:
    model.fit(
        train_ds, 
        validation_data=valid_ds, 
        epochs=num_epochs,
        callbacks=[
            callbacks.ModelCheckpoint(
                filepath='model.{epoch:02d}-{val_cos:.3f}.h5',
                monitor='val_cos',
                mode='max',
                save_best_only=True
            ),
            callbacks.CSVLogger('training_log.csv')
        ]
    )

# Data [Inference]

In [None]:
comp_path = '/kaggle/input/stable-diffusion-image-to-prompts'
imag_path = glob.glob(f'{comp_path}/images/*.png')
images = os.listdir(os.path.join(comp_path, 'images'))
imgIds = [i.split('.')[0] for i in images]

EMBEDDING_LENGTH = 384
eIds = list(range(EMBEDDING_LENGTH))

imgId_eId = [
    '_'.join(map(str, i)) for i in zip(
        np.repeat(imgIds, EMBEDDING_LENGTH),
        np.tile(range(EMBEDDING_LENGTH), len(imgIds)))
]

# Dataloader [Inference]

In [None]:
def dataloader(
    image_paths, batch_size=1
):
    dataset = tf.data.Dataset.from_tensor_slices(
        (image_paths)
    )
    dataset = dataset.map(read_image, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.batch(batch_size, drop_remainder=False)
    return dataset.prefetch(tf.data.AUTOTUNE)

In [None]:
test_ds = dataloader(
    imag_path, batch_size=batch_size
)

# Prediction

In [None]:
prompt_embeddings = model.predict(test_ds)
prompt_embeddings = np.vstack(prompt_embeddings).flatten()

In [None]:
submission = pd.DataFrame(
    index=imgId_eId,
    data=prompt_embeddings,
    columns=['val']
).rename_axis('imgId_eId')
submission.to_csv('submission.csv')
submission.head()