<a href="https://colab.research.google.com/github/lifeisbeautifu1/deep-learning/blob/main/image_inpainting.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from __future__ import print_function, division
from keras.datasets import cifar10
from keras.layers import BatchNormalization, Activation, Dense, Input, Reshape, Flatten, Dropout, multiply, GaussianNoise, Embedding, ZeroPadding2D, MaxPooling2D, LeakyReLU, UpSampling2D, Conv2D
from keras.optimizers.legacy import Adam
from keras.models import Sequential, Model, load_model
from keras import losses
from keras.utils import to_categorical

import matplotlib.pyplot as plt
import numpy as np

In [None]:
class ContextEncoder():
  def __init__(self):
    self.img_rows = 32
    self.img_cols = 32
    self.mask_height = 8
    self.mask_width = 8
    self.channels = 3
    self.num_classes = 2
    self.img_shape = (self.img_rows, self.img_cols, self.channels)
    self.missing_shape = (self.mask_height, self.mask_width, self.channels)

    # optimizer = Adam(0.0002, 0.5)

    self.discriminator = load_model("drive/MyDrive/saved_model/discriminator.keras")

    # self.discriminator = self.build_discriminator()
    # self.discriminator.compile(loss="binary_crossentropy", optimizer=optimizer,
    #                            metrics=["accuracy"])

    # self.generator = self.build_generator()

    self.generator = load_model("drive/MyDrive/saved_model/generator.keras")

    # masked_img = Input(shape=self.img_shape)
    # gen_missing = self.generator(masked_img)

    # self.discriminator.trainable = False

    # valid = self.discriminator(gen_missing)

    # self.combined = Model(masked_img, [gen_missing, valid])
    # self.combined.compile(loss=["mse", "binary_crossentropy"], loss_weights=[0.999, 0.001], optimizer=optimizer)
    self.combined = load_model("drive/MyDrive/saved_model/encoder.keras")

  def build_generator(self):
    model = Sequential()

    # Encoder

    model.add(Conv2D(32, kernel_size=3, strides=2, input_shape=self.img_shape, padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))

    model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))

    model.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))

    model.add(Conv2D(512, kernel_size=1, strides=2, padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.5))
    model.add(BatchNormalization(momentum=0.8))

    # Decoder

    model.add(UpSampling2D())
    model.add(Conv2D(128, kernel_size=3, padding="same"))
    model.add(Activation("relu"))
    model.add(BatchNormalization(momentum=0.8))

    model.add(UpSampling2D())
    model.add(Conv2D(64, kernel_size=3, padding="same"))
    model.add(Activation("relu"))
    model.add(BatchNormalization(momentum=0.8))

    model.add(Conv2D(self.channels, kernel_size=3, padding="same"))
    model.add(Activation('tanh'))

    # model.summary()

    masked_img = Input(self.img_shape)

    gen_missing = model(masked_img)

    return Model(masked_img, gen_missing)

  def build_discriminator(self):
    model = Sequential()

    model.add(Conv2D(64, kernel_size=3, strides=2, input_shape=self.missing_shape, padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))

    model.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))

    model.add(Conv2D(256, kernel_size=3, padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))

    model.add(Flatten())
    model.add(Dense(1, activation="sigmoid"))

    # model.summary()

    img = Input(shape=self.missing_shape)

    validity = model(img)

    return Model(img, validity)

  def mask_random(self, imgs):
    y1 = np.random.randint(0, self.img_rows - self.mask_height, imgs.shape[0])
    y2 = y1 + self.mask_height

    x1 = np.random.randint(0, self.img_cols - self.mask_width, imgs.shape[0])
    x2 = x1 + self.mask_width

    masked_imgs = np.empty_like(imgs)

    missing_parts = np.empty(shape=(imgs.shape[0], self.mask_height, self.mask_width, self.channels))

    for i, img in enumerate(imgs):
      masked_img = img.copy()

      _y1, _y2, _x1, _x2 = y1[i], y2[i], x1[i], x2[i]

      missing_parts[i] = masked_img[_y1:_y2, _x1:_x2, :]

      masked_img[_y1:_y2, _x1:_x2, :] = 1

      masked_imgs[i] = masked_img

    return (masked_imgs, missing_parts, (y1, y2, x1, x2))

  def train(self, epochs, batch_size, sample_interval):
    (X_train, y_train), (_, _) = cifar10.load_data()

    X_cats = X_train[(y_train==3).flatten()]
    X_dogs = X_train[(y_train==5).flatten()]

    X_cats = X_cats[:50]
    X_dogs = X_dogs[:50]

    X_train = np.vstack((X_cats, X_dogs))

    X_train = X_train / 127.5 - 1
    y_train = y_train.reshape(-1, 1)

    valid = np.ones((batch_size, 1))
    fake = np.zeros((batch_size, 1))

    for epoch in range(epochs):
      idx = np.random.randint(0, X_train.shape[0], batch_size)

      imgs = X_train[idx]

      masked_imgs, missing_parts, _ = self.mask_random(imgs)

      gen_missing = self.generator(masked_imgs)

      d_loss_real = self.discriminator.train_on_batch(missing_parts, valid)
      d_loss_fake = self.discriminator.train_on_batch(gen_missing, fake)

      d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

      g_loss = self.combined.train_on_batch(masked_imgs, [missing_parts, valid])

      print("%d [D loss: %f , acc: %.2f%%] [G loss: %f , mse: %f]" % (epoch, d_loss[0], 100 * d_loss[1], g_loss[0], g_loss[1]))

      if epoch % sample_interval == 0:
        idx = np.random.randint(0, X_train.shape[0], 6)

        imgs = X_train[idx]

        self.sample_images(epoch, imgs)

    print("training done. saving the model")
    self.save_model()


  def sample_images(self, epoch, imgs):
    rows, cols = 3, 6

    masked_imgs, missing_parts, (y1, y2, x1, x2) = self.mask_random(imgs)

    gen_missing = self.generator.predict(masked_imgs)

    imgs = 0.5 * imgs + 0.5
    masked_imgs = 0.5 * masked_imgs + 0.5
    gen_missing = 0.5 * gen_missing + 0.5
    fig, axs = plt.subplots(rows, cols)

    for i in range(cols):
      axs[0, i].imshow(imgs[i, :, :])
      axs[0, i].axis("off")
      axs[1, i].imshow(masked_imgs[i, :, :])
      axs[1, i].axis("off")

      filled_img = imgs[i].copy()
      filled_img[y1[i]:y2[i], x1[i]:x2[i], :] = gen_missing[i]

      axs[2, i].imshow(filled_img)
      axs[2, i].axis("off")

    plt.savefig("drive/MyDrive/images/%d.png" % epoch)
    plt.close()

  def save_model(self):
    self.generator.save("drive/MyDrive/saved_model/generator.keras")

    self.discriminator.save("drive/MyDrive/saved_model/discriminator.keras")

    self.combined.save("drive/MyDrive/saved_model/encoder.keras")










In [None]:
if __name__ == "__main__":
  context_encoder = ContextEncoder()

  context_encoder.train(epochs=1001, batch_size=32, sample_interval=500)



0 [D loss: 0.471809 , acc: 75.00%] [G loss: 0.173016 , mse: 0.172584]
1 [D loss: 0.431027 , acc: 85.94%] [G loss: 0.225289 , mse: 0.224488]
2 [D loss: 0.247876 , acc: 92.19%] [G loss: 0.186412 , mse: 0.186142]
3 [D loss: 0.339230 , acc: 92.19%] [G loss: 0.190729 , mse: 0.190255]
4 [D loss: 0.474293 , acc: 89.06%] [G loss: 0.185220 , mse: 0.184740]
5 [D loss: 0.323961 , acc: 90.62%] [G loss: 0.206191 , mse: 0.205569]
6 [D loss: 0.308570 , acc: 90.62%] [G loss: 0.200373 , mse: 0.199482]
7 [D loss: 0.320271 , acc: 89.06%] [G loss: 0.190708 , mse: 0.190004]
8 [D loss: 0.278702 , acc: 90.62%] [G loss: 0.155486 , mse: 0.154902]
9 [D loss: 0.305470 , acc: 90.62%] [G loss: 0.162882 , mse: 0.162075]
10 [D loss: 0.202448 , acc: 95.31%] [G loss: 0.178062 , mse: 0.177291]
11 [D loss: 0.299048 , acc: 90.62%] [G loss: 0.157615 , mse: 0.157068]
12 [D loss: 0.238053 , acc: 93.75%] [G loss: 0.162316 , mse: 0.161595]
13 [D loss: 0.404478 , acc: 84.38%] [G loss: 0.216097 , mse: 0.215642]
14 [D loss: 0.32