In [None]:
from skimage import io
import numpy as np
import tensorflow as tf
from matplotlib import pyplot as plt
from utils import grainPreprocess
from sklearn.model_selection import train_test_split

In [None]:
gpus = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(gpus[0], True)

In [None]:
#all_images=grainPreprocess.read_preprocess_data('data/dataset',images_num_per_class=150,preprocess=False,save=True,save_name='all_images_no_preprocess.npy',resize=True)

In [None]:
all_images=np.load('data/saved np/all_images_no_preprocess.npy',allow_pickle=True)

In [None]:
all_images_rgb=[]
for i,images_list in enumerate(all_images):
    for image_gray in images_list:
        tf_image=tf.expand_dims(image_gray/255,2)
        tf_rgb=tf.image.grayscale_to_rgb(tf_image)
        tf_preproc=tf.image.resize(tf_rgb,(1024,1024))
        all_images_rgb.append(tf_preproc)

all_images_rgb=np.array(all_images_rgb)

In [None]:
image_shape=(1024,1024,3)

In [None]:
decoder=tf.keras.applications.resnet50.ResNet50(
    include_top=False, input_tensor=None,weights=None,
     pooling='max',input_shape=image_shape
)

encoder=tf.keras.applications.resnet_v2.ResNet50V2(
    include_top=False, input_tensor=None,weights=None,
    pooling=True,input_shape=(32,32,2)
)

[keras autoencoder](https://blog.keras.io/building-autoencoders-in-keras.html)

In [None]:
# This is the size of our encoded representations
encoding_dim = 32  # 32 floats -> compression of factor 24.5, assuming the input is 784 floats

# This is our input image
input_img = tf.keras.Input(shape=image_shape)

x=decoder(input_img)
x=tf.keras.layers.Reshape((32,32,2))(x)

x = encoder(x)
x=tf.keras.layers.Conv2DTranspose(32,(3,3),(4,4))(x)
x=tf.keras.layers.Conv2DTranspose(32,(3,3),(4,4))(x)
x=tf.keras.layers.Conv2DTranspose(32,(3,3),(4,4))(x)
x=tf.keras.layers.Conv2D(32,(1,1),(1,1))(x)
x=tf.keras.activations.relu(x)
x=tf.keras.layers.Conv2DTranspose(32,(3,3),(4,4))(x)
x=tf.keras.layers.Conv2DTranspose(3,(3,3),(4,4))(x)
x=tf.keras.layers.Conv2D(3,(1,1),(1,1))(x)
x=tf.keras.activations.tanh(x)



x=tf.keras.layers.Reshape(image_shape)(x)

# This model maps an input to its reconstruction
autoencoder = tf.keras.Model(input_img, x)

In [None]:
autoencoder.summary()

In [None]:
x_train, x_test, y_train, y_test = train_test_split(  all_images_rgb,all_images_rgb, test_size=0.2, )

In [None]:
autoencoder.compile(optimizer=tf.keras.optimizers.RMSprop(), loss='MSE',metrics=['MAE'])

In [None]:
print(x_train.shape)
print(x_test.shape)

In [None]:
history=autoencoder.fit(x_train, y_train,
                epochs=15,
                batch_size=8,
                shuffle=True,
                validation_data=(x_test, x_test))

In [None]:
# Encode and decode some digits
# Note that we take them from the *test* set
encoded_imgs = encoder.predict(x_test)
decoded_imgs = decoder.predict(encoded_imgs)

In [None]:
# Use Matplotlib (don't ask)
import matplotlib.pyplot as plt

n = 10  # How many digits we will display
plt.figure(figsize=(20, 4))
for i in range(n):
    # Display original
    ax = plt.subplot(2, n, i + 1)
    plt.imshow(x_test[i].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    # Display reconstruction
    ax = plt.subplot(2, n, i + 1 + n)
    plt.imshow(decoded_imgs[i].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.show()