**Start** (SET PATH)
---

In [None]:
experiment = "SSRIW"
path = "./{}/".format(experiment)

In [None]:
!pip install --upgrade pip
!pip install vit-keras
!pip install einops

!pip install scikit-image
!pip install scikit-learn
!pip install tensorflow-addons
!pip install opencv-python
!pip install opencv-python-headless
!apt-get update && apt-get install ffmpeg libsm6 libxext6  -y
!pip install imgaug
!pip install pydot
!pip install pydotplus
!apt -y install graphviz
!pip install pydot_ng

In [None]:
import tensorflow as tf
print(tf.version.VERSION)

In [None]:
from skimage import data, io, transform, color
from skimage.transform import rescale, resize, downscale_local_mean
from skimage.filters import threshold_otsu
from skimage.util import *
from sklearn.utils import shuffle

from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa
from einops.layers.tensorflow import Rearrange
from mpl_toolkits import axes_grid1
from kaggle_datasets import KaggleDatasets

from tensorflow.keras.layers import *
from tensorflow.keras.activations import *
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.datasets import cifar10
import tensorflow.keras.backend as K
from tensorflow.keras.initializers import glorot_uniform
from tensorflow.keras.utils import *
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard

import vit_keras.layers as vit_layers
from vit_keras import vit, utils, visualize

import imgaug.augmenters as iaa
import json
import os
import random
import math
import numpy as np
import cv2 as cv
import matplotlib.pyplot as plt
from PIL import Image
from functools import reduce

import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import logging
logger = tf.get_logger()
logger.setLevel(logging.ERROR)

In [None]:
if not os.path.exists(path):
    os.makedirs(path)

**Select TPU**
---

In [None]:
# detect and init the TPU
tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
tf.config.experimental_connect_to_cluster(tpu)
tf.tpu.experimental.initialize_tpu_system(tpu)

# instantiate a distribution strategy
strategy = tf.distribute.TPUStrategy(tpu)

**Training Parameters**
---

In [None]:
batch_size = 32 * strategy.num_replicas_in_sync
num_epochs = 500
wait = 0
patience = 15
num_classes = 10
input_shape = (128, 128, 3)
watermark_shape = (8, 8, 1)
noise_dim = 512

mlp_dim = 512
hidden_size = 512
num_layers = 4
patch_size = 16
num_heads = 2

**Dataset**
---

In [None]:
DATASET_INFO = {}
AUTO = tf.data.AUTOTUNE

In [None]:
def get_dataset(batch_size, dataset_type='TRAIN', batchwise_augment=False):
    if dataset_type == 'TRAIN':
        data_dir = '/kaggle/input/imagenetmini-1000/imagenet-mini/train'
    else:
        data_dir = '/kaggle/input/imagenetmini-1000/imagenet-mini/val'
    
    dataset = tf.keras.utils.image_dataset_from_directory(
        data_dir,
        labels="inferred",
        label_mode="int",
        class_names=None,
        color_mode="rgb",
        batch_size=batch_size,
        image_size=(128, 128),
        shuffle=True,
        seed=None,
        validation_split=None,
        subset=None,
        interpolation="bilinear",
        follow_links=False,
        crop_to_aspect_ratio=False
    )
    
    DATASET_INFO[dataset_type] = dataset.cardinality().numpy()*batch_size
    
    def watermark(x):
        w = x[:, :, :, 0]
        w = tf.math.greater(w, 127)
        w = tf.where(w, 1, 0)
        w = tf.expand_dims(w, axis=-1)
        return w
    
    def create_triplets(x, y):
        # x_n
        x_n = x
        
        # w (also for noised)
        w = x_n
        w = tf.image.resize(w, (8, 8))
        w = watermark(w)
        w = tf.random.shuffle(w)
        
        # w_s
        w_s = w
        w_s = tf.random.shuffle(w_s)
        
        # x_s and y_s
        x_s = x
        y_s = y
        seed = random.randint(0, 10000)
        x_s = tf.random.shuffle(x_s, seed=seed)
        y_s = tf.random.shuffle(y_s, seed=seed)
        
        x = tf.cast(x, tf.float32) / 255.0
        x_n = tf.cast(x_n, tf.float32) / 255.0
        x_s = tf.cast(x_s, tf.float32) / 255.0
        w = tf.cast(w, tf.float32)
        w_s = tf.cast(w_s, tf.float32)
        y = tf.cast(y, tf.float32)
        y_s = tf.cast(y_s, tf.float32)

        return x, x_n, x_s, w, w_s, y, y_s
    
    dataset = dataset.map(create_triplets, num_parallel_calls=AUTO)
    dataset = dataset.repeat()
    dataset = dataset.unbatch()
    dataset = dataset.batch(batch_size, drop_remainder=True)
    dataset = dataset.prefetch(AUTO)
    
    return dataset

In [None]:
PER_REPLICA_BATCH_SIZE = batch_size // strategy.num_replicas_in_sync

if strategy.num_replicas_in_sync > 1:
    trn_dataset = strategy.distribute_datasets_from_function(
        lambda _ : get_dataset(PER_REPLICA_BATCH_SIZE, 'TRAIN')
    )
    val_dataset = strategy.distribute_datasets_from_function(
      lambda _ : get_dataset(PER_REPLICA_BATCH_SIZE, 'VAL')
    )
else:
    trn_dataset = get_dataset(batch_size, 'TRAIN')
    val_dataset = get_dataset(batch_size, 'VAL')
    
print(f"train dataset size: {DATASET_INFO['TRAIN']}")
print(f"val dataset size: {DATASET_INFO['VAL']}")
    
trn_iterator = iter(trn_dataset)
val_iterator = iter(val_dataset)

In [None]:
def normalize_img(temp):
    return ((temp - np.min(temp)) / (np.max(temp) - np.min(temp)))

In [None]:
def get_BER(original, reconstructed):
    og = np.array(original)
    rec = np.array(reconstructed)
    og = np.reshape(og, watermark_shape)
    rec = np.reshape(rec, watermark_shape)
    
    TB = np.prod(og.shape)
    EB = (og == rec).flatten().tolist().count(False)
    RB = TB - EB
    BER = EB*100.0/TB
    BRR = RB*100.0/TB
    
    print('Total bits =', TB)
    print('Recovered bits =', RB)
    print('Error bits =', EB)
    print(f'BRR = {BRR} %')
    print(f'BER = {BER} %')

In [None]:
def get_PSNR(original, reconstructed):
    og = np.array(original)
    rec = np.array(reconstructed)
    og = np.reshape(og, input_shape)
    rec = np.reshape(rec, input_shape)
    
    mse = np.mean((og - rec) ** 2)
    if(mse == 0):
        psnr = 100
    else:
        max_pixel = 1.0
        psnr = 20 * math.log10(max_pixel / math.sqrt(mse))
    
    print(f'PSNR = {psnr} dB')

**All models**
---

**Embedder**
---

In [None]:
def get_embedder(
        mlp_dim=mlp_dim,
        num_heads=num_heads,
        name='Embedder'):
    
    cover_im = Input(shape=input_shape)
    watermark = Input(shape=watermark_shape)
    
    cover_im_ = Rearrange('b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size)(cover_im)
    watermark_ = Rearrange('b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1=patch_size//16, p2=patch_size//16)(watermark)
    
    cover_im_ = vit_layers.AddPositionEmbs(name="T_Pos_Embed_c")(cover_im_)
    watermark_ = vit_layers.AddPositionEmbs(name="T_Pos_Embed_w")(watermark_)
    
    # MHA
    attention_output_1 = MultiHeadAttention(num_heads=num_heads, key_dim=mlp_dim)(cover_im_, watermark_)
    attention_output_2 = MultiHeadAttention(num_heads=num_heads, key_dim=mlp_dim)(watermark_, cover_im_)
    
    attention_output_1 = Add()([cover_im_, attention_output_1])
    attention_output_2 = Add()([watermark_, attention_output_2])
    
    attention_output = Concatenate()([attention_output_1, attention_output_2])
    
    x = Dense(768)(attention_output)
    h = 128 // patch_size
    
    marked_im = Rearrange('b (h w) (p1 p2 c) -> b (h p1) (w p2) c', h=h, w=h, p1=patch_size, p2=patch_size, c=3)(x)
    
    model = Model([cover_im, watermark], marked_im, name=name)
    
    return model

**Extractor**
---

In [None]:
def get_extractor(
        mlp_dim=mlp_dim,
        num_heads=num_heads,
        name='Extractor'):
    
    marked_im = Input(shape=input_shape)
    
    x = Reshape((8, 8, 768))(marked_im)
    
    x = Conv2D(64, 3, padding='same')(x)
    x = Activation('relu')(x)
    x = Conv2D(128, 3, padding='same')(x)
    x = Activation('relu')(x)
    x = Dropout(0.2)(x)
    x = Conv2D(256, 3, padding='same')(x)
    x = Activation('relu')(x)
    
    x = Dense(512)(x)
    
    x = Conv2D(128, 3, padding='same')(x)
    x = Activation('relu')(x)
    x = Dropout(0.2)(x)
    x = Conv2D(64, 3, padding='same')(x)
    x = Activation('relu')(x)
    x = Conv2D(32, 3, padding='same')(x)
    x = Activation('relu')(x)
    x = Conv2D(8, 3, padding='same')(x)
    x = Activation('relu')(x)
    x = Conv2D(1, 3, padding='same')(x)
    watermark = Activation('relu')(x)
    
    model = Model(marked_im, watermark, name=name)
    
    return model

**Encoder**
---

In [None]:
def get_encoder(
        num_layers=num_layers,
        patch_size=patch_size,
        hidden_size=hidden_size,
        mlp_dim=mlp_dim,
        dropout=.02,
        num_heads=num_heads,
        name='Encoder'
    ):
    
    in_channels = 3
    patch_dim = in_channels * patch_size ** 2
    h = 128 // patch_size
    
    ip = Input(shape=(128, 128, in_channels))
    
    y = Rearrange('b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size)(ip)
    
    y = Dense(hidden_size)(y)
    y = vit_layers.AddPositionEmbs(name="T_Pos_Embed")(y)
    
    for n in range(num_layers):
        y, _ = vit_layers.TransformerBlock(
                    num_heads=num_heads,
                    mlp_dim=mlp_dim,
                    dropout=dropout,
                    name=f"T_Enc_Block_{n}"
                )(y)
    
    y_ = BatchNormalization()(y)
    y_ = Flatten()(y_)
    y_ = Dense(1000)(y_)
    
    model = Model(inputs=ip, outputs=[y, y_], name=name)
    return model

**Decoder**
---

In [None]:
def get_decoder(
        input_shape,
        num_layers=num_layers,
        patch_size=patch_size,
        hidden_size=hidden_size,
        mlp_dim=mlp_dim,
        dropout=.02,
        num_heads=num_heads,
        name='Decoder'
    ):
    
    in_channels = 3
    patch_dim = in_channels * patch_size ** 2
    h = 128 // patch_size
    
    ip = Input(shape=input_shape)
    y = ip
    
    for n in range(4):
        y, _ = vit_layers.TransformerBlock(
                    num_heads=num_heads,
                    mlp_dim=mlp_dim,
                    dropout=dropout,
                    name=f"T_Dec_Block_{n}"
                )(y)
    
    y = LayerNormalization(
        epsilon=1e-6, name="T_LNorm"
    )(y)
    
    y = Dense(patch_dim)(y)
    y = Rearrange('b (h w) (p1 p2 c) -> b (h p1) (w p2) c', h=h, w=h, p1=patch_size, p2=patch_size, c=in_channels)(y)
    
    model = Model(inputs=ip, outputs=y, name=name)
    return model

In [None]:
image_size = 128
encoder = get_encoder()
channels = PER_REPLICA_BATCH_SIZE * reduce(lambda x, y: x * y, encoder.output_shape[0][1:]) // (32*32)

**Defining all params inside strategy**
---

In [None]:
with strategy.scope():
    margin = 1.0
    lr_schedule = keras.optimizers.schedules.ExponentialDecay(
        initial_learning_rate=1e-4,
        decay_steps=10000,
        decay_rate=0.9)
    opt_emb_ext = keras.optimizers.Adam(learning_rate=lr_schedule)
    opt_enc = keras.optimizers.Adam(learning_rate=1e-4)
    opt_enc_legacy = keras.optimizers.legacy.Adam(learning_rate=1e-4)
    opt_ext = keras.optimizers.Adam(learning_rate=1e-4)
    
    mse = keras.losses.mean_squared_error
    @tf.function
    def mse_loss(z1, z2):
        return tf.nn.compute_average_loss(keras.losses.mean_squared_error(z1, z2))
    
    train_loss_emb = []
    train_loss_enc = []
    train_loss_ext = []

    val_loss_emb = []
    val_loss_enc = []
    val_loss_ext = []

In [None]:
with strategy.scope():
    @tf.function
    def random_augmenter(x, image_size=128):
        # Flip
        if tf.random.uniform([], minval=0, maxval=1) < 0.5: # Default
#         if tf.random.uniform([], minval=0, maxval=1) < 10.0: # Always
#         if tf.random.uniform([], minval=0, maxval=1) > 10.0: # Never
            x = tf.image.flip_left_right(x)

        # Color Jitter
        if tf.random.uniform([], minval=0, maxval=1) < 0.8: # Default
#         if tf.random.uniform([], minval=0, maxval=1) < 10.0: # Always
#         if tf.random.uniform([], minval=0, maxval=1) > 10.0: # Never
            x = tf.stop_gradient(tf.image.random_brightness(x, 0.4))
            x = tf.stop_gradient(tf.image.random_contrast(x, 0.5, 2.0))
            x = tf.stop_gradient(tf.image.random_saturation(x, 0.5, 2.0))
            x = tf.stop_gradient(tf.image.random_hue(x, 0.25))

        # Guassian Blur
        if tf.random.uniform([], minval=0, maxval=1) < 0.4: # Default
#         if tf.random.uniform([], minval=0, maxval=1) < 10.0: # Always
#         if tf.random.uniform([], minval=0, maxval=1) > 10.0: # Never
            s = np.random.uniform(1.0, 2.5)
            x = tfa.image.gaussian_filter2d(image=x, sigma=s)

        # Solarization
        if tf.random.uniform([], minval=0, maxval=1) < 0.2: # Default
#         if tf.random.uniform([], minval=0, maxval=1) < 10.0: # Always
#         if tf.random.uniform([], minval=0, maxval=1) > 10.0: # Never
            x = tf.where(x < 127/255.0, x, 1.0 - x)

        return x

    augment = tf.keras.Sequential([
    layers.Lambda(random_augmenter),
    ])

In [None]:
og, noised, shuffled, w, w_s, y, y_s = val_iterator.next()

i = 0

og_ = og.values[0]
noised_ = noised.values[0]
shuffled_ = shuffled.values[0]
w_ = w.values[0]
w_s_ = w_s.values[0]
y_ = y.values[0]
y_s_ = y_s.values[0]

print(og_[0][0][0][0])

print(og_.shape, type(og_))
print(noised_.shape, type(noised_))
print(shuffled_.shape, type(shuffled_))
print(w_.shape, type(w_))
print(w_s_.shape, type(w_s_))
print(y_.shape, type(y_))
print(y_s_.shape, type(y_s_))

fig, axs = plt.subplots(1, 5)

axs[0].imshow(og_[i])
axs[0].set_title(int(y_[i]), fontdict={'fontsize': 80})
axs[1].imshow(augment(noised_)[i])
axs[1].set_title(int(y_[i]), fontdict={'fontsize': 80})
axs[2].imshow(shuffled_[i])
axs[2].set_title(int(y_s_[i]), fontdict={'fontsize': 80})
axs[3].imshow(w_[i], cmap='gray')
axs[4].imshow(w_s_[i], cmap='gray')

print(augment(noised_)[i].shape)

print(tf.reduce_sum(og_[i] - augment(noised_)[i]))

fig.set_figheight(10)
fig.set_figwidth(40)

In [None]:
# BREAKPOINT_1

**Embedder Model**
---

In [None]:
with strategy.scope():
    embedder = get_embedder()
#     embedder = get_concat_embedder()
    
    # Getting 3 cover images
    cover_og = keras.Input(shape=input_shape)
    cover_noised = keras.Input(shape=input_shape)
    cover_shuffled = keras.Input(shape=input_shape)
    
    # Getting 3 watermarks
    watermark_og = keras.Input(shape=watermark_shape)
    watermark_noised = keras.Input(shape=watermark_shape)
    watermark_shuffled = keras.Input(shape=watermark_shape)
    
    # Getting marked image
    marked_og = embedder([cover_og, watermark_og])
    marked_noised = embedder([cover_noised, watermark_noised])
    marked_shuffled = embedder([cover_shuffled, watermark_shuffled])
    
    emb_train = Model([cover_og, watermark_og], marked_og, name='Embedder_train')

In [None]:
emb_train.summary()

In [None]:
with strategy.scope():
    extractor = get_extractor()
    
    # Getting watermark back
    watermark_og_ = extractor(marked_og)

    emb_ext_train = Model([cover_og, watermark_og], [marked_og, watermark_og_], name='Embedder_Extractor_train')

In [None]:
emb_ext_train.summary()

**Training: Embedder-Extractor Model**
---

In [None]:
# BREAKPOINT_2

In [None]:
@tf.function
def train_step(im_og, im_n, im_s, w, w_s, y, y_s):
    with tf.GradientTape() as tape:
        marked_og, watermark_og_ = emb_ext_train([im_og, w], training=True)
              
        # Embedder loss
        embedding_loss = mse_loss(im_og, marked_og)
        
        # Watermark extraction loss
        watermark_loss = mse_loss(w, watermark_og_)
                
    grads = tape.gradient([embedding_loss, watermark_loss], emb_ext_train.trainable_variables)    
    opt_emb_ext.apply_gradients(zip(grads, emb_ext_train.trainable_variables))
  
    return embedding_loss, watermark_loss

In [None]:
@tf.function
def validate_step(im_og, im_n, im_s, w, w_s, y, y_s):  
    with tf.GradientTape() as tape:
        marked_og, watermark_og_ = emb_ext_train([im_og, w], training=False)
              
        # Embedder loss
        embedding_loss = mse_loss(im_og, marked_og)
        
        # Watermark extraction loss
        watermark_loss = mse_loss(w, watermark_og_)
  
    return embedding_loss, watermark_loss

In [None]:
best_emb = 1000000000
best_ext = 1000000000
wait = 0
train_loss_emb = []
train_loss_ext = []
val_loss_emb = []
val_loss_ext = []
    
for epoch in range(num_epochs):
    
    batch_count = 0
    emb_sum = 0
    wat_sum = 0
    for i in range(DATASET_INFO['TRAIN']//PER_REPLICA_BATCH_SIZE):
        batch_count += 1
        
        print('\rEpoch [%d/%d] Batch: %d%s' % (epoch + 1, num_epochs, batch_count, '.' * (batch_count % 10)), end='')
        emb_loss, wat_loss = strategy.run(train_step, args=(next(trn_iterator)))
        emb_sum += tf.math.reduce_sum(emb_loss.values)
        wat_sum += tf.math.reduce_sum(wat_loss.values)
    train_loss_emb.append(emb_sum/batch_size)
    train_loss_ext.append(wat_sum/batch_size)
    
    emb_sum = 0
    wat_sum = 0
    for i in range(DATASET_INFO['VAL']//PER_REPLICA_BATCH_SIZE):
        emb_loss, wat_loss = strategy.run(validate_step, args=(next(val_iterator)))
        emb_sum += tf.math.reduce_sum(emb_loss.values)
        wat_sum += tf.math.reduce_sum(wat_loss.values)
    val_loss_emb.append(emb_sum/batch_size)
    val_loss_ext.append(wat_sum/batch_size)
    
    print()
    print('Training loss(emb): {} - Validation loss(emb): {}'.format(train_loss_emb[epoch], val_loss_emb[epoch]))
    print('Training loss(ext): {} - Validation loss(ext): {}'.format(train_loss_ext[epoch], val_loss_ext[epoch]))
    
    # Early Stopping
    wait += 1
    if val_loss_emb[epoch] < best_emb and val_loss_ext[epoch] < best_ext:
        best_emb = val_loss_emb[epoch]
        best_ext = val_loss_ext[epoch]
        wait = 0
        print(f'Best Embedder loss: {best_emb}, Best Extractor loss: {best_ext}')
        emb_ext_train.save_weights(path + 'emb_ext_{}.h5'.format(experiment))
    else:
        print('Not saved!')
    if wait >= patience:
        print(f'Training stagnated for {patience} epochs')
        break

In [None]:
# Plotting loss
fig, ax = plt.subplots(1, 1, figsize=(20, 20))

ax.plot(train_loss_emb)
ax.plot(train_loss_ext)

ax.plot(val_loss_emb)
ax.plot(val_loss_ext)

ax.set_title('Embedder Extractor loss')
ax.set(xlabel='Epoch', ylabel='Loss')
ax.grid()
ax.legend(['Train_emb', 'Train_ext', 'Validation_emb', 'Validation_ext'], loc='upper right')

fig.savefig(path + 'Embedder_Extractor - Training.png')

In [None]:
emb_ext_train.load_weights(path + 'emb_ext_{}.h5'.format(experiment))

In [None]:
def plot():
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(40, 40))
    
    noised_, w_ = emb_ext_train.predict([noised.values[0], w.values[0]])

    ax1.imshow(noised.values[0][0])
    ax2.imshow(w.values[0][0], cmap='gray')

    w_ = extractor.predict(noised_)
    ax3.imshow(noised_[0])
    ax4.imshow(cv.threshold(normalize_img(w_[0]), 0.5, 1.0, cv.THRESH_BINARY)[1], cmap='gray')
    
    print('===================================\nCover and marked:')
    get_PSNR(noised.values[0][0], noised_[0])
    
    print('===================================\nWatermark:')
    get_BER(w.values[0][0], cv.threshold(normalize_img(w_[0]), 0.5, 1.0, cv.THRESH_BINARY)[1])

    fig.savefig(path + 'Embedder_Extractor - Results.png')
    
    ax1.set_xticklabels([])
    ax1.set_yticklabels([])
    ax2.set_xticklabels([])
    ax2.set_yticklabels([])
    ax3.set_xticklabels([])
    ax3.set_yticklabels([])
    ax4.set_xticklabels([])
    ax4.set_yticklabels([])
    plt.show()
    
    return noised.values[0][0], noised_[0]
    
temp_c, temp_m = plot()

**Watermark Location**
---

In [None]:
fig, ((ax1, ax2, ax3), (ax4, ax5, ax6)) = plt.subplots(2, 3, figsize=(30, 20))

diff = (temp_c - temp_m)*20.0

ax1.imshow(temp_c)
ax2.imshow(temp_m)
ax3.imshow(diff, cmap='gray')
ax4.imshow(diff[:, :, 0], cmap='gray')
ax5.imshow(diff[:, :, 1], cmap='gray')
ax6.imshow(diff[:, :, 2], cmap='gray')

ax1.set_xticklabels([])
ax1.set_yticklabels([])
ax2.set_xticklabels([])
ax2.set_yticklabels([])
ax3.set_xticklabels([])
ax3.set_yticklabels([])
ax4.set_xticklabels([])
ax4.set_yticklabels([])
ax5.set_xticklabels([])
ax5.set_yticklabels([])
ax6.set_xticklabels([])
ax6.set_yticklabels([])

print(np.min(temp_c), np.max(temp_c))
print(np.min(temp_m), np.max(temp_m))
print(np.min(diff), np.max(diff))

plt.savefig(path + 'Watermark location.png')
plt.show()

In [None]:
# BREAKPOINT_3

**Encoder Model**
---

In [None]:
with strategy.scope():
    encoder = get_encoder()
    marked_noised_ = augment(marked_noised)
    
    # Getting 3 latent domains
    code_og, label_og = encoder(marked_og)
    code_noised, label_noised = encoder(marked_noised_)
    code_shuffled, label_shuffled = encoder(marked_shuffled)

In [None]:
emb_ext_train.trainable = False

In [None]:
with strategy.scope():
    # Training model
    enc_train = Model([cover_og, cover_noised, cover_shuffled,
                 watermark_og, watermark_noised, watermark_shuffled],
                [code_og, code_noised, code_shuffled, label_og, label_noised, label_shuffled],
                name='Encoder_train')

In [None]:
enc_train.summary()

**Training: Encoder Model**
---

In [None]:
# BREAKPOINT_4

In [None]:
@tf.function
def train_step(im_og, im_n, im_s, w, w_s, y, y_s):
    with tf.GradientTape() as tape:
        code_og, code_noised, code_shuffled, label_og, label_noised, label_shuffled = enc_train([im_og, im_n, im_s, w, w, w_s], training=True)
        
        # Encoder triplet loss
        positive_anchor_loss = mse(code_og, code_noised)
        negative_anchor_loss = mse(code_og, code_shuffled)
        encoding_loss = tf.maximum(positive_anchor_loss - negative_anchor_loss + margin, 0.0)
    
    enc_grads = tape.gradient(encoding_loss, enc_train.trainable_variables)
    opt_enc.apply_gradients(zip(enc_grads, enc_train.trainable_variables))
  
    return encoding_loss

In [None]:
@tf.function
def validate_step(im_og, im_n, im_s, w, w_s, y, y_s):
    with tf.GradientTape() as tape:
        code_og, code_noised, code_shuffled, label_og, label_noised, label_shuffled = enc_train([im_og, im_n, im_s, w, w, w_s], training=False)
        
        # Encoder triplet loss
        positive_anchor_loss = mse(code_og, code_noised)
        negative_anchor_loss = mse(code_og, code_shuffled)
        encoding_loss = tf.maximum(positive_anchor_loss - negative_anchor_loss + margin, 0.0)
  
    return encoding_loss

In [None]:
best = 1000000000
wait = 0
train_loss_enc = []
val_loss_enc = []
    
for epoch in range(num_epochs):
    
    batch_count = 0
    enc_sum = 0
    for i in range(DATASET_INFO['TRAIN']//PER_REPLICA_BATCH_SIZE):
        batch_count += 1
        
        print('\rEpoch [%d/%d] Batch: %d%s' % (epoch + 1, num_epochs, batch_count, '.' * (batch_count % 10)), end='')
        enc_loss = strategy.run(train_step, args=(next(trn_iterator)))
        enc_sum += tf.math.reduce_sum(enc_loss.values)
    train_loss_enc.append(enc_sum/batch_size)
    
    enc_sum = 0
    for i in range(DATASET_INFO['VAL']//PER_REPLICA_BATCH_SIZE):
        enc_loss = strategy.run(validate_step, args=(next(val_iterator)))
        enc_sum += tf.math.reduce_sum(enc_loss.values)
    val_loss_enc.append(enc_sum/batch_size)
    
    print()
    print('Training loss(enc): {} - Validation loss(enc): {}'.format(train_loss_enc[epoch], val_loss_enc[epoch]))
    
    # Early Stopping
    wait += 1
    if np.mean(val_loss_enc[epoch]) < best:
        best = np.mean(val_loss_enc[epoch])
        wait = 0
        print(f'Best Encoding loss: {best}')
        enc_train.save_weights(path + 'enc_{}.h5'.format(experiment))
    else:
        print('Not saved!')
    if wait >= patience:
        break

In [None]:
# Plotting loss
fig, ax = plt.subplots(1, 1, figsize=(20, 20))

ax.plot(train_loss_enc)

ax.plot(val_loss_enc)

ax.set_title('Encoder loss')
ax.set(xlabel='Epoch', ylabel='Encoder loss')
ax.grid()
ax.legend(['Train_enc', 'Validation_enc'], loc='upper right')

fig.savefig(path + 'Encoder - Training.png'.format(experiment))

In [None]:
enc_train.load_weights(path + 'enc_{}.h5'.format(experiment))

In [None]:
def plot():
    fig, ((ax1, ax2, ax3), (ax4, ax5, ax6), (ax7, ax8, ax9), (ax10, ax11, ax12)) = plt.subplots(4, 3, figsize=(60, 80))
    
    ax1.imshow(og.values[0][0])
    ax2.imshow(noised.values[0][0])
    ax3.imshow(shuffled.values[0][0])

    ax4.imshow(w.values[0][0], cmap='gray')
    ax5.imshow(w.values[0][0], cmap='gray')
    ax6.imshow(w_s.values[0][0], cmap='gray')

    og_, og_w_ = emb_ext_train.predict([og.values[0], w.values[0]])
    noised_, noised_w_ = emb_ext_train.predict([noised.values[0], w.values[0]])
    shuffled_, shuffled_w_ = emb_ext_train.predict([shuffled.values[0], w_s.values[0]])

    noised_marked = augment(noised_)
    
    ax7.imshow(og_[0])
    ax8.imshow(noised_marked[0])
    ax9.imshow(shuffled_[0])
    
    og_c, og_label = encoder.predict(og_)
    noised_c, noised_label = encoder.predict(noised_marked)
    shuffled_c, shuffled_label = encoder.predict(shuffled_)
    
    ax10.imshow(np.reshape(og_c, (32, 32, channels))[:, :, 0])
    ax11.imshow(np.reshape(noised_c, (32, 32, channels))[:, :, 0])
    ax12.imshow(np.reshape(shuffled_c, (32, 32, channels))[:, :, 0])

    fig.savefig(path + 'Encoder - Results.png')

    print("Og_c - Noised_c = ", tf.reduce_mean(og_c[0] - noised_c[0]))
    print("Og_c - Shuffled_c = ", tf.reduce_mean(og_c[0] - shuffled_c[0]))
    print("Noised_c - Shuffled_c = ", tf.reduce_mean(noised_c[0] - shuffled_c[0]))
    
    plt.show()

plot()

**Decoder Model**
---

In [None]:
with strategy.scope():
    decoder = get_decoder(encoder.output_shape[0][1:])
    
    # Getting 3 watermarks and cover images back
    marked_og_ = decoder(code_og)
    marked_noised__ = decoder(code_noised)
    marked_shuffled_ = decoder(code_shuffled)

In [None]:
# enc_train.trainable = False

In [None]:
with strategy.scope():
    # Training model
    ae_train = Model([cover_og, cover_noised, cover_shuffled,
                watermark_og, watermark_noised, watermark_shuffled],
               [marked_og, marked_noised, marked_shuffled,
                marked_og_, marked_noised__, marked_shuffled_],
               name='AE_train')

In [None]:
ae_train.summary()

In [None]:
# BREAKPOINT_5

**Extractor Model**
---

In [None]:
with strategy.scope():
    extractor = get_extractor()
    
    # Getting 3 watermarks and cover images back
    watermark_og_ = extractor(marked_og_)
    watermark_noised_ = extractor(marked_noised__)
    watermark_shuffled_ = extractor(marked_shuffled_)

In [None]:
# decoder.trainable = False
# encoder.trainable = False

In [None]:
extractor.trainable=True

In [None]:
with strategy.scope():
    # Training model
    ext_train = Model([cover_og, cover_noised, cover_shuffled,
                       watermark_og, watermark_noised, watermark_shuffled],
                      [watermark_og_, watermark_noised_, watermark_shuffled_],
                      name='Extractor_train')

In [None]:
ext_train.summary()

**Training: Extractor Model**
---

In [None]:
# BREAKPOINT_6

In [None]:
@tf.function
def train_step(im_og, im_n, im_s, w, w_s, y, y_s):
    with tf.GradientTape() as tape:
        watermark_og_, watermark_noised_, watermark_shuffled_ = ext_train([im_og, im_n, im_s, w, w, w_s], training=True)
              
        # Watermark extraction loss
        extraction_loss = mse_loss(w, watermark_og_) + mse_loss(w, watermark_noised_) + mse_loss(w_s, watermark_shuffled_)
                
    grads = tape.gradient(extraction_loss, ext_train.trainable_variables)
    opt_ext.apply_gradients(zip(grads, ext_train.trainable_variables))
  
    return extraction_loss

In [None]:
@tf.function
def validate_step(im_og, im_n, im_s, w, w_s, y, y_s):
    with tf.GradientTape() as tape:
        watermark_og_, watermark_noised_, watermark_shuffled_ = ext_train([im_og, im_n, im_s, w, w, w_s], training=False)
              
        # Watermark extraction loss
        extraction_loss = mse_loss(w, watermark_og_) + mse_loss(w, watermark_noised_) + mse_loss(w_s, watermark_shuffled_)
  
    return extraction_loss

In [None]:
best_ext = 1000000000
wait = 0
train_loss_ext = []
val_loss_ext = []
    
for epoch in range(num_epochs):
    
    batch_count = 0
    ext_sum = 0
    for i in range(DATASET_INFO['TRAIN']//PER_REPLICA_BATCH_SIZE):
        batch_count += 1
        
        print('\rEpoch [%d/%d] Batch: %d%s' % (epoch + 1, num_epochs, batch_count, '.' * (batch_count % 10)), end='')
        ext_loss = strategy.run(train_step, args=(next(trn_iterator)))
        ext_sum += tf.math.reduce_sum(ext_loss.values)
    train_loss_ext.append(ext_sum/batch_size)
    
    ext_sum = 0
    for i in range(DATASET_INFO['VAL']//PER_REPLICA_BATCH_SIZE):
        ext_loss = strategy.run(validate_step, args=(next(val_iterator)))
        ext_sum += tf.math.reduce_sum(ext_loss.values)
    val_loss_ext.append(ext_sum/batch_size)
    
    print()
    print('Training loss(ext): {} - Validation loss(ext): {}'.format(train_loss_ext[epoch], val_loss_ext[epoch]))
    
    # Early Stopping
    wait += 1
    if val_loss_ext[epoch] < best_ext:
        best_ext = val_loss_ext[epoch]
        wait = 0
        print(f'Best Extractor loss: {best_ext}')
        ext_train.save_weights(path + 'ext_{}.h5'.format(experiment))
    else:
        print('Not saved!')
    if wait >= patience:
        break

In [None]:
# Plotting loss
fig, ax = plt.subplots(1, 1, figsize=(20, 20))

ax.plot(train_loss_ext)
ax.plot(val_loss_ext)

ax.set_title('Extractor loss')
ax.set(xlabel='Epoch', ylabel='Extractor loss')
ax.grid()
ax.legend(['Train', 'Validation'], loc='upper right')

fig.savefig(path + 'Extractor - Training.png'.format(experiment))

In [None]:
ext_train.load_weights(path + 'ext_{}.h5'.format(experiment))

In [None]:
def plot():
    fig, axs = plt.subplots(3, 5, figsize=(100, 60))

    axs[0, 0].imshow(og.values[0][0])
    axs[1, 0].imshow(noised.values[0][0])
    axs[2, 0].imshow(shuffled.values[0][0])

    axs[0, 1].imshow(w.values[0][0], cmap='gray')
    axs[1, 1].imshow(w.values[0][0], cmap='gray')
    axs[2, 1].imshow(w_s.values[0][0], cmap='gray')

    og_, og_w_ = emb_ext_train.predict([og.values[0], w.values[0]])
    noised_, noised_w_ = emb_ext_train.predict([noised.values[0], w.values[0]])
    shuffled_, shuffled_w_ = emb_ext_train.predict([shuffled.values[0], w_s.values[0]])

    noised_marked = augment(noised_)
    axs[0, 2].imshow(og_[0])
    axs[1, 2].imshow(noised_marked[0])
    axs[2, 2].imshow(shuffled_[0])
    
    og_c, og_label = encoder.predict(og_)
    noised_c, noised_label = encoder.predict(noised_marked)
    shuffled_c, shuffled_label = encoder.predict(shuffled_)
    
    axs[0, 3].imshow(np.reshape(og_c, (32, 32, channels))[:, :, 0])
    axs[1, 3].imshow(np.reshape(noised_c, (32, 32, channels))[:, :, 0])
    axs[2, 3].imshow(np.reshape(shuffled_c, (32, 32, channels))[:, :, 0])
    
    marked_og_ = decoder.predict(og_c)
    marked_noised_ = decoder.predict(noised_c)
    marked_shuffled_ = decoder.predict(shuffled_c)
    
    og_w_ = extractor.predict(marked_og_)
    noised_w_ = extractor.predict(marked_noised_)
    shuffled_w_ = extractor.predict(marked_shuffled_)
    
    axs[0, 4].imshow(cv.threshold(normalize_img(og_w_[0]), 0.5, 1.0, cv.THRESH_BINARY)[1], cmap='gray')
    axs[1, 4].imshow(cv.threshold(normalize_img(noised_w_[0]), 0.5, 1.0, cv.THRESH_BINARY)[1], cmap='gray')
    axs[2, 4].imshow(cv.threshold(normalize_img(shuffled_w_[0]), 0.5, 1.0, cv.THRESH_BINARY)[1], cmap='gray')
    
    axs = axs.flatten()
    for ax in axs:
        ax.set_xticks([])
        ax.set_yticks([])

    fig.savefig(path + 'Functionality - Results.png')

    print("===================================")
    print("Og_c - Noised_c = ", tf.reduce_mean(og_c[0] - noised_c[0]))
    print("Og_c - Shuffled_c = ", tf.reduce_mean(og_c[0] - shuffled_c[0]))
    
    print('===================================\nCover and marked:')
    print('----------\nOg - Og_:')
    get_PSNR(og.values[0][0], og_[0])
    print('----------\nNoised - Noised_:')
    get_PSNR(noised.values[0][0], noised_marked[0])
    print('----------\nShuffled - Shuffled_:')
    get_PSNR(shuffled.values[0][0], shuffled_[0])
        
    print('===================================\nWatermark:')
    print('----------\nOg:')
    get_BER(w.values[0][0], cv.threshold(normalize_img(og_w_[0]), 0.5, 1.0, cv.THRESH_BINARY)[1])
    print('----------\nNoised:')
    get_BER(w.values[0][0], cv.threshold(normalize_img(noised_w_[0]), 0.5, 1.0, cv.THRESH_BINARY)[1])
    print('----------\nShuffled:')
    get_BER(w_s.values[0][0], cv.threshold(normalize_img(shuffled_w_[0]), 0.5, 1.0, cv.THRESH_BINARY)[1])
    
    plt.show()
    
plot()