<a href="https://colab.research.google.com/github/leonyangucl/2023fyp/blob/main/inpainting_test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import matplotlib.pyplot as plt
from keras.datasets import mnist
import random

In [None]:

# load mnist data
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# example image
image = x_train[5]

image_copy = np.copy(image)

# we randomly mask this image
mask_size = 10
start_row = random.randint(0, image_copy.shape[0] - mask_size)
start_col = random.randint(0, image_copy.shape[1] - mask_size)

image_copy[start_row:start_row+mask_size, start_col:start_col+mask_size] = 0

# plot the two images
fig, ax = plt.subplots(1, 2)
ax[0].imshow(image, cmap='gray')
ax[0].set_title('Original Image')
ax[1].imshow(image_copy, cmap='gray')
ax[1].set_title('Masked Image')
plt.show()


In [3]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# normalization
x_train = x_train / 255.0
x_test = x_test / 255.0

x_train_noisy = x_train.copy()
x_test_noisy = x_test.copy()

mask_size = 10

#mask all the images
for img in x_train_noisy:
    row = random.randint(0, img.shape[0] - mask_size)
    col = random.randint(0, img.shape[1] - mask_size)
    img[row:row+mask_size, col:col+mask_size] = 0

for img in x_test_noisy:
    row = random.randint(0, img.shape[0] - mask_size)
    col = random.randint(0, img.shape[1] - mask_size)
    img[row:row+mask_size, col:col+mask_size] = 0

In [None]:
from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D, UpSampling2D, Reshape
from keras.callbacks import EarlyStopping



model = Sequential([
    Reshape((28, 28, 1), input_shape=(28, 28)),
    Conv2D(32, (3, 3), activation='relu', padding='same'),
    MaxPooling2D((2, 2), padding='same'),
    Conv2D(32, (3, 3), activation='relu', padding='same'),
    MaxPooling2D((2, 2), padding='same'),


    Conv2D(32, (3, 3), activation='relu', padding='same'),
    UpSampling2D((2, 2)),
    Conv2D(32, (3, 3), activation='relu', padding='same'),
    UpSampling2D((2, 2)),
    Conv2D(1, (3, 3), activation='sigmoid', padding='same'),
    Reshape((28, 28))
])

model.summary()




In [None]:
model.compile(optimizer='adam', loss='binary_crossentropy')


early_stop = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)

model.fit(x_train_noisy, x_train, epochs=50, batch_size=128, validation_data=(x_test_noisy, x_test), callbacks=[early_stop])

In [None]:
import numpy as np

decoded_imgs = model.predict(x_test_noisy)

n = 10
random_test_images = np.random.choice(x_test_noisy.shape[0], size=n)

plt.figure(figsize=(20, 4))
for i, image_idx in enumerate(random_test_images):
    # original images
    ax = plt.subplot(3, n, i + 1)
    plt.imshow(x_test[image_idx].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    # masked images
    ax = plt.subplot(3, n, i + 1 + n)
    plt.imshow(x_test_noisy[image_idx].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    # fixed images
    ax = plt.subplot(3, n, i + 1 + 2*n)
    plt.imshow(decoded_imgs[image_idx].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.show()

simple unet

In [None]:
from keras.models import Model
from keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, concatenate

def create_model():
    inputs = Input((28, 28, 1))

    # encoder
    conv1 = Conv2D(64, (3, 3), activation='relu', padding='same')(inputs)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = Conv2D(128, (3, 3), activation='relu', padding='same')(pool1)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    # decoder
    conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(pool2)
    up1 = concatenate([UpSampling2D(size=(2, 2))(conv3), conv2], axis=-1)

    conv4 = Conv2D(64, (3, 3), activation='relu', padding='same')(up1)
    up2 = concatenate([UpSampling2D(size=(2, 2))(conv4), conv1], axis=-1)

    # output layer
    conv5 = Conv2D(1, (3, 3), activation='sigmoid', padding='same')(up2)

    model = Model(inputs=[inputs], outputs=[conv5])
    model.compile(optimizer='adam', loss='binary_crossentropy')

    return model

model = create_model()
model.summary()

In [None]:
model.compile(optimizer='adam', loss='binary_crossentropy')


early_stop = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)

model.fit(x_train_noisy, x_train, epochs=50, batch_size=128, validation_data=(x_test_noisy, x_test), callbacks=[early_stop])

In [None]:
import numpy as np

decoded_imgs = model.predict(x_test_noisy)

n = 10
random_test_images = np.random.choice(x_test_noisy.shape[0], size=n)

plt.figure(figsize=(20, 4))
for i, image_idx in enumerate(random_test_images):
    # original images
    ax = plt.subplot(3, n, i + 1)
    plt.imshow(x_test[image_idx].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    # masked images
    ax = plt.subplot(3, n, i + 1 + n)
    plt.imshow(x_test_noisy[image_idx].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    # fixed images
    ax = plt.subplot(3, n, i + 1 + 2*n)
    plt.imshow(decoded_imgs[image_idx].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.show()
