In [None]:
import numpy as np
from trainer.utils import DataLoader, plot_test_images, restore_original_image, get_Fcs
from trainer.MST import MST
import matplotlib.pyplot as plt
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.vgg19 import preprocess_input
%matplotlib inline

In [None]:
model = MST(None,None,3, decoder_weights='./trainer/data/weights/pretrained.h5')

In [None]:
# path to data
dl = DataLoader(datapath='./trainer/data')

In [None]:
model.decoder.summary()

In [None]:
content_img = image.load_img('./trainer/data/test/content_img.jpg')
style_img = image.load_img('./trainer/data/test/style_img.jpg')

content_img = preprocess_input(np.array(content_img))
style_img = preprocess_input(np.array(style_img))

content_vgg = np.expand_dims(content_img, 0)
style_vgg = np.expand_dims(style_img, 0)

Fs = dl.vgg.predict(style_vgg)
Fc = dl.vgg.predict(content_vgg)
Fcs = get_Fcs(Fc, Fs, k=3, alpha=1)
Fcs = np.expand_dims(Fcs, 0)
Ics = model.decoder.predict(Fcs)

In [None]:
Ocs = restore_original_image(Ics, 'channels_last')
Oc = restore_original_image(content_vgg, 'channels_last')
Os = restore_original_image(style_vgg, 'channels_last')

Ocs = np.squeeze(Ocs, axis=0)
Oc = np.squeeze(Oc, axis=0)
Os = np.squeeze(Os, axis=0)

In [None]:
im = {
    'Content': Oc,
    'Style': Os,
    'Out': Ocs
}

fig, axes = plt.subplots(1, 3, figsize=(40, 10))
for i, (title, img) in enumerate(im.items()):
    axes[i].imshow(img)
    axes[i].set_title("{} - {}".format(title, img.shape))
    axes[i].axis('off')

plt.suptitle('{} - Epoch: {}'.format('filename', 23))
plt.show()