# Script for training VAE with NASNet.
# Made by [oblivisheee](https://github.com/oblivisheee) and [Integrio Team](https://github.com/Integrio-Team).
*Recommended to copy in your drive and customise for yourself.*<br>
*Its alpha version, for now, im fixing all problems.*





In [None]:
# @title # Installing requirements.
connect_google_drive = True # @param {type:"boolean"}
if connect_google_drive == True:
  from google.colab import drive
  drive.mount('/content/drive/')


!pip install safetensors
!pip install matplotlib
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Flatten, Dense, Lambda, BatchNormalization, Reshape, Conv2DTranspose, UpSampling2D
from tensorflow.keras.models import Model
from tensorflow.keras.applications import NASNetLarge
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import numpy as np
from safetensors.tensorflow import save_file
import matplotlib as plt

In [None]:
# @title # Installing NASNet.
amount_of_filters = 3 # @param {type:"integer"}
width_of_shape = 600 # @param {type:"integer"}
height_of_shape = 850 # @param {type:"integer"}
# Входное изображение
input_shape = (height_of_shape, width_of_shape, amount_of_filters)
input_tensor = Input(shape=input_shape)
include_top = False # @param {type:"boolean"}
# Загрузка предварительно обученной модели NASNetLarge
weights_type = 'imagenet' # @param ["imagenet", "custom"]
if weights_type == 'custom':
  file_path_weights = None
  weights_type = file_path_weights
  # @markdown * Enter file path of weights if you chose "custom".
  file_path_weights = "" # @param {type:"string"}
nasnet_model = NASNetLarge(weights=weights_type, include_top=include_top, input_tensor=input_tensor)

In [None]:
# @title # Model config and compile.

latent_amount = 256 # @param {type:"integer"}
# @markdown * Size of filters like Conv2D and Dense.
size_of_filters = 128 # @param {type:"integer"}

factors_of_reshape = None


for factor in range(5, size_of_filters + 1):
    if size_of_filters % factor == 0:
        second_factor = size_of_filters // factor
        break
# Encoder
x = Conv2D(size_of_filters, (3, 3), activation='relu', padding='same')(nasnet_model.output)
x = BatchNormalization()(x)
x = MaxPooling2D((2, 2), padding='same')(x)
x = Conv2D(size_of_filters, (3, 3), activation='relu', padding='same')(x)
x = BatchNormalization()(x)
x = MaxPooling2D((2, 2), padding='same')(x)
x = Flatten()(x)
x = Dense(size_of_filters, activation='relu')(x)
x = BatchNormalization()(x)
# @markdown * Paremeter of against overtrain.
dropout_rate = 0.15 # @param {type:"number"}
x = tf.keras.layers.Dropout(rate=dropout_rate)(x)
z_mean = Dense(latent_amount)(x)
z_log_var = Dense(latent_amount)(x)

def sampling(args):
    z_mean, z_log_var = args
    epsilon = tf.random.normal(shape=(tf.shape(z_mean)[0], latent_amount), mean=0.0, stddev=1.0)
    return z_mean + tf.exp(0.5 * z_log_var) * epsilon

z = Lambda(sampling, output_shape=(latent_amount,))([z_mean, z_log_var])

class CustomLayer(tf.keras.layers.Layer):
    def call(self, inputs):
        result = self.custom_add(inputs)
        return result

    @staticmethod
    def custom_add(inputs):
        return tf.add(inputs, 1)

custom_layer = CustomLayer()
z_mean_custom = custom_layer(z_mean)
input_layer = Input(shape=input_shape)

# @markdown * Weight of Regularization Layer

weight_of_regularization_layer = 0.01 # @param {type:"number"}
class CustomRegularizationLayer(tf.keras.layers.Layer):
    def __init__(self, weight=weight_of_regularization_layer, **kwargs):
        super(CustomRegularizationLayer, self).__init__(**kwargs)
        self.weight = weight

    def call(self, inputs):
        regularization_loss = tf.reduce_sum(tf.square(inputs)) * self.weight
        self.add_loss(regularization_loss, inputs=inputs)
        return inputs

regularization_layer = CustomRegularizationLayer(weight=0.01)

input_layer = Input(shape=input_shape)

regularized_inputs = regularization_layer(input_layer)

x = Dense(size_of_filters, activation='relu')(regularized_inputs)
output_layer = Dense(size_of_filters, activation='softmax')(x)
x = Dense(size_of_filters, activation='relu', kernel_initializer='glorot_uniform')(input_tensor)
x = Conv2D(size_of_filters, (3, 3), activation='relu', padding='same')(nasnet_model.output)

dynamic_of_sampling_layer = True # @param {type:"boolean"}
class SamplingLayer(tf.keras.layers.Layer):
    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.keras.backend.random_normal(shape=(batch, dim), mean=0.0, stddev=1.0)
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon

    def compute_output_shape(self, input_shape):
        return input_shape[0]

sampling_layer = SamplingLayer(dynamic=dynamic_of_sampling_layer)

#Saving encoder
encoder = Model(input_tensor, [z_mean, z_log_var, z_mean_custom])
encoder_output = encoder(input_tensor)

# Decoder
decoder_input = Input(shape=(latent_amount,))
sampled_z = Lambda(sampling, output_shape=(latent_amount,))([z_mean, z_log_var])
x = Dense(size_of_filters, activation='relu')(decoder_input)
x = BatchNormalization()(x)
x = Reshape((factor, second_factor, 1))(x)
x = Conv2DTranspose(size_of_filters, (3, 3), activation='relu', padding='same')(x)
x = BatchNormalization()(x)
x = UpSampling2D((2, 2))(x)
x = Conv2DTranspose(size_of_filters, (3, 3), activation='relu', padding='same')(x)
x = BatchNormalization()(x)
x = UpSampling2D((2, 2))(x)
decoder_output = Conv2DTranspose(amount_of_filters, (3, 3), activation='sigmoid', padding='same')(x)

# Saving decoder
decoder = Model(decoder_input, decoder_output)

sampled_z = sampling_layer([z_mean, z_log_var])
vae_input = Input(shape=input_shape, name='vae_input')
z_mean, z_log_var, z_mean_custom = encoder(vae_input)
z = Lambda(sampling, output_shape=(latent_amount,))([z_mean, z_log_var])
vae_output = decoder(z)


output_layer_resized = tf.image.resize(vae_output, size=(height_of_shape, width_of_shape), method='nearest')

# vae_loss
reconstruction_loss = tf.reduce_mean(tf.square(vae_input - output_layer_resized), axis=(1, 2, 3))
kl_loss = -0.5 * tf.reduce_mean(1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var), axis=-1)
vae_loss = reconstruction_loss + kl_loss

vae = Model(vae_input, output_layer_resized)

vae.add_loss(vae_loss)
# Model compiling.
vae.compile(optimizer='Adam')

In [None]:
# @title # Set path of train data.
import os
main_dir = "/content/datasets/" # @param {type:"string"}
train_data_dir = "/content/datasets/train_data" # @param {type:"string"}
validation_data_dir = "/content/datasets/validation_dir/" # @param {type:"string"}
if not os.path.exists(main_dir):
    os.makedirs(main_dir)
if not os.path.exists(train_data_dir):
    os.makedirs(train_data_dir)
if not os.path.exists(validation_data_dir):
    os.makedirs(validation_data_dir)

In [None]:
# @title # Dataset config.
def preprocess_image(image):
    if image.shape[-1] == 3:
        image = tf.image.central_crop(image, central_fraction=0.8)
        image = tf.image.resize(image, size=(height_of_shape, width_of_shape))
    else:
        image = tf.image.rgb_to_grayscale(image)
        image = tf.image.central_crop(image, central_fraction=0.8)
        image = tf.image.resize(image, size=(height_of_shape, width_of_shape))
    return image

rescale = 255 # @param {type:"integer"}
rotation_range = 20 # @param {type:"integer"}
width_shift_range = 0.2 # @param {type:"number"}
height_shift_range = 0.2 # @param {type:"number"}
shear_range = 0.2 # @param {type:"number"}
zoom_range = 0.3 # @param {type:"number"}
horizontal_flip = True # @param {type:"boolean"}
fill_mode = "nearest" # @param ["nearest", "reflect", "wrap", "constant"]

class_mode = "input" # @param ["input", "categorical", "binary", "sparse"]
shuffle = True # @param {type:"boolean"}
datagen = ImageDataGenerator(
    rescale=1.0/rescale,
    rotation_range=rotation_range,
    width_shift_range=width_shift_range,
    height_shift_range=height_shift_range,
    shear_range=shear_range,
    zoom_range=zoom_range,
    horizontal_flip=horizontal_flip,
    fill_mode=fill_mode
)

batch_size = 5 # @param {type:"integer"}
image_size = (height_of_shape, width_of_shape)

train_generator = datagen.flow_from_directory(
    train_data_dir,
    target_size=image_size,
    batch_size=batch_size,
    class_mode=class_mode,
    color_mode='rgb',
    shuffle=shuffle
)

validation_generator = datagen.flow_from_directory(
    validation_data_dir,
    target_size=image_size,
    batch_size=batch_size,
    class_mode=class_mode,
    color_mode='rgb',
    shuffle=shuffle
)


train_dataset = tf.data.Dataset.from_generator(
    lambda: train_generator,
    output_signature=tf.TensorSpec(shape=(None, *image_size, 3), dtype=tf.float32)
)

validation_dataset = tf.data.Dataset.from_generator(
    lambda: validation_generator,
    output_signature=tf.TensorSpec(shape=(None, *image_size, 3), dtype=tf.float32)
)


train_dataset = train_dataset.map(preprocess_image)
validation_dataset = validation_dataset.map(preprocess_image)

num_epochs_for_repeat = 15 # @param {type:"integer"}
train_dataset = train_dataset.repeat(num_epochs_for_repeat)
validation_dataset = validation_dataset.repeat(num_epochs_for_repeat)


Found 557 images belonging to 3 classes.
Found 0 images belonging to 0 classes.


In [None]:
# @title # Start training.

num_epochs = 1 # @param {type:"integer"}
learning_rate = 0.01 # @param {type:"number"}
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
checkpoint_per_epoch = True # @param {type:"boolean"}
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='vae_loss', factor=0.2, patience=5, min_lr=0.0001)


main_dir = "/content/trained-models/" # @param {type:"string"}
models_dir = "/content/trained-models/main-models/" # @param {type:"string"}
checkpoint_models_dir = "/content/trained-models/saved_models_per_epoch/" # @param {type:"string"}
if not os.path.exists(main_dir):
    os.makedirs(main_dir)
if not os.path.exists(checkpoint_models_dir) and checkpoint_per_epoch == True:
    os.makedirs(checkpoint_models_dir)
if not os.path.exists(models_dir):
    os.makedirs(models_dir)

name_of_model = "!change_ME!" # @param {type:"string"}
safetensors_ext_name = ".safetensors"

class SaveSafetensors(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        name_of_model = "manga-diffusion-vae_epoch_{:02d}".format(epoch)
        tensors = {"embedding": tf.zeros((height_of_shape, width_of_shape)), "attention": tf.zeros((height_of_shape, width_of_shape))}
        save_file(tensors, os.path.join(checkpoint_models_dir, name_of_model + safetensors_ext_name))
        print("\nEpoch-checkpoint saved in " + checkpoint_models_dir + " with the name: " + name_of_model + ".safetensors")


save_safetensors_callback = SaveSafetensors()


callbacks = [reduce_lr]


if checkpoint_per_epoch:
    callbacks.append(save_safetensors_callback)


history = vae.fit(
    train_generator,
    epochs=num_epochs,
    validation_data=validation_generator,
    steps_per_epoch=len(train_generator),
    validation_steps=len(validation_generator),
    callbacks=callbacks
)


os.chdir(models_dir)
tensors = {"embedding": tf.zeros((height_of_shape, width_of_shape)), "attention": tf.zeros((height_of_shape, width_of_shape))}
save_file(tensors, name_of_model + safetensors_ext_name)
print("Model with name " + name_of_model + safetensors_ext_name + " saved in " + models_dir + " .")


In [None]:
# @title #Test the model.
import numpy as np
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Flatten, Dense, BatchNormalization, Dropout, Lambda, Layer, Conv2DTranspose, Reshape, UpSampling2D
from tensorflow.keras.models import Model
import tensorflow as tf
color_mode = "grayscale" # @param ["rgb", "grayscale"]

def preprocess_image(image_path, target_size):
    img = image.load_img(image_path, target_size=target_size, color_mode=color_mode)
    img = image.img_to_array(img)


    if color_mode == "grayscale":
      img = tf.image.grayscale_to_rgb(tf.convert_to_tensor(img, dtype=tf.float32))

    img = img / rescale
    img = np.expand_dims(img, axis=0)
    return img


def visualize_images(original_img, generated_img):
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.title('Original Image')
    plt.imshow(original_img[0, :, :, 0], cmap='gray')

    plt.subplot(1, 2, 2)
    plt.title('Generated Image')
    plt.imshow(generated_img[0, :, :, 0], cmap='gray')

    plt.show()


model_link = "/content/trained-models/main-models/" + name_of_model + safetensors_ext_name

image_path = '' # @param {type:"string"}
target_size = (height_of_shape, width_of_shape, 3)
num_steps = 30 # @param {type:"integer"}


input_image = preprocess_image(image_path, target_size)
generated_image = vae.predict(input_image, steps=num_steps)
visualize_images(input_image, generated_image)



